From d91269e8ce309437c1f849b5ab3362d69b178ef4 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 18 Nov 2025 17:20:39 +0000 Subject: [PATCH 001/230] Revert "[ROCm] enable fastSpecializedAtomicAdd for gfx950 (#167661)" This reverts commit 1b43d6cd4e01b63f6bcf5238fdca5dc41e9121ae. Reverted https://github.com/pytorch/pytorch/pull/167661 on behalf of https://github.com/yangw-dev due to break internal tests and build, please reach out meta fellas to have fix it and reland again, error examplke: hip/KernelUtils.cuh:74:5: error: no matching function for call to 'unsafeAtomicAdd' ([comment](https://github.com/pytorch/pytorch/pull/167661#issuecomment-3548737051)) --- aten/src/ATen/native/cuda/KernelUtils.cuh | 60 ++++++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/cuda/KernelUtils.cuh b/aten/src/ATen/native/cuda/KernelUtils.cuh index fd406829707a1..5c8b98105bb26 100644 --- a/aten/src/ATen/native/cuda/KernelUtils.cuh +++ b/aten/src/ATen/native/cuda/KernelUtils.cuh @@ -5,11 +5,69 @@ #include #endif +// ROCm 6.3 is planned to have these functions, but until then here they are. #if defined(USE_ROCM) #include #include #include -#define ATOMICADD unsafeAtomicAdd + +__device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* address, __hip_bfloat162 value) { +#if (defined(__gfx942__)) && \ + __has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2bf16) + typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2; + static_assert(sizeof(vec_short2) == sizeof(__hip_bfloat162_raw)); + union { + __hip_bfloat162_raw bf162_raw; + vec_short2 vs2; + } u{static_cast<__hip_bfloat162_raw>(value)}; + u.vs2 = __builtin_amdgcn_flat_atomic_fadd_v2bf16((vec_short2*)address, u.vs2); + return static_cast<__hip_bfloat162>(u.bf162_raw); +#else + static_assert(sizeof(unsigned int) == sizeof(__hip_bfloat162_raw)); + union u_hold { + __hip_bfloat162_raw h2r; + unsigned int u32; + }; + u_hold old_val, new_val; + old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + do { + new_val.h2r = __hadd2(old_val.h2r, value); + } while (!__hip_atomic_compare_exchange_strong( + (unsigned int*)address, &old_val.u32, new_val.u32, + __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT)); + return old_val.h2r; +#endif +} + +__device__ inline __half2 preview_unsafeAtomicAdd(__half2* address, __half2 value) { +#if (defined(__gfx942__)) && \ + __has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2f16) + // The api expects an ext_vector_type of half + typedef _Float16 __attribute__((ext_vector_type(2))) vec_fp162; + static_assert(sizeof(vec_fp162) == sizeof(__half2_raw)); + union { + __half2_raw h2r; + vec_fp162 fp16; + } u {static_cast<__half2_raw>(value)}; + u.fp16 = __builtin_amdgcn_flat_atomic_fadd_v2f16((vec_fp162*)address, u.fp16); + return static_cast<__half2>(u.h2r); +#else + static_assert(sizeof(__half2_raw) == sizeof(unsigned int)); + union u_hold { + __half2_raw h2r; + unsigned int u32; + }; + u_hold old_val, new_val; + old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + do { + new_val.h2r = __hadd2(old_val.h2r, value); + } while (!__hip_atomic_compare_exchange_strong( + (unsigned int*)address, &old_val.u32, new_val.u32, + __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT)); + return old_val.h2r; +#endif +} +#define ATOMICADD preview_unsafeAtomicAdd #define NATIVE_ZERO_BF16 __float2bfloat16(0.0f) #else #define ATOMICADD atomicAdd From 57927a620d6033b5812d130dbe2452648dcf38b0 Mon Sep 17 00:00:00 2001 From: Shivam Raikundalia Date: Tue, 18 Nov 2025 17:56:50 +0000 Subject: [PATCH 002/230] [Profiler] Deprecate export_memory_timeline method (#168036) Summary: The export_memory_timeline method in torch.profiler is being deprecated in favor of the newer memory snapshot API (torch.cuda.memory._record_memory_history and torch.cuda.memory._export_memory_snapshot). This change adds the deprecated decorator from typing_extensions and updates the docstring to guide users to the recommended alternative. The decorator will emit a FutureWarning at runtime, and the docstring now includes a .. deprecated:: directive for documentation visibility. Test Plan: Manual verification that the decorator is properly applied and the deprecation message is informative. Differential Revision: D87272399 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168036 Approved by: https://github.com/valentinandrei --- torch/profiler/profiler.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index f3400e438a2d3..056a5fcc21fdd 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -9,7 +9,7 @@ from enum import Enum from functools import partial from typing import Any, Optional -from typing_extensions import Self +from typing_extensions import deprecated, Self from warnings import warn import torch @@ -408,6 +408,11 @@ def _memory_profile(self) -> MemoryProfile: ) return MemoryProfile(self.profiler.kineto_results) + @deprecated( + "`export_memory_timeline` is deprecated and will be removed in a future version. " + "Please use `torch.cuda.memory._record_memory_history` and `torch.cuda.memory._export_memory_snapshot` instead.", + category=FutureWarning, + ) def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None: """Export memory event information from the profiler collected tree for a given device, and export a timeline plot. There are 3 @@ -429,6 +434,11 @@ def export_memory_timeline(self, path: str, device: Optional[str] = None) -> Non ``torch.profiler._memory_profiler.Category``. Output: Memory timeline written as gzipped JSON, JSON, or HTML. + + .. deprecated:: + ``export_memory_timeline`` is deprecated and will be removed in a future version. + Please use ``torch.cuda.memory._record_memory_history`` and + ``torch.cuda.memory._export_memory_snapshot`` instead. """ # Default to device 0, if unset. Fallback on cpu. if device is None: From 20cae808f7ee9b96578c7daa5669a408d8a9f136 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Tue, 18 Nov 2025 17:57:30 +0000 Subject: [PATCH 003/230] `ComplexTensor` subclass (#167621) This PR introduces a `Tensor` subclass which represents a complex tensor in terms of two real ones. Ops are decomposed as individual ops on the real and imaginary parts. It is compatible with `torch.compile`, so long as the real ops used are also compatible. Autograd "works", but is WIP due to different edge-case behaviour. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167621 Approved by: https://github.com/ezyang --- test/complex_tensor/test_complex_tensor.py | 238 +++++ test/complex_tensor/utils.py | 214 ++++ torch/_subclasses/complex_tensor/__init__.py | 9 + torch/_subclasses/complex_tensor/_core.py | 151 +++ .../complex_tensor/_ops/__init__.py | 5 + torch/_subclasses/complex_tensor/_ops/aten.py | 921 ++++++++++++++++++ .../_subclasses/complex_tensor/_ops/common.py | 317 ++++++ .../_subclasses/complex_tensor/_ops/prims.py | 34 + 8 files changed, 1889 insertions(+) create mode 100644 test/complex_tensor/test_complex_tensor.py create mode 100644 test/complex_tensor/utils.py create mode 100644 torch/_subclasses/complex_tensor/__init__.py create mode 100644 torch/_subclasses/complex_tensor/_core.py create mode 100644 torch/_subclasses/complex_tensor/_ops/__init__.py create mode 100644 torch/_subclasses/complex_tensor/_ops/aten.py create mode 100644 torch/_subclasses/complex_tensor/_ops/common.py create mode 100644 torch/_subclasses/complex_tensor/_ops/prims.py diff --git a/test/complex_tensor/test_complex_tensor.py b/test/complex_tensor/test_complex_tensor.py new file mode 100644 index 0000000000000..dbb14d93f972a --- /dev/null +++ b/test/complex_tensor/test_complex_tensor.py @@ -0,0 +1,238 @@ +# Owner(s): ["module: complex"] +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +import torch.distributed as dist + + +# Support both when imported from elsewhere or directly as a file +try: + from .utils import ( + COMPLEX_DTYPES, + Descriptor, + force_test_op_db, + get_overload_packet_from_name, + implemented_op_db, + TestCase, + Variant, + ) +except ImportError: + from utils import ( + COMPLEX_DTYPES, + Descriptor, + force_test_op_db, + get_overload_packet_from_name, + implemented_op_db, + TestCase, + Variant, + ) + +from torch._subclasses.complex_tensor._ops.common import ComplexTensorMode +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, + OpDTypes, + ops, +) +from torch.testing._internal.common_utils import ( + run_tests, + TestGradients, + unMarkDynamoStrictTest, +) + + +if TYPE_CHECKING: + from torch.testing._internal.opinfo.core import OpInfo + +aten = torch.ops.aten + +SKIPS = { + Descriptor(op=aten.empty_like, variant=None): "Non-deterministic output", + Descriptor(op=aten.randn_like, variant=None): "Non-deterministic output", + Descriptor(op=aten.angle, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.asinh, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.atanh, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor( + op=aten.reciprocal, variant=Variant.GradCheck + ): "Numerical inconsistency", + Descriptor(op=aten.rsqrt, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.select, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.asin, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.log, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.sgn, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.cumprod, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.slice, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.sqrt, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.tan, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor( + op=aten.true_divide, variant=Variant.GradCheck + ): "Numerical inconsistency", + Descriptor(op=aten.prod, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.div, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.expm1, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.var, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.bmm, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.diagonal, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.sinh, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.abs, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.sin, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.atan, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.acos, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.acosh, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.cos, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.cosh, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.addmm, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.pow, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.log1p, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.tanh, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.mm, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.dot, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.mul, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.exp, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.to, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor( + op=aten.any, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.all, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.allclose, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.conj_physical, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten._conj_physical, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.cumprod, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.index_add, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.diagonal_scatter, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.flip, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.masked_fill, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.masked_scatter, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.rsub, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.ne, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.squeeze, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.index_select, variant=Variant.Distributed + ): "Sharding propagation failed", + Descriptor(op=aten.real, variant=Variant.Distributed): "No scalar support", + Descriptor(op=aten.imag, variant=Variant.Distributed): "No scalar support", + Descriptor(op=aten.isfinite, variant=Variant.Distributed): "No scalar support", + Descriptor(op=aten.transpose, variant=Variant.Distributed): "No scalar support", + Descriptor(op=aten.view_as_real, variant=Variant.Distributed): "No scalar support", +} + +EXTRA_KWARGS = { + Descriptor(op=aten.asinh, dtype=torch.complex64, variant=Variant.Op): { + "rtol": 2e-5, + "atol": 5e-5, + }, + Descriptor(op=aten.tanh, dtype=torch.complex64, variant=Variant.Op): { + "rtol": 1e-4, + "atol": 1e-5, + }, + Descriptor(op=aten.pow, dtype=torch.complex64, variant=Variant.Op): { + "rtol": 2e-2, + "atol": 2e-6, + }, + Descriptor(op=aten.asinh, dtype=torch.complex64, variant=Variant.Distributed): { + "rtol": 2e-5, + "atol": 5e-5, + }, + Descriptor(op=aten.tanh, dtype=torch.complex64, variant=Variant.Distributed): { + "rtol": 1e-4, + "atol": 1e-5, + }, + Descriptor(op=aten.pow, dtype=torch.complex64, variant=Variant.Distributed): { + "rtol": 2e-2, + "atol": 2e-6, + }, + Descriptor(op=aten.tan, dtype=torch.complex64, variant=Variant.Distributed): { + "rtol": 2e-6, + "atol": 1e-2, + }, +} + + +class TestComplexTensor(TestCase): + _default_dtype_check_enabled = True + + @ops( + implemented_op_db, + dtypes=OpDTypes.supported, + allowed_dtypes=list(COMPLEX_DTYPES), + ) + def test_consistency(self, device, dtype, op: OpInfo): + self.check_consistency(device, dtype, op, Variant.Op) + + @ops(force_test_op_db, allowed_dtypes=list(COMPLEX_DTYPES)) + def test_maybe_error(self, device, dtype, op: OpInfo): + self.check_consistency(device, dtype, op, Variant.Op) + + +@unMarkDynamoStrictTest +class TestComplexBwdGradients(TestGradients): + _default_dtype_check_enabled = True + + @ops( + implemented_op_db, + dtypes=OpDTypes.supported_backward, + allowed_dtypes=[torch.complex128], + ) + def test_fn_grad(self, device: str, dtype: torch.dtype, op: OpInfo) -> None: + test_info = Descriptor( + op=get_overload_packet_from_name(op.name), + device_type=torch.device(device).type, + dtype=dtype, + variant=Variant.GradCheck, + ) + for xfail_info, reason in SKIPS.items(): + if xfail_info.matches(test_info): + self.skipTest(reason) + + if dtype not in op.supported_backward_dtypes(torch.device(device).type): + self.skipTest(f"Skipped! {dtype=} is not in supported backward dtypes!") + + with ComplexTensorMode(): + op.gradcheck_fast_mode = False + self._grad_test_helper(device, dtype, op, op.get_op()) + + +instantiate_device_type_tests(TestComplexTensor, globals()) +instantiate_device_type_tests(TestComplexBwdGradients, globals()) + + +if dist.is_available(): + from torch.testing._internal.common_distributed import MultiProcessTestCase + + @unMarkDynamoStrictTest + class TestComplexDistributed(TestCase, MultiProcessTestCase): + @ops(implemented_op_db, allowed_dtypes=list(COMPLEX_DTYPES)) + def test_distributed(self, device, dtype, op: OpInfo): + self.check_consistency(device, dtype, op, Variant.Distributed) + + instantiate_device_type_tests(TestComplexDistributed, globals()) + +if __name__ == "__main__": + run_tests() diff --git a/test/complex_tensor/utils.py b/test/complex_tensor/utils.py new file mode 100644 index 0000000000000..d2a1e1d312264 --- /dev/null +++ b/test/complex_tensor/utils.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +from dataclasses import dataclass, field, fields +from enum import auto, Enum +from typing import Any, TYPE_CHECKING + +import torch +import torch.distributed as dist +from torch._subclasses.complex_tensor._ops.common import ( + _as_complex_tensor, + _as_interleaved, + _get_op_name, + COMPLEX_OPS_TABLE, + COMPLEX_TO_REAL, + FORCE_TEST_LIST, + OpOverloadPacket, +) +from torch.testing._internal.common_methods_invocations import op_db +from torch.testing._internal.common_utils import TestCase as PytorchTestCase +from torch.utils._pytree import tree_flatten + + +if TYPE_CHECKING: + from collections.abc import Callable + + from torch.distributed.tensor import DTensor + from torch.testing._internal.opinfo.core import OpInfo + +COMPLEX_DTYPES = set(COMPLEX_TO_REAL) + + +class Variant(Enum): + Op = auto() + GradCheck = auto() + Distributed = auto() + + +def _as_local(arg: DTensor | Any) -> torch.Tensor | Any: + if not (dist.is_available() and isinstance(arg, dist.tensor.DTensor)): + return arg + + return arg.full_tensor() + + +def _as_complex_dtensor(arg: torch.Tensor | Any) -> torch.Tensor | Any: + if not isinstance(arg, torch.Tensor): + return arg + + return dist.tensor.DTensor.from_local(_as_complex_tensor(arg)) + + +TRANSFORM_FUNCS = { + Variant.Op: _as_complex_tensor, + Variant.Distributed: _as_complex_dtensor, +} + + +@dataclass(frozen=True, kw_only=True) +class Descriptor: + op: OpOverloadPacket + variant: Variant | None + device_type: str | None = field(default=None) + dtype: torch.dtype | None = field(default=None) + + def matches(self, other: Descriptor) -> bool: + fields1 = fields(self) + fields2 = fields(other) + if fields1 != fields2: + return False + + for f in fields1: + f1 = getattr(self, f.name) + f2 = getattr(other, f.name) + if f1 is not None and f2 is not None and f1 != f2: + return False + + return True + + +class TestCase(PytorchTestCase): + def assertSameResult( + self, + expected: Callable[[], Any], + actual: Callable[[], Any], + *args, + **kwargs, + ) -> None: + try: + result_e = expected() + exception_e = None + except Exception as e: # noqa: BLE001 + result_e = None + exception_e = e + + try: + result_a = actual() + exception_a = None + except Exception as e: # noqa: BLE001 + result_a = None + exception_a = e + + if (exception_e is None) != (exception_a is None): + if exception_a is not None and exception_e is None: + raise exception_a + self.assertIs( + type(exception_e), + type(exception_a), + f"\n{exception_e=}\n{exception_a=}", + ) + + if exception_e is None: + flattened_e, spec_e = tree_flatten(result_e) + flattened_a, spec_a = tree_flatten(result_a) + + self.assertEqual( + spec_e, + spec_a, + "Both functions must return a result with the same tree structure.", + ) + for value_e, value_a in zip(flattened_e, flattened_a, strict=True): + value_e = _as_interleaved(_as_local(value_e)) + value_a = _as_interleaved(_as_local(value_a)) + + self.assertEqual(value_e, value_a, *args, **kwargs) + + def check_consistency( + self, device: str, dtype, op: OpInfo, variant: Variant + ) -> None: + try: + from .test_complex_tensor import EXTRA_KWARGS, SKIPS + except ImportError: + from test_complex_tensor import EXTRA_KWARGS, SKIPS + test_info = Descriptor( + op=get_overload_packet_from_name(op.name), + device_type=torch.device(device).type, + dtype=dtype, + variant=variant, + ) + for xfail_info, reason in SKIPS.items(): + if xfail_info.matches(test_info): + self.skipTest(reason) + + kwargs = {} + for extra_info, extra_kw in EXTRA_KWARGS.items(): + if extra_info.matches(test_info): + kwargs = extra_kw + break + sample_inputs = op.sample_inputs(device, dtype) + transform_fn = TRANSFORM_FUNCS[variant] + + for sample_input in sample_inputs: + + def expected(sample_input=sample_input): + return op(sample_input.input, *sample_input.args, **sample_input.kwargs) + + subclass_sample = sample_input.transform(transform_fn) + + def actual(subclass_sample=subclass_sample): + return op( + subclass_sample.input, + *subclass_sample.args, + **subclass_sample.kwargs, + ) + + self.assertSameResult(expected, actual, **kwargs) + + +aten = torch.ops.aten + +complex_op_db = tuple( + filter(lambda op: any(op.supports_dtype(ct, "cpu") for ct in COMPLEX_DTYPES), op_db) +) + + +def get_overload_packet_from_name(name: str) -> OpOverloadPacket: + for domain_name in torch.ops: + op_namespace = getattr(torch.ops, domain_name) + op: OpOverloadPacket | None = getattr(op_namespace, name, None) + if op is not None: + return op + + raise RuntimeError(f"No op with {name=} found.") + + +force_test_names = set(map(_get_op_name, FORCE_TEST_LIST)) +implemented_op_names = ( + set(map(_get_op_name, COMPLEX_OPS_TABLE.keys())) - force_test_names +) +implemented_op_db = tuple( + filter(lambda op: op.name in implemented_op_names, complex_op_db) +) +force_test_op_db = tuple(filter(lambda op: op.name in force_test_names, op_db)) + +tested_op_names = {op.name for op in implemented_op_db} | { + op.name for op in force_test_op_db +} +non_tested_ops = { + op for op in COMPLEX_OPS_TABLE if _get_op_name(op) not in tested_op_names +} + + +# TODO (hameerabbasi): There are a number of ops that don't have any associated +# OpInfos. We still need to write tests for those ops. +if len(non_tested_ops) != 0: + import textwrap + import warnings + + list_missing_ops = "\n".join(sorted([str(op) for op in non_tested_ops])) + warnings.warn( + "Not all implemented ops are tested. List of ops missing tests:" + f"\n{textwrap.indent(list_missing_ops, ' ')}", + UserWarning, + stacklevel=2, + ) diff --git a/torch/_subclasses/complex_tensor/__init__.py b/torch/_subclasses/complex_tensor/__init__.py new file mode 100644 index 0000000000000..1ab4a816261dc --- /dev/null +++ b/torch/_subclasses/complex_tensor/__init__.py @@ -0,0 +1,9 @@ +from ._core import ComplexTensor +from ._ops import ComplexTensorMode, is_complex_tensor + + +__all__ = ["ComplexTensor", "ComplexTensorMode", "is_complex_tensor"] + +ComplexTensor.__module__ = __name__ +ComplexTensorMode.__module__ = __name__ +is_complex_tensor.__module__ = __name__ diff --git a/torch/_subclasses/complex_tensor/_core.py b/torch/_subclasses/complex_tensor/_core.py new file mode 100644 index 0000000000000..edd7568b2ef06 --- /dev/null +++ b/torch/_subclasses/complex_tensor/_core.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from typing import Any, TYPE_CHECKING +from typing_extensions import Self + +import torch +from torch import Tensor +from torch.autograd import Function + + +if TYPE_CHECKING: + from torch._ops import OpOverload + from torch._prims_common import DeviceLikeType + from torch.autograd.function import FunctionCtx + + +class ComplexTensor(Tensor): + """A class that decomposes all ops on complex Tensors into their real and imaginary parts.""" + + _re: Tensor + _im: Tensor + + def __new__(cls, real: Tensor, imag: Tensor) -> Self: + """Initialize a ComplexTensor from its real and imaginary parts.""" + from ._ops.common import REAL_TO_COMPLEX + + shape = real.shape + device = real.device + + # TODO (hameerabbasi): `torch.compile` sometimes fails here without making these + # contiguous. Why? + real = real.contiguous() + imag = imag.contiguous() + + # TODO (hameerabbasi): + # What should we do with dtype? + # We could convert to the complex type (float32 -> complex64), but we + # can't use that model for say `bfloat16` which does not have a + # corresponding complex dtype. + # If we want to support this complex rep using any float type (see + # https://github.com/pytorch/pytorch/issues/95100) + # We either need to: + # 1) add the complex types for say `complexbf32`, knowing they can't really be used anywhere + # else. + # 2) We use the real float dtype here, and it is up to the user to know + # that dtype=float here really means complex<2xSize> with dtype + # matching that of re/im parts alone + # I'm going with 1 for now, so that I can make gradcheck and some complex + # ops work properly, but might want to discuss this in the RFP. + dtype = REAL_TO_COMPLEX.get(real.dtype) + if dtype is None: + raise TypeError( + "Unsupported dtype for constituent tensors. Supported dtypes are: " + f"{set(REAL_TO_COMPLEX.keys())!r}." + ) + storage_offset = real.storage_offset() + strides = real.stride() + layout = real.layout + pin_memory = real.is_pinned() + + assert shape == imag.shape, f"Expected imag shape {shape}, got {imag.shape}" + assert device == imag.device, ( + f"Expected imag device {device}, got {imag.device}" + ) + assert real.dtype == imag.dtype, ( + f"Expected imag dtype {real.dtype}, got {imag.dtype}" + ) + assert pin_memory == imag.is_pinned(), ( + f"Expected imag pinning {pin_memory}, got {imag.is_pinned()}" + ) + + res = Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, + shape, + device=device, + dtype=dtype, + storage_offset=storage_offset, + strides=strides, + pin_memory=pin_memory, + layout=layout, + requires_grad=False, + ) + res._re = real.clone().detach() + res._im = imag.clone().detach() + + return res + + @property + def re(self) -> Tensor: + return self._re + + @property + def im(self) -> Tensor: + return self._im + + @classmethod + def __torch_dispatch__( + cls, + func: OpOverload, + types: tuple[type, ...], + args: tuple = (), + kwargs: dict | None = None, + ): + from ._ops.common import lookup_complex + + kwargs = {} if kwargs is None else kwargs + + impl = lookup_complex(func, *args, **kwargs) + if impl is None: + return NotImplemented + + return impl(*args, **kwargs) + + @staticmethod + def from_interleaved(t: Tensor) -> ComplexTensor: + t_real = torch.real(t) + t_imag = torch.imag(t) if t.dtype.is_complex else torch.zeros_like(t_real) + return Complex.apply(t_real, t_imag) + + def as_interleaved(self) -> Tensor: + return torch.complex(self.real, self.imag) + + @staticmethod + def __tensor_unflatten__( + inner_tensors: dict[str, Tensor], + meta: Any, + outer_size: tuple[int, ...], + outer_stride: tuple[int, ...], + ) -> ComplexTensor: + assert meta is None + re, im = inner_tensors["re"], inner_tensors["im"] + return ComplexTensor(re, im) + + def __tensor_flatten__(self) -> tuple[list[str], Any]: + return ["re", "im"], None + + def __repr__(self, *, tensor_contents=None) -> str: + return f"ComplexTensor(real={self.re!r}, imag={self.im!r})" + + def is_pinned(self, device: DeviceLikeType | None = None) -> bool: + return self.re.is_pinned(device) + + +class Complex(Function): + @staticmethod + def forward(ctx: FunctionCtx, real: Tensor, imag: Tensor) -> ComplexTensor: # type: ignore[bad-override] + return ComplexTensor(real, imag) + + @staticmethod + def backward(ctx: FunctionCtx, grad_output: ComplexTensor) -> tuple[Tensor, Tensor]: # type: ignore[bad-override] + return grad_output.real, grad_output.imag diff --git a/torch/_subclasses/complex_tensor/_ops/__init__.py b/torch/_subclasses/complex_tensor/_ops/__init__.py new file mode 100644 index 0000000000000..c07bdf6099b65 --- /dev/null +++ b/torch/_subclasses/complex_tensor/_ops/__init__.py @@ -0,0 +1,5 @@ +from . import aten, prims +from .common import ComplexTensorMode, is_complex_tensor + + +__all__ = ["ComplexTensorMode", "is_complex_tensor", "aten", "prims"] diff --git a/torch/_subclasses/complex_tensor/_ops/aten.py b/torch/_subclasses/complex_tensor/_ops/aten.py new file mode 100644 index 0000000000000..15e09c3b314f0 --- /dev/null +++ b/torch/_subclasses/complex_tensor/_ops/aten.py @@ -0,0 +1,921 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from .._core import ComplexTensor +from .common import ( + _get_func_name, + COMPLEX_TO_REAL, + complex_to_real_dtype, + is_complex, + OpType, + promote_tensors, + register_binary_nonlinear, + register_complex, + register_error, + register_force_test, + register_simple, + split_complex_arg, + split_complex_tensor, +) + + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + from typing import Any + +aten = torch.ops.aten + + +def register_binary_linear(op: OpType): + def impl_with_alpha( + lhs: ComplexTensor, rhs: ComplexTensor, *args, alpha, **kwargs + ) -> ComplexTensor: + return op(lhs, aten.mul(rhs, alpha, *args, **kwargs), *args, **kwargs) + + def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor: + alpha = kwargs.pop("alpha", None) + if alpha is not None: + return impl_with_alpha(lhs, rhs, *args, alpha=alpha, **kwargs) + a_r, a_i = split_complex_arg(lhs) + b_r, b_i = split_complex_arg(rhs) + out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i) + u = op(a_r, b_r, *args, **kwargs) + v = op(a_i, b_i, *args, **kwargs) + return ComplexTensor(u.to(out_dt), v.to(out_dt)) + + return register_complex(op, impl) + + +@register_complex(aten.real) +def real_impl(self: ComplexTensor) -> torch.Tensor: + re, _ = split_complex_tensor(self) + return re + + +@register_complex(aten.imag) +def imag_impl(self: ComplexTensor) -> torch.Tensor: + _, im = split_complex_tensor(self) + return im + + +@register_complex(aten.is_pinned) +def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> bool: + return self.is_pinned(device) + + +SIMPLE_OPS_LIST = [ + aten.slice, + aten.flatten, + aten.view, + aten.diagonal, + aten.expand, + aten.unsqueeze, + aten.unsqueeze_, + aten.mean, + aten.sum, + aten.clone, + aten.neg, + aten.flip, + aten.permute, + aten.repeat, + aten.index_select, + aten.split, + aten.split_with_sizes, + aten.cumsum, + aten.detach, + aten.select, + aten.squeeze, + aten.zero_, + aten.transpose, + aten.t, + aten.gather, +] + +for simple_op in SIMPLE_OPS_LIST: + globals()[_get_func_name(simple_op)] = register_simple(simple_op) + +# TODO (hameerabbasi): Not being tested +SIMPLE_FORCE_TESTED_OPS = [ + aten.copy, + aten.col2im, + aten.alias, + aten.lift_fresh, + aten._unsafe_view, + aten.index, + aten._neg_view, + aten.avg_pool2d, + aten.avg_pool3d, + aten.avg_pool2d_backward, + aten.avg_pool3d_backward, + aten.masked_scatter_backward, + aten.select_backward, + aten.slice_backward, + aten.embedding, +] + +for simple_op in SIMPLE_FORCE_TESTED_OPS: + globals()[_get_func_name(simple_op)] = register_force_test( + simple_op, register_simple(simple_op) + ) + +del simple_op + +# some binary ops which we can stamp out +mul_impl = register_binary_nonlinear(aten.mul) +mul__impl = register_binary_nonlinear(aten.mul_) +mm_impl = register_binary_nonlinear(aten.mm) +dot_impl = register_binary_nonlinear(aten.dot) +bmm_impl = register_binary_nonlinear(aten.bmm) + +# TODO (hameerabbasi): Not being tested +convolution_impl = register_force_test( + aten.convolution, register_binary_nonlinear(aten.convolution) +) + +slice_scatter_impl = register_force_test( + aten.slice_scatter, register_binary_linear(aten.slice_scatter) +) +select_scatter_impl = register_force_test( + aten.select_scatter, register_binary_linear(aten.select_scatter) +) + +add_impl = register_binary_linear(aten.add) +add__impl = register_binary_linear(aten.add_) +sub_impl = register_binary_linear(aten.sub) +sub__impl = register_binary_linear(aten.sub_) +diagonal_scatter_impl = register_binary_linear(aten.diagonal_scatter) +fill__impl = register_binary_linear(aten.fill_) + + +@register_complex(aten.rsub) +def rsub_impl(lhs: ComplexTensor, rhs: ComplexTensor, alpha=None) -> ComplexTensor: + if alpha is None: + return torch.sub(rhs, lhs) # type: ignore[bad-return] + return torch.sub(rhs, lhs, alpha=alpha) # type: ignore[bad-return] + + +@register_complex(aten.div) +@register_complex(aten.true_divide) +def div_impl(lhs: ComplexTensor, rhs: ComplexTensor, *, rounding_mode=None): + if rounding_mode is not None: + raise NotImplementedError( + "`rounding_mode` other than `None` not implemented for`ComplexTensor`." + ) + a_r, a_i = split_complex_tensor(lhs) + if not is_complex(rhs): + return ComplexTensor(a_r / rhs, a_i / rhs) + b_r, b_i = split_complex_arg(rhs) + out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i) + num_r = a_r * b_r + a_i * b_i + num_i = a_i * b_r - a_r * b_i + den = b_r * b_r + b_i * b_i + return ComplexTensor( + (num_r / den).to(out_dt), + (num_i / den).to(out_dt), + ) + + +@register_complex(aten.reciprocal) +def reciprocal_impl(self: ComplexTensor): + self_r, self_i = split_complex_tensor(self) + out_dt, (self_r, self_i) = promote_tensors(self_r, self_i) + den = self_r * self_r + self_i * self_i + return ComplexTensor( + aten.div(self_r, den).to(out_dt), + aten.div(-self_i, den).to(out_dt), + ) + + +# reductions +@register_complex(aten.prod) +def prod_impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor: + out_dt, (self,) = promote_tensors(self) + dtype = kwargs.pop("dtype", out_dt) + kwargs["dtype"] = complex_to_real_dtype(self.dtype) + + prod_r = torch.prod(torch.abs(self), *args, **kwargs) + sum_phi = torch.sum(torch.angle(self), *args, **kwargs) + u = prod_r * torch.cos(sum_phi) + v = prod_r * torch.sin(sum_phi) + return ComplexTensor(u, v).to(dtype) # type: ignore[bad-return] + + +@register_complex(aten.pow) +def pow_impl(self: ComplexTensor, exponent: ComplexTensor) -> ComplexTensor: + out_dt, (self, exponent) = promote_tensors(self, exponent) + return torch.exp(exponent * torch.log(self)).to(out_dt) # type: ignore[bad-return] + + +@register_complex(aten.cumprod) +def cumprod_impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor: + dtype = kwargs.pop("dtype", self.dtype) + kwargs["dtype"] = complex_to_real_dtype(dtype) + + prod_r = torch.cumprod(torch.abs(self), *args, **kwargs) + sum_phi = torch.cumsum(torch.angle(self), *args, **kwargs) + u = prod_r * torch.cos(sum_phi) + v = prod_r * torch.sin(sum_phi) + return ComplexTensor(u, v) + + +# unary funcs, +# most of these are simple or require some kind of identity +@register_complex(aten.abs) +def abs_impl(self: ComplexTensor) -> torch.Tensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + result = torch.hypot(x, y) + return result.to(out_dt) + + +@register_complex(aten.angle) +def angle_impl(self: ComplexTensor) -> torch.Tensor: + x, y = split_complex_tensor(self) + return torch.atan2(y, x) + + +@register_complex(aten.acos) +def acos_impl(self: ComplexTensor) -> ComplexTensor: + _, y = split_complex_tensor(self) + acosh_z = torch.acosh(self) + assert isinstance(acosh_z, ComplexTensor) + acosh_z_re, acosh_z_im = split_complex_tensor(acosh_z) + sign_im = 2 * torch.signbit(y) - 1 + return ComplexTensor(torch.abs(acosh_z_im), sign_im * torch.abs(acosh_z_re)) + + +@register_complex(aten.asin) +def asin_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + asinh_iz = torch.asinh(ComplexTensor(-y, x)) + assert isinstance(asinh_iz, ComplexTensor) + asinh_iz_re, asinh_iz_im = split_complex_tensor(asinh_iz) + return ComplexTensor(asinh_iz_im, -asinh_iz_re) + + +@register_complex(aten.atan) +def atan_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + tanh_iz = torch.atanh(ComplexTensor(-y, x)) + assert isinstance(tanh_iz, ComplexTensor) + tanh_iz_re, tanh_iz_im = split_complex_tensor(tanh_iz) + return ComplexTensor(tanh_iz_im, -tanh_iz_re) + + +@register_complex(aten.asinh) +def asinh_impl(self: ComplexTensor) -> ComplexTensor: + out_dt, (self,) = promote_tensors(self) + return torch.log(self + torch.sqrt(self * self + 1)).to(out_dt) # type: ignore[bad-return] + + +@register_complex(aten.acosh) +def acosh_impl(self: ComplexTensor) -> ComplexTensor: + out_dt, (self,) = promote_tensors(self) + return torch.log(self + torch.sqrt(self * self - 1)).to(out_dt) # type: ignore[bad-return] + + +@register_complex(aten.atanh) +def atanh_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + + ret = 0.5 * ( + torch.log(ComplexTensor(1 + x, y)) - torch.log(ComplexTensor(1 - x, -y)) + ) + assert isinstance(ret, ComplexTensor) + ret_re, ret_im = split_complex_tensor(ret) + + return ComplexTensor(ret_re.to(out_dt), ret_im.to(out_dt)) + + +@register_complex(aten.cos) +def cos_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + return torch.cosh(ComplexTensor(-y, x)) # type: ignore[bad-return] + + +@register_complex(aten.cosh) +def cosh_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + u = torch.cosh(x) * torch.cos(y) + v = torch.sinh(x) * torch.sin(y) + return ComplexTensor(u.to(out_dt), v.to(out_dt)) + + +@register_complex(aten.sin) +def sin_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + sinh_iz = torch.sinh(ComplexTensor(-y, x)) + assert isinstance(sinh_iz, ComplexTensor) + sinh_iz_re, sinh_iz_im = split_complex_tensor(sinh_iz) + return ComplexTensor(sinh_iz_im, -sinh_iz_re) + + +@register_complex(aten.sinh) +def sinh_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + u = torch.sinh(x) * torch.cos(y) + v = torch.cosh(x) * torch.sin(y) + return ComplexTensor(u.to(out_dt), v.to(out_dt)) + + +@register_complex(aten.tan) +def tan_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + tanh_iz = torch.tanh(ComplexTensor(-y, x)) + assert isinstance(tanh_iz, ComplexTensor) + tanh_iz_re, tanh_iz_im = split_complex_tensor(tanh_iz) + return ComplexTensor(tanh_iz_im, -tanh_iz_re) + + +@register_complex(aten.tanh) +def tanh_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + + _2x = 2 * x + _2y = 2 * y + _d = torch.cosh(_2x) + torch.cos(_2y) + _2xsh = torch.sinh(_2x) + + out_re = _2xsh / _d + out_im = torch.sin(_2y) / _d + + return ComplexTensor(out_re.to(out_dt), out_im.to(out_dt)) + + +@register_complex(aten.exp) +def exp_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + ex = torch.exp(x) + u = ex * torch.cos(y) + v = ex * torch.sin(y) + return ComplexTensor(u.to(out_dt), v.to(out_dt)) + + +@register_complex(aten.expm1) +def expm1_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + # TODO (hameerabbasi): The two lines below may have numerical issues + ex = torch.exp(x) + u = ex * torch.cos(y) - 1 + v = ex * torch.sin(y) + return ComplexTensor(u.to(out_dt), v.to(out_dt)) + + +@register_complex(aten.log) +def log_impl(self: ComplexTensor) -> ComplexTensor: + out_dt, (self,) = promote_tensors(self) + re = torch.log(torch.abs(self)) + im = torch.angle(self) + return ComplexTensor(re, im).to(out_dt) # type: ignore[bad-return] + + +@register_complex(aten.log1p) +def log1p_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + # TODO (hameerabbasi): The line below may have numerical issues + return torch.log(ComplexTensor(x + 1, y)) # type: ignore[bad-return] + + +@register_complex(aten.any) +def any_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + x, y = split_complex_tensor(self) + return torch.any(x, *args, **kwargs) | torch.any(y, *args, **kwargs) + + +@register_complex(aten.all) +def all_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + x, y = split_complex_tensor(self) + return torch.any(x, *args, **kwargs) & torch.any(y, *args, **kwargs) + + +@register_complex(aten.eq) +def eq_impl(self: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> torch.Tensor: + a_r, a_i = split_complex_arg(self) + b_r, b_i = split_complex_arg(rhs) + return torch.eq(a_r, b_r, *args, **kwargs) & torch.eq(a_i, b_i, *args, **kwargs) + + +@register_complex(aten.ne) +def ne_impl(self: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> torch.Tensor: + a_r, a_i = split_complex_tensor(self) + b_r, b_i = split_complex_arg(rhs) + return torch.ne(a_r, b_r, *args, **kwargs) | torch.ne(a_i, b_i, *args, **kwargs) + + +@register_complex(aten.isnan) +def isnan_impl(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return torch.isnan(re) | torch.isnan(im) + + +@register_complex(aten.isinf) +def isinf_impl(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return torch.isinf(re) | torch.isinf(im) + + +@register_complex(aten.isfinite) +def isfinite_impl(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return torch.isfinite(re) & torch.isfinite(im) + + +@register_complex(aten.isclose) +def isclose_impl( + self: ComplexTensor, + rhs: ComplexTensor, + rtol=1e-5, + atol=1e-8, + equal_nan: bool = False, +) -> torch.Tensor: + abs_diff = torch.abs(self - rhs) + abs_other = torch.abs(rhs) + basic_condition = abs_diff <= (rtol * abs_other + atol) + + # This is the nontrivial part + if equal_nan: + a_r, a_i = split_complex_tensor(self) + b_r, b_i = split_complex_arg(rhs) + + a_r_nan = torch.isnan(a_r) + b_r_nan = torch.isnan(b_r) + a_i_nan = torch.isnan(a_i) + b_i_nan = torch.isnan(b_i) + a_nan = a_r_nan | a_i_nan + + # This logical expression makes sure that the isnan of both the real and imaginary parts + # matches (so 1 + nan*i doesn't equal nan + 1*i) + equal_nan_condition = ((a_r_nan == b_r_nan) & (a_i_nan == b_i_nan)) & a_nan + return basic_condition | equal_nan_condition + + return basic_condition + + +ERROR_OPS_LIST = [ + aten.lt, + aten.le, + aten.gt, + aten.ge, + aten.amin, + aten.amax, + aten.clamp, + aten.ceil, + aten.floor, + aten.minimum, + aten.maximum, + aten.trunc, + aten.sign, + aten.argmax, + aten.argmin, + aten.sort, + aten.topk, + aten.round, + aten.fmod, +] + + +ERROR_TYPES = { + aten.minimum: RuntimeError, + aten.maximum: RuntimeError, + aten.argmax: RuntimeError, + aten.argmin: RuntimeError, + aten.sort: RuntimeError, + aten.topk: RuntimeError, +} + + +for err_op in ERROR_OPS_LIST: + globals()[_get_func_name(err_op)] = register_error( + err_op, ERROR_TYPES.get(err_op, NotImplementedError) + ) + +del err_op + + +@register_complex(aten.masked_scatter) +def masked_scatter_impl( + self: ComplexTensor, mask: torch.Tensor, source: ComplexTensor +) -> ComplexTensor: + self_r, self_i = split_complex_tensor(self) + source_r, source_i = split_complex_arg(source) + ret_r = torch.masked_scatter(self_r, mask, source_r) + ret_i = torch.masked_scatter(self_i, mask, source_i) + + return ComplexTensor(ret_r, ret_i) + + +@register_complex(aten.where) +def where_impl(mask: torch.Tensor, x: ComplexTensor, y: ComplexTensor) -> ComplexTensor: + x_r, x_i = split_complex_arg(x) + y_r, y_i = split_complex_arg(y) + + ret_r = torch.where(mask, x_r, y_r) + ret_i = torch.where(mask, x_i, y_i) + + return ComplexTensor(ret_r, ret_i) + + +@register_complex(aten.full_like) +def full_like_impl( + input: ComplexTensor, + fill_value: complex, + *args, + dtype: torch.dtype | None = None, + **kwargs, +) -> torch.Tensor | ComplexTensor: + # Note: Cannot be merged with the cases below due to the `fill_value` argument + input_r, input_i = split_complex_tensor(input) + if dtype is not None and dtype not in COMPLEX_TO_REAL: + return torch.full_like(input_r, fill_value, *args, dtype=dtype, **kwargs) + + if dtype is not None: + kwargs["dtype"] = COMPLEX_TO_REAL[dtype] + + fv_r, fv_i = split_complex_arg(fill_value) + ret_r = torch.full_like(input_r, fv_r, *args, **kwargs) + ret_i = torch.full_like(input_i, fv_i, *args, **kwargs) + + return ComplexTensor(ret_r, ret_i) + + +def register_like(op: OpType) -> Callable[..., torch.Tensor | ComplexTensor]: + def impl( + self: ComplexTensor, *args, dtype: torch.dtype | None = None, **kwargs + ) -> torch.Tensor | ComplexTensor: + self_re, self_im = split_complex_tensor(self) + + if dtype is not None and dtype not in COMPLEX_TO_REAL: + return op(self_re, *args, dtype=dtype, **kwargs) + + if dtype is not None: + kwargs["dtype"] = COMPLEX_TO_REAL[dtype] + + ret_re = op(self_re, *args, **kwargs) + ret_im = op(self_im, *args, **kwargs) + + return ComplexTensor(ret_re, ret_im) + + func_name = _get_func_name(op) + impl.__name__ = func_name + impl.__qualname__ = func_name + + return register_complex(op, impl) + + +LIKE_OPS_LIST = [ + aten.empty_like, + aten.zeros_like, + aten.randn_like, + aten.new_zeros, +] + +for like_op in LIKE_OPS_LIST: + globals()[_get_func_name(like_op)] = register_like(like_op) + +del like_op + + +@register_complex(aten.cat) +def cat_impl(tensors: Sequence[ComplexTensor], dim: int = 0) -> ComplexTensor: + tensors_r = [] + tensors_i = [] + + for t in tensors: + t_r, t_i = split_complex_arg(t) + tensors_r.append(t_r) + tensors_i.append(t_i) + + ret_r = torch.cat(tensors_r, dim=dim) + ret_i = torch.cat(tensors_i, dim=dim) + + return ComplexTensor(ret_r, ret_i) + + +@register_complex(aten.sgn) +def sgn_impl(self: ComplexTensor) -> ComplexTensor: + self_r, self_i = split_complex_tensor(self) + out_dt, (self_r, self_i) = promote_tensors(self_r, self_i) + abs_self = torch.abs(ComplexTensor(self_r, self_i)) + mask = (self_r != 0) | (self_i != 0) + masked_sgn = ComplexTensor( + (self_r / abs_self).to(out_dt), (self_i / abs_self).to(out_dt) + ) + return torch.where(mask, masked_sgn, 0) # type: ignore[bad-return] + + +@register_complex(aten.sqrt) +def sqrt_impl(self: ComplexTensor) -> ComplexTensor: + self_r, self_i = split_complex_tensor(self) + out_dt, (self_r, self_i) = promote_tensors(self_r, self_i) + self = ComplexTensor(self_r, self_i) + self_abs_sqrt = torch.sqrt(torch.abs(self)) + self_half_angle = 0.5 * torch.angle(self) + + ret_r = self_abs_sqrt * torch.cos(self_half_angle) + ret_i = self_abs_sqrt * torch.sin(self_half_angle) + + return ComplexTensor(ret_r.to(out_dt), ret_i.to(out_dt)) + + +@register_complex(aten.rsqrt) +def rsqrt_impl(self: ComplexTensor) -> ComplexTensor: + self_r, self_i = split_complex_tensor(self) + out_dt, (self_r, self_i) = promote_tensors(self_r, self_i) + self = ComplexTensor(self_r, self_i) + self_abs_rsqrt = torch.rsqrt(torch.abs(self)) + self_neg_half_angle = -0.5 * torch.angle(self) + + ret_r = self_abs_rsqrt * torch.cos(self_neg_half_angle) + ret_i = self_abs_rsqrt * torch.sin(self_neg_half_angle) + + return ComplexTensor(ret_r.to(out_dt), ret_i.to(out_dt)) + + +@register_complex(aten.addmm) +def addmm_impl( + input: ComplexTensor, + mat1: ComplexTensor, + mat2: ComplexTensor, + out_dtype: torch.dtype | None = None, + beta: complex = 1, + alpha: complex = 1, +) -> ComplexTensor: + ret = beta * input + alpha * torch.mm(mat1, mat2) + assert isinstance(ret, ComplexTensor) + ret_r, ret_i = split_complex_tensor(ret) + if out_dtype is not None: + out_dtype = COMPLEX_TO_REAL[out_dtype] + ret_r, ret_i = ret_r.to(out_dtype), ret_i.to(out_dtype) + return ComplexTensor(ret_r, ret_i) + + +def elemwise_nonzero(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return (re != 0) | (im != 0) + + +def register_nonzero_impl(op: OpType): + def nonzero_impl( + self: ComplexTensor, other: ComplexTensor, *args, **kwargs + ) -> torch.Tensor: + return op(elemwise_nonzero(self), elemwise_nonzero(other), *args, **kwargs) + + func_name = _get_func_name(op) + nonzero_impl.__name__ = func_name + nonzero_impl.__qualname__ = func_name + + return register_complex(op, nonzero_impl) + + +logical_and_impl = register_nonzero_impl(aten.logical_and) +logical_or_impl = register_nonzero_impl(aten.logical_or) +logical_xor_impl = register_nonzero_impl(aten.logical_xor) + + +@register_complex(aten.logical_not) +def logical_not_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + return torch.logical_not(elemwise_nonzero(self), *args, **kwargs) + + +@register_complex(aten.view_as_real) +def view_as_real_impl(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return torch.stack([re, im], dim=-1) + + +@register_complex(aten.linalg_vector_norm) +def linalg_vector_norm_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + return torch.linalg.vector_norm(torch.abs(self), *args, **kwargs) + + +@register_force_test(aten.copy_) +def copy__impl(self: ComplexTensor, src, *args, **kwargs): + self_re, self_im = split_complex_tensor(self) + src_re, src_im = split_complex_arg(src) + + ret_re = self_re.copy_(src_re, *args, **kwargs) + ret_im = self_im.copy_(src_im, *args, **kwargs) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten._local_scalar_dense) +def _local_scalar_dense_impl(self: ComplexTensor, *args, **kwargs) -> complex: + x, y = split_complex_tensor(self) + u = aten._local_scalar_dense(x, *args, **kwargs) + v = aten._local_scalar_dense(y, *args, **kwargs) + return complex(u, v) + + +@register_complex(aten.allclose) +def allclose_impl( + input: torch.Tensor, + other: torch.Tensor, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, +) -> bool: + return torch.all( + torch.isclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan) + ).item() # type: ignore[bad-return] + + +@register_complex(aten.stack) +def stack_impl(self: list[ComplexTensor], *args, **kwargs) -> ComplexTensor: + re_im_tuples = [split_complex_arg(self_i) for self_i in self] + u = torch.stack([c[0] for c in re_im_tuples], *args, **kwargs) + v = torch.stack([c[1] for c in re_im_tuples], *args, **kwargs) + return ComplexTensor(u, v) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten._conj_physical) +@register_complex(aten.conj_physical) +def conj_physical_impl(self: ComplexTensor) -> ComplexTensor: + re, im = split_complex_tensor(self) + return ComplexTensor(re, -im) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten._conj) +def _conj_impl(self: ComplexTensor) -> ComplexTensor: + re, im = split_complex_tensor(self) + return ComplexTensor(re, torch._neg_view(im)) + + +@register_complex(aten.index_add) +def index_add_impl( + self: ComplexTensor, dim: int, index: torch.Tensor, source: ComplexTensor, **kwargs +) -> ComplexTensor: + alpha = kwargs.pop("alpha", None) + if alpha is not None: + source = source * alpha + self_re, self_im = split_complex_arg(self) + source_re, source_im = split_complex_arg(source) + + ret_re = self_re.index_add(dim, index, source_re) + ret_im = self_im.index_add(dim, index, source_im) + + return ComplexTensor(ret_re, ret_im) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten.index_add_) +def index_add__impl( + self: ComplexTensor, dim: int, index: torch.Tensor, source: ComplexTensor, **kwargs +) -> ComplexTensor: + alpha = kwargs.pop("alpha", None) + if alpha is not None: + source = source * alpha + + self_re, self_im = split_complex_arg(self) + source_re, source_im = split_complex_arg(source) + + ret_re = self_re.index_add_(dim, index, source_re) + ret_im = self_im.index_add_(dim, index, source_im) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.masked_fill) +def masked_fill_impl( + self: ComplexTensor, mask: torch.Tensor, value: complex +) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + value_re, value_im = split_complex_arg(value) + + ret_re = self_re.masked_fill(mask, value_re) + ret_im = self_im.masked_fill(mask, value_im) + + return ComplexTensor(ret_re, ret_im) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten.masked_fill_) +def masked_fill__impl( + self: ComplexTensor, mask: torch.Tensor, value: complex +) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + value_re, value_im = split_complex_arg(value) + + ret_re = self_re.masked_fill_(mask, value_re) + ret_im = self_im.masked_fill_(mask, value_im) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.constant_pad_nd) +def constant_pad_nd_impl( + self: ComplexTensor, pad, value: complex | None = None +) -> ComplexTensor: + self_re, self_im = split_complex_tensor(self) + if value is None: + ret_re = aten.constant_pad_nd(self_re, pad) + ret_im = aten.constant_pad_nd(self_im, pad) + else: + value_re, value_im = split_complex_arg(value) + ret_re = aten.constant_pad_nd(self_re, pad, value_re) + ret_im = aten.constant_pad_nd(self_im, pad, value_im) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.var) +def var_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + self_re, self_im = split_complex_tensor(self) + return torch.var(self_re, *args, **kwargs) + torch.var(self_im, *args, **kwargs) + + +@register_complex(aten.scatter_add) +def scatter_add_impl( + self: ComplexTensor, dim, index, src: ComplexTensor +) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + src_re, src_im = split_complex_arg(src) + + ret_re = torch.scatter_add(self_re, dim, index, src_re) + ret_im = torch.scatter_add(self_im, dim, index, src_im) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.scatter_add_) +def scatter_add__impl( + self: ComplexTensor, dim, index, src: ComplexTensor +) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + src_re, src_im = split_complex_arg(src) + + out_re = self_re.scatter_add_(dim, index, src_re) + out_im = self_im.scatter_add_(dim, index, src_im) + + return ComplexTensor(out_re, out_im) + + +@register_complex(aten.index_put_) +def index_put__impl( + self: ComplexTensor, + indices: tuple[torch.Tensor, ...], + values: ComplexTensor, + accumulate: bool = False, +) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + values_re, values_im = split_complex_arg(values) + + out_re = self_re.index_put_(indices, values_re, accumulate=accumulate) + out_im = self_im.index_put_(indices, values_im, accumulate=accumulate) + + return ComplexTensor(out_re, out_im) + + +@register_complex(aten.tanh_backward) +def tanh_backward(out_grad: torch.Tensor, y: torch.Tensor): + return out_grad * (1.0 - y * y).conj_physical() + + +@register_complex(aten.diagonal_backward) +def diagonal_backward( + grad_output: torch.Tensor, input_sizes: list[int], offset: int, dim1: int, dim2: int +): + grad_input = grad_output.new_zeros(input_sizes) + return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2) + + +def _dt_to_real(dt: torch.dtype | Any) -> torch.dtype | Any: + if not isinstance(dt, torch.dtype): + return dt + + return COMPLEX_TO_REAL[dt] + + +def register_to_impl(op: OpType): + """Register an op similar to `aten.to`, but may have different signatures.""" + + def impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor | ComplexTensor: + x, y = split_complex_tensor(self) + try: + args = tuple(_dt_to_real(a) for a in args) + kwargs = {k: _dt_to_real(v) for k, v in kwargs.items()} + except KeyError: + return op(x, *args, **kwargs) + + return ComplexTensor(op(x, *args, **kwargs), op(y, *args, **kwargs)) + + func_name = _get_func_name(op) + impl.__name__ = func_name + impl.__qualname__ = func_name + + return register_complex(op, impl) + + +to_impl = register_to_impl(aten.to) +_to_copy_impl = register_to_impl(aten._to_copy) diff --git a/torch/_subclasses/complex_tensor/_ops/common.py b/torch/_subclasses/complex_tensor/_ops/common.py new file mode 100644 index 0000000000000..88532efe224bb --- /dev/null +++ b/torch/_subclasses/complex_tensor/_ops/common.py @@ -0,0 +1,317 @@ +from collections.abc import Callable +from typing import Any, overload, TypeAlias +from typing_extensions import TypeIs + +import torch +from torch import Tensor +from torch._decomp import get_decompositions +from torch._ops import OpOverload, OpOverloadPacket +from torch._refs import is_complex as _is_complex +from torch.types import Number +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten + +from .._core import ComplexTensor + + +OpType: TypeAlias = OpOverloadPacket | OpOverload + +TableType: TypeAlias = dict[OpType, Callable] + +# Mapping from ops to implementations +COMPLEX_OPS_TABLE: TableType = {} + +COMPLEX_TO_REAL = { + torch.complex128: torch.float64, + torch.complex64: torch.float32, + torch.complex32: torch.float16, +} + +REAL_TO_COMPLEX = {v: k for k, v in COMPLEX_TO_REAL.items()} + +# Used to promote dtypes in `promote_real_cpu_tensors` +PROMOTE_TYPES = { + torch.float16: torch.float32, + torch.bfloat16: torch.float32, + torch.complex32: torch.complex64, +} + + +def is_complex_tensor(obj: Any, /) -> TypeIs[ComplexTensor]: + r"""Returns True if the input is a ComplexTensor, else False + + Args: + a: any input + + Examples: + + >>> # xdoctest: +SKIP + >>> from torch.complex import ComplexTensor + >>> data = torch.zeros((3, 2), dtype=torch.complex64) + >>> ct = ComplexTensor.from_interleaved(data) + >>> is_complex_tensor(ct) + True + """ + return isinstance(obj, ComplexTensor) + + +@overload +def promote_tensors( + *tensors: ComplexTensor, +) -> tuple[torch.dtype, tuple[ComplexTensor, ...]]: ... + + +@overload +def promote_tensors( + *tensors: Tensor, +) -> tuple[torch.dtype, tuple[Tensor, ...]]: ... + + +def promote_tensors( + *tensors: Tensor | ComplexTensor, +) -> tuple[torch.dtype, tuple[Tensor | ComplexTensor, ...]]: + """ + Promotes all tensors to a common dtype. + Additionally promotes CPU tensors to at least `float32`. + """ + tensor = next(t for t in tensors if isinstance(t, Tensor)) + out_dt = tensor.dtype + for t in tensors: + if isinstance(t, Tensor): + out_dt = torch.promote_types(out_dt, t.dtype) + + prom_dt = PROMOTE_TYPES.get(out_dt, out_dt) + return out_dt, tuple( + t.to(prom_dt) if isinstance(t, Tensor) else torch.asarray(t, dtype=prom_dt) + for t in tensors + ) + + +def register_complex( + op: OpType, + func_impl: Callable | None = None, +): + """Decorator to register an implementation for some ops in some dispatch tables""" + + def inner(func): + if COMPLEX_OPS_TABLE.get(op, func) is not func: + raise RuntimeError(f"Attempted to register multiple functions for {op}") + COMPLEX_OPS_TABLE[op] = func + return func + + if func_impl is None: + return inner + + return inner(func_impl) + + +FORCE_TEST_LIST: list[OpType] = [] + + +def register_force_test(op: OpType, *args, **kwargs): + """Will attempt to test these ops even if they err on "normal" inputs""" + FORCE_TEST_LIST.append(op) + return register_complex(op, *args, **kwargs) + + +DECOMPOSITIONS = get_decompositions(list(torch.ops.aten)) # type: ignore[no-matching-overload] + + +def lookup_complex(func: OpOverload, *args, **kwargs) -> Callable | None: + """ + Lookup an impl from the table. + + Try the particular overload first, then the overload packet. + + If nothing is found, try the decompositions with both. + """ + return COMPLEX_OPS_TABLE.get( + func, + COMPLEX_OPS_TABLE.get( + func.overloadpacket, + DECOMPOSITIONS.get(func, DECOMPOSITIONS.get(func.overloadpacket)), + ), + ) + + +def is_complex(x: Any, /) -> bool: + """Utility to detect if a given object is (known) to be complex.""" + return (isinstance(x, Tensor) and _is_complex(x)) or isinstance(x, complex) + + +@overload +def split_complex_arg( + arg: Tensor | ComplexTensor, +) -> tuple[Tensor, Tensor]: ... + + +@overload +def split_complex_arg( + arg: complex | Number, +) -> tuple[Number, Number]: ... + + +def split_complex_arg( + arg: Tensor | ComplexTensor | complex | Number, +) -> tuple[Tensor, Tensor] | tuple[Number, Number]: + """ + Split a complex argument into a real/imaginary component. + + If real, use zero for the imaginary part. + """ + if isinstance(arg, ComplexTensor): + return split_complex_tensor(arg) + if isinstance(arg, Tensor): + if is_complex(arg): + return arg.real, arg.imag + return arg, torch.zeros_like(arg) + # TODO (hameerabbasi): Should there be a `torch.SymComplex`? + if isinstance(arg, complex): + return arg.real, arg.imag + if isinstance(arg, float | torch.SymFloat): + return arg, 0.0 + if isinstance(arg, int | torch.SymInt): + return arg, 0 + if isinstance(arg, bool | torch.SymBool): + return arg, False + raise TypeError(f"Expected tensor or number got, {type(arg)}") + + +def split_complex_tensor(complex_tensor: ComplexTensor) -> tuple[Tensor, Tensor]: + """Split a ComplexTensor into its real and imaginary parts.""" + return complex_tensor.re, complex_tensor.im + + +def complex_to_real_dtype(dtype: torch.dtype) -> torch.dtype: + """Convert a complex dtype to the dtype of its real part. Return other dtypes as-is.""" + return COMPLEX_TO_REAL.get(dtype, dtype) + + +def _get_op_name(op: OpType) -> str: + """Get the op name from the op.""" + if isinstance(op, OpOverload): + op = op.overloadpacket + return str(op).split(".", 1)[1] + + +def _get_func_name(op: OpType) -> str: + """Get the name of the implementation function from the op.""" + return f"{_get_op_name(op)}_impl" + + +def register_error(op: OpType, exc_type: type[Exception] = NotImplementedError): + msg = f"`aten.{_get_op_name(op)}` not implemented for `{ComplexTensor.__name__}`." + + def ordered_impl(*args, **kwargs): + raise exc_type(msg) + + func_name = _get_func_name(op) + ordered_impl.__name__ = func_name + ordered_impl.__qualname__ = func_name + + return register_force_test(op, ordered_impl) + + +def register_binary_nonlinear(op: OpType) -> Callable: + """Register a "multiplication-style" op, e.g. aten.mul, aten.mm, ...""" + + def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor: + a_r, a_i = split_complex_arg(lhs) + b_r, b_i = split_complex_arg(rhs) + out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i) + real = op(a_r, b_r, *args, **kwargs) - op(a_i, b_i, *args, **kwargs) + imag = op(a_r, b_i, *args, **kwargs) + op(a_i, b_r, *args, **kwargs) + return ComplexTensor(real.to(out_dt), imag.to(out_dt)) + + func_name = _get_func_name(op) + impl.__name__ = func_name + impl.__qualname__ = func_name + + return register_complex(op, impl) + + +def register_simple(op: OpType): + """Register an op which can be applied independently to the real and complex parts to get the result.""" + + def impl( + self: ComplexTensor, *args, dtype: torch.dtype | None = None, **kwargs + ) -> ComplexTensor: + x, y = split_complex_tensor(self) + if dtype is not None and dtype not in COMPLEX_TO_REAL: + raise RuntimeError( + "Non-complex `dtype` specified, please write custom impl." + ) + + if dtype in COMPLEX_TO_REAL: + assert dtype is not None + kwargs["dtype"] = COMPLEX_TO_REAL[dtype] + + u = op(x, *args, **kwargs) + v = op(y, *args, **kwargs) + + u_flat, u_spec = tree_flatten(u) + v_flat, v_spec = tree_flatten(v) + assert u_spec == v_spec + out_flat = [ + ComplexTensor(ui, vi) for ui, vi in zip(u_flat, v_flat, strict=False) + ] + return tree_unflatten(out_flat, u_spec) + + func_name = _get_func_name(op) + impl.__name__ = func_name + impl.__qualname__ = func_name + + return register_complex(op, impl) + + +def _as_complex_tensor(arg: Tensor | Any) -> Tensor | ComplexTensor | Any: + """Convert a Tensor with complex dtypes to a ComplexTensor. Pass along other args as-is.""" + if ( + not isinstance(arg, ComplexTensor) + and isinstance(arg, Tensor) + and arg.dtype in COMPLEX_TO_REAL + ): + return ComplexTensor.from_interleaved(arg) + return arg + + +def _as_interleaved(arg: ComplexTensor | Any) -> Tensor | Any: + """Convert a ComplexTensor to a Tensor with a complex dtype. Pass other arguments as-is.""" + if isinstance(arg, ComplexTensor): + return arg.as_interleaved() + return arg + + +class ComplexTensorMode(TorchDispatchMode): + _compile: bool + + """ A TorchDispatchMode to replace any Tensor that has a complex dtype with a ComplexTensor for the computation. """ + + def __init__(self, _dispatch_key=None, *, _compile: bool = False): + """Initialize a ComplexTensorMode. + + Args: + _dispatch_key: passed on to TorchDispatchMode + _compile: Compile the op before the computation + """ + super().__init__(_dispatch_key) + self._compile = _compile + + def __torch_dispatch__( + self, + func: OpOverload, + types: tuple[type], + args: tuple = (), + kwargs: dict[str, Any] | None = None, + ): + if kwargs is None: + kwargs = {} + + # TODO (hameerabbasi): Test perf with `_compile` set to `True` + if self._compile: + func = torch.compile(func) # type: ignore[bad-assignment] + + args = tree_map(_as_complex_tensor, args) + kwargs = tree_map(_as_complex_tensor, kwargs) + + return tree_map(_as_interleaved, func(*args, **kwargs)) diff --git a/torch/_subclasses/complex_tensor/_ops/prims.py b/torch/_subclasses/complex_tensor/_ops/prims.py new file mode 100644 index 0000000000000..9a237b32d9904 --- /dev/null +++ b/torch/_subclasses/complex_tensor/_ops/prims.py @@ -0,0 +1,34 @@ +import torch + +from .._core import ComplexTensor +from .common import ( + complex_to_real_dtype, + register_complex, + register_force_test, + split_complex_tensor, +) + + +prims = torch.ops.prims +aten = torch.ops.aten + + +# TODO (hameerabbasi): Not being tested +@register_force_test(prims.convert_element_type) +def convert_element_type_impl(x: ComplexTensor, dtype: torch.dtype) -> ComplexTensor: + dtype = complex_to_real_dtype(dtype) + u, v = split_complex_tensor(x) + u_out = prims.convert_element_type(u, dtype) + v_out = prims.convert_element_type(v, dtype) + + return ComplexTensor(u_out, v_out) + + +@register_complex(prims.conj_physical) +def conj_physical_impl(self: ComplexTensor) -> ComplexTensor: + return aten._conj_physical(self) + + +@register_complex(prims.conj) +def conj_impl(self: ComplexTensor) -> ComplexTensor: + return aten._conj(self) From 0e13964b7482576dab1c2ec81e89899e4f17a361 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 18 Nov 2025 09:47:49 -0800 Subject: [PATCH 004/230] [CI] Disable ET tests (again) (#168090) Repeatition of https://github.com/pytorch/pytorch/pull/155708 Has been broken for a while, and ET pin in Pytorch are so old that `torch==2.10.0.dev20250915` could no longer be found in nightly indices Pull Request resolved: https://github.com/pytorch/pytorch/pull/168090 Approved by: https://github.com/atalman, https://github.com/yangw-dev --- .github/workflows/docker-builds.yml | 3 ++- .github/workflows/trunk.yml | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 408a8f0000504..5700c8e3c74b3 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -75,7 +75,8 @@ jobs: pytorch-linux-jammy-py3-clang12-onnx, pytorch-linux-jammy-linter, pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-linter, - pytorch-linux-jammy-py3-clang12-executorch, + # TODO: Re-enable me when docker pin update happens + # pytorch-linux-jammy-py3-clang12-executorch, pytorch-linux-jammy-py3.12-triton-cpu, pytorch-linux-noble-riscv64-py3.12-gcc14 ] diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 667c37727045b..6e775da47fc1e 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -283,6 +283,7 @@ jobs: name: linux-jammy-py3-clang12-executorch uses: ./.github/workflows/_linux-build.yml needs: get-label-type + if: false # Has been broken for a while with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3-clang12-executorch From 5333e511950d460fca7f4f8a2f868db2275ffc76 Mon Sep 17 00:00:00 2001 From: Aidyn-A Date: Tue, 18 Nov 2025 18:45:47 +0000 Subject: [PATCH 005/230] [CUDA][Thor] Enable CUTLASS matmuls on Thor (#164836) This PR enables special matmuls on Thor devices. This includes row-wise scaled matmul on `fp8` and group gemm on `bfloat16`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164836 Approved by: https://github.com/ngimel --- aten/src/ATen/native/cuda/GroupMM.cu | 3 ++- aten/src/ATen/native/cuda/RowwiseScaledMM.cu | 5 +++-- cmake/Codegen.cmake | 10 ++++++++-- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/cuda/GroupMM.cu b/aten/src/ATen/native/cuda/GroupMM.cu index a917b0d6163fa..3f4f998d92cd6 100644 --- a/aten/src/ATen/native/cuda/GroupMM.cu +++ b/aten/src/ATen/native/cuda/GroupMM.cu @@ -346,8 +346,9 @@ void dispatch_bf16_grouped_kernel_on_tile_size( bool small = (M <= 128 || N <= 128); cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties(); const bool sm10x = properties != nullptr && properties->major == 10; + const bool sm11x = properties != nullptr && properties->major == 11; - if (sm10x) { + if (sm10x || sm11x) { if (small){ bf16bf16_grouped_gemm_impl_sm90_sm100< cutlass::arch::Sm100, diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu index 382a5a065b300..8971e05094651 100644 --- a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu +++ b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu @@ -958,8 +958,9 @@ void dispatch_fp8_rowwise_kernel_on_sm( const bool sm89 = properties != nullptr && properties->major == 8 && properties->minor == 9; const bool sm9x = properties != nullptr && properties->major == 9; const bool sm10x = properties != nullptr && properties->major == 10; + const bool sm11x = properties != nullptr && properties->major == 11; const bool sm12x = properties != nullptr && properties->major == 12; - if (!(sm89 || sm9x || sm10x || sm12x)) { + if (!(sm89 || sm9x || sm10x || sm11x || sm12x)) { TORCH_CHECK( false, "Rowwise scaling is not currently supported on your device"); } @@ -968,7 +969,7 @@ void dispatch_fp8_rowwise_kernel_on_sm( dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose< /*ArchTag=*/cutlass::arch::Sm90, Types...>(XQ, WQ, x_scale, w_scale, bias, out); - } else if (sm10x) { + } else if (sm10x || sm11x) { dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose< /*ArchTag=*/cutlass::arch::Sm100, Types...>(XQ, WQ, x_scale, w_scale, bias, out); diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index bac1fa7daac01..5faad21f9f6cd 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -113,6 +113,12 @@ if(INTERN_BUILD_ATEN_OPS) list(APPEND _file_compile_flags "-gencode;arch=compute_103a,code=sm_103a") endif() endif() + # We will need to gate against CUDA version, because sm_110a is available on CUDA 13.0+ + if("${_arch}" STREQUAL "110a" AND CUDA_VERSION VERSION_GREATER_EQUAL 13.0) + if(_existing_arch_flags MATCHES ".*compute_110.*") + list(APPEND _file_compile_flags "-gencode;arch=compute_110a,code=sm_110a") + endif() + endif() if("${_arch}" STREQUAL "120a") if(_existing_arch_flags MATCHES ".*compute_120.*") list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a") @@ -132,13 +138,13 @@ if(INTERN_BUILD_ATEN_OPS) _BUILD_FOR_ADDITIONAL_ARCHS( "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu" - "89;90a;100a;103a;120a;121a") + "89;90a;100a;103a;110a;120a;121a") _BUILD_FOR_ADDITIONAL_ARCHS( "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu" "90a") _BUILD_FOR_ADDITIONAL_ARCHS( "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/GroupMM.cu" - "90a;100a;103a") + "90a;100a;103a;110a") endif() From d1f6dd61055a71ec2072dc0d006e9f12d8e23501 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Tue, 18 Nov 2025 19:00:24 +0000 Subject: [PATCH 006/230] distributed/debug: add an HTTP server for debugging running jobs (#167395) This adds a debug HTTP server for debugging stuck or slow jobs. It runs the WorkerServer on every worker and then launches a separate flask process on rank 0 to have users connect to for debugging. This can easily be improved to trigger profilers as well as visualize the data much better. Initial handlers: * pytorch profiler * FlightRecorder data * Python stacks ``` os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "2000" from torch.distributed.debug import enable_debug_server enable_debug_server() ``` Test plan: ``` torchrun --nnodes 1 --nproc_per_node=gpu ~/scripts/debug_test.py ``` 20251117_16h58m18s_grim 20251117_16h58m11s_grim 20251117_16h58m03s_grim Pull Request resolved: https://github.com/pytorch/pytorch/pull/167395 Approved by: https://github.com/fduwjj, https://github.com/malfet, https://github.com/atalman --- .ci/docker/requirements-ci.txt | 3 + docs/source/distributed.md | 18 + test/distributed/test_debug.py | 56 +++ torch/_C/_distributed_c10d.pyi | 17 +- torch/_C/_profiler.pyi | 1 + .../c10d/control_plane/Handlers.cpp | 10 + .../c10d/control_plane/Handlers.hpp | 8 + .../c10d/control_plane/WorkerServer.cpp | 6 + .../c10d/control_plane/WorkerServer.hpp | 5 + torch/csrc/distributed/c10d/init.cpp | 31 +- torch/distributed/debug/__init__.py | 82 ++++ torch/distributed/debug/_frontend.py | 353 ++++++++++++++++++ torch/distributed/debug/_handlers.py | 22 ++ torch/distributed/debug/_store.py | 24 ++ 14 files changed, 629 insertions(+), 7 deletions(-) create mode 100644 test/distributed/test_debug.py create mode 100644 torch/distributed/debug/__init__.py create mode 100644 torch/distributed/debug/_frontend.py create mode 100644 torch/distributed/debug/_handlers.py create mode 100644 torch/distributed/debug/_store.py diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index f3636071714f8..242cbaafa059e 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -402,3 +402,6 @@ scikit-build==0.18.1 pyre-extensions==0.0.32 tabulate==0.9.0 #Description: These package are needed to build FBGEMM and torchrec on PyTorch CI + +Jinja2==3.1.6 +#Description: required for torch.distributed.debug diff --git a/docs/source/distributed.md b/docs/source/distributed.md index ca1fe3b5e9099..6840bbb893bf7 100644 --- a/docs/source/distributed.md +++ b/docs/source/distributed.md @@ -987,6 +987,24 @@ In addition, `TORCH_DISTRIBUTED_DEBUG=DETAIL` can be used in conjunction with `T collective desynchronization checks will work for all applications that use `c10d` collective calls backed by process groups created with the {func}`torch.distributed.init_process_group` and {func}`torch.distributed.new_group` APIs. + +### torch.distributed.debug HTTP Server + +The `torch.distributed.debug` module provides a HTTP server that can be used to debug distributed applications. The server can +be started by calling {func}`torch.distributed.debug.start_debug_server`. This +allows users to collect data across all workers at runtime. + +```{eval-rst} +.. automodule:: torch.distributed.debug + :members: + :undoc-members: + :show-inheritance: + :special-members: __init__ + :member-order: bysource + +``` + + ## Logging In addition to explicit debugging support via {func}`torch.distributed.monitored_barrier` and `TORCH_DISTRIBUTED_DEBUG`, the underlying C++ library of `torch.distributed` also outputs log diff --git a/test/distributed/test_debug.py b/test/distributed/test_debug.py new file mode 100644 index 0000000000000..ff6a203bcf160 --- /dev/null +++ b/test/distributed/test_debug.py @@ -0,0 +1,56 @@ +# Owner(s): ["oncall: distributed"] + +import os + +import requests +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry + +import torch +import torch.distributed as dist +from torch.distributed.debug import start_debug_server, stop_debug_server +from torch.testing._internal.common_utils import run_tests, TestCase + + +session = requests.Session() +retry_strategy = Retry(total=5, backoff_factor=0.5) +adapter = HTTPAdapter(max_retries=retry_strategy) +session.mount("http://", adapter) +session.mount("https://", adapter) + + +class TestDebug(TestCase): + def test_basics(self) -> None: + store = dist.TCPStore("localhost", 0, 1, is_master=True, wait_for_workers=False) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(store.port) + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + + port = 25999 + + def fetch(path: str) -> str: + resp = session.get(f"http://localhost:{port}{path}") + resp.raise_for_status() + return resp.text + + start_debug_server(port=port) + + self.assertIn("torch profiler", fetch("/")) + self.assertIn("View 0", fetch("/profile?duration=0.01")) + self.assertIn("test_basics", fetch("/stacks")) + self.assertIn("pg_status", fetch("/fr_trace")) + + if torch.cuda.is_available(): + self.assertIn("pg_status", fetch("/fr_trace_nccl")) + + # test errors + resp = session.get(f"http://localhost:{port}/blah") + self.assertEqual(resp.status_code, 404) + self.assertIn("Handler not found: /blah", resp.text) + + stop_debug_server() + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 752bd594d066f..477b35b1811e4 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -100,7 +100,9 @@ class Logger: def _set_static_graph(self) -> None: ... class _WorkerServer: - def __init__(self, socket_path: str) -> None: ... + port: int + + def __init__(self, host_or_file: str, port: int = ...) -> None: ... def shutdown(self) -> None: ... def get_debug_level(): ... @@ -206,6 +208,7 @@ class Store: desired_value: str, ) -> bytes: ... def delete_key(self, key: str) -> bool: ... + def multi_get(self, keys: list[str]) -> list[bytes]: ... def num_keys(self) -> int: ... def set_timeout(self, timeout: timedelta): ... @overload @@ -872,3 +875,15 @@ class ProcessGroupXCCL(Backend): def _set_process_group(pg: ProcessGroup) -> None: ... def _current_process_group() -> ProcessGroup: ... + +class _Request: + def body(self) -> bytes: ... + def get_param(self, str) -> str: ... + +class _Response: + def set_content(self, content: str | bytes, content_type: str) -> None: ... + def set_status(self, status: int) -> None: ... + +def _register_handler( + name: str, handler: Callable[[_Request, _Response], None] +) -> None: ... diff --git a/torch/_C/_profiler.pyi b/torch/_C/_profiler.pyi index d60d89a6a4796..de12af50c1855 100644 --- a/torch/_C/_profiler.pyi +++ b/torch/_C/_profiler.pyi @@ -60,6 +60,7 @@ class _ExperimentalConfig: verbose: bool = ..., performance_events: list[str] = ..., enable_cuda_sync_events: bool = ..., + profile_all_threads: bool = ..., ) -> None: ... class ProfilerConfig: diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.cpp b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp index 10274d053b995..fe8f831a23bb1 100644 --- a/torch/csrc/distributed/c10d/control_plane/Handlers.cpp +++ b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp @@ -1,5 +1,7 @@ #include +#include + #include #include #include @@ -63,6 +65,14 @@ RegisterHandler pingHandler{"ping", [](const Request&, Response& res) { res.setStatus(200); }}; +RegisterHandler frTracehandler( + "fr_trace_json", + [](const Request&, Response& res) { + auto trace = ::c10d::dump_fr_trace_json(true, true); + res.setContent(std::move(trace), "application/json"); + res.setStatus(200); + }); + } // namespace void registerHandler(const std::string& name, HandlerFunc f) { diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.hpp b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp index 70333a3a4844c..58ae9368ea212 100644 --- a/torch/csrc/distributed/c10d/control_plane/Handlers.hpp +++ b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp @@ -18,6 +18,14 @@ class TORCH_API Request { virtual const std::string& body() const = 0; virtual const std::multimap& params() const = 0; + + std::string getParam(const std::string& key) const { + auto it = params().find(key); + if (it != params().end()) { + return it->second; + } + return ""; + } }; // Response represents a response to the handler. This conceptually maps to an diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp index 8bbe857620790..eda6ee3a91488 100644 --- a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp +++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp @@ -152,11 +152,17 @@ WorkerServer::WorkerServer(const std::string& hostOrFile, int port) { TORCH_CHECK( server_.bind_to_port(hostOrFile, 80), fmt::format("Error binding to {}", hostOrFile)); + } else if (port == 0) { + C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port); + port_ = server_.bind_to_any_port(hostOrFile); + TORCH_CHECK( + port_ >= 0, fmt::format("Error binding to {}:{}", hostOrFile, port)); } else { C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port); TORCH_CHECK( server_.bind_to_port(hostOrFile, port), fmt::format("Error binding to {}:{}", hostOrFile, port)); + port_ = port; } serverThread_ = std::thread([this]() { diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp index 41c1356fc01f3..20d05b7509e92 100644 --- a/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp +++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp @@ -19,9 +19,14 @@ class TORCH_API WorkerServer : public c10::intrusive_ptr_target { void shutdown(); + int port() { + return port_; + } + private: httplib::Server server_; std::thread serverThread_; + int port_; }; } // namespace c10d::control_plane diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 94a8c0bbe228b..255e793eaa4df 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -46,6 +46,7 @@ #include #include +#include #include #include #include @@ -4209,7 +4210,9 @@ such as `dist.all_reduce(tensor, async_op=True)`. }), py::arg("host_or_file"), py::arg("port") = -1) - .def("shutdown", &::c10d::control_plane::WorkerServer::shutdown); + .def("shutdown", &::c10d::control_plane::WorkerServer::shutdown) + .def_property_readonly( + "port", &::c10d::control_plane::WorkerServer::port); module.def( "_get_handler", @@ -4225,6 +4228,25 @@ such as `dist.all_reduce(tensor, async_op=True)`. Returns the handler with the specified name. )"); + module.def( + "_register_handler", + [](const std::string& name, const py::function& handler) { + ::c10d::control_plane::registerHandler( + name, + [handler]( + const ::c10d::control_plane::Request& req, + ::c10d::control_plane::Response& res) { + py::gil_scoped_acquire acquire; + handler(std::ref(req), std::ref(res)); + }); + }, + + py::arg("name"), + py::arg("handler"), + R"( + Registers a handler by name. + )"); + module.def( "_get_handler_names", &::c10d::control_plane::getHandlerNames, @@ -4242,12 +4264,9 @@ such as `dist.all_reduce(tensor, async_op=True)`. // Default constructor. .def(py::init<>()) .def("body", &::c10d::control_plane::Request::body) - .def("params", &::c10d::control_plane::Request::params); + .def("get_param", &::c10d::control_plane::Request::getParam); - py::class_< - ::c10d::control_plane::Response, - std::shared_ptr<::c10d::control_plane::Response>, - PythonResponse>( + py::class_<::c10d::control_plane::Response, PythonResponse>( module, "_Response", R"( diff --git a/torch/distributed/debug/__init__.py b/torch/distributed/debug/__init__.py new file mode 100644 index 0000000000000..46267a686e86d --- /dev/null +++ b/torch/distributed/debug/__init__.py @@ -0,0 +1,82 @@ +import logging +import multiprocessing +import socket + +# import for registration side effect +import torch.distributed.debug._handlers # noqa: F401 +from torch._C._distributed_c10d import _WorkerServer +from torch.distributed.debug._store import get_rank, tcpstore_client + + +__all__ = [ + "start_debug_server", + "stop_debug_server", +] + +logger: logging.Logger = logging.getLogger(__name__) + +_WORKER_SERVER: _WorkerServer | None = None +_DEBUG_SERVER_PROC: multiprocessing.Process | None = None + + +def start_debug_server(port: int = 25999, worker_port: int = 0) -> None: + """ + Start the debug server stack on all workers. The frontend debug server is + only started on rank0 while the per rank worker servers are started on all + ranks. + + This server provides an HTTP frontend that allows for debugging slow and + deadlocked distributed jobs across all ranks simultaneously. This collects + data such as stack traces, FlightRecorder events, and performance profiles. + + WARNING: This is intended to only be used in trusted network environments. + The debug server is not designed to be secure and should not be exposed to + the public internet. See SECURITY.md for more details. + + WARNING: This is an experimental feature and may change at any time. + + Args: + port (int): The port to start the frontend debug server on. + worker_port (int): The port to start the worker server on. Defaults to 0, which + will cause the worker server to bind to an ephemeral port. + """ + global _WORKER_SERVER, _DEBUG_SERVER_PROC + + assert _WORKER_SERVER is None, "debug server already started" + assert _DEBUG_SERVER_PROC is None, "debug server already started" + + logger.info("Starting debug server on port %d", port) + + store = tcpstore_client() + + _WORKER_SERVER = _WorkerServer("::", worker_port) + + RANK = get_rank() + store.set(f"rank{RANK}", f"http://{socket.gethostname()}:{_WORKER_SERVER.port}") + + from torch.distributed.debug._frontend import main + + if RANK == 0: + _DEBUG_SERVER_PROC = multiprocessing.Process( + target=main, args=(port,), daemon=True + ) + _DEBUG_SERVER_PROC.start() + + +def stop_debug_server() -> None: + """ + Shutdown the debug server and stop the frontend debug server process. + """ + global _WORKER_SERVER, _DEBUG_SERVER_PROC + + assert _DEBUG_SERVER_PROC is not None + assert _WORKER_SERVER is not None + + logger.info("Stopping debug server") + + _DEBUG_SERVER_PROC.terminate() + _WORKER_SERVER.shutdown() + _DEBUG_SERVER_PROC.join() + + _WORKER_SERVER = None + _DEBUG_SERVER_PROC = None diff --git a/torch/distributed/debug/_frontend.py b/torch/distributed/debug/_frontend.py new file mode 100644 index 0000000000000..622c41ca8bd64 --- /dev/null +++ b/torch/distributed/debug/_frontend.py @@ -0,0 +1,353 @@ +import json +import logging +import socket +import threading +from collections.abc import Iterator +from concurrent.futures import ThreadPoolExecutor +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from urllib.parse import parse_qs, urlparse + +import requests +from jinja2 import DictLoader, Environment + +from torch.distributed.debug._store import get_world_size, tcpstore_client + + +logger: logging.Logger = logging.getLogger(__name__) + + +def fetch_all( + endpoint: str, args: str = "" +) -> tuple[list[str], Iterator[requests.Response]]: + store = tcpstore_client() + keys = [f"rank{r}" for r in range(get_world_size())] + addrs = store.multi_get(keys) + addrs = [f"{addr.decode()}/handler/{endpoint}?{args}" for addr in addrs] + + with ThreadPoolExecutor(max_workers=10) as executor: + resps = executor.map(requests.post, addrs) + + return addrs, resps + + +def format_json(blob: str): + parsed = json.loads(blob) + return json.dumps(parsed, indent=2) + + +templates = { + "base.html": """ + + + {% block title %}{% endblock %} - PyTorch Distributed + + + + + + + +
+ {% block header %}{% endblock %} + {% block content %}{% endblock %} +
+ """, + "index.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}Index{% endblock %}

+{% endblock %} +{% block content %} +Hi +{% endblock %} + """, + "raw_resp.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}{{title}}{% endblock %}

+{% endblock %} +{% block content %} + {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %} +

Rank {{ i }}: {{ addr }}

+ {% if resp.status_code != 200 %} +

Failed to fetch: status={{ resp.status_code }}

+
{{ resp.text }}
+ {% else %} +
{{ resp.text }}
+ {% endif %} + {% endfor %} +{% endblock %} + """, + "json_resp.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}{{ title }}{% endblock %}

+{% endblock %} +{% block content %} + {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %} +

Rank {{ i }}: {{ addr }}

+ {% if resp.status_code != 200 %} +

Failed to fetch: status={{ resp.status_code }}

+
{{ resp.text }}
+ {% else %} +
{{ format_json(resp.text) }}
+ {% endif %} + {% endfor %} +{% endblock %} + """, + "profile.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}torch.profiler{% endblock %}

+{% endblock %} + +{% block content %} +
+ + + +
+ + + + {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %} +

Rank {{ i }}: {{ addr }}

+ {% if resp.status_code != 200 %} +

Failed to fetch: status={{ resp.status_code }}

+
{{ resp.text }}
+ {% else %} + + + + {% endif %} + {% endfor %} +{% endblock %} + """, +} + + +class _IPv6HTTPServer(ThreadingHTTPServer): + address_family: socket.AddressFamily = socket.AF_INET6 # pyre-ignore + request_queue_size: int = 1024 + + +class HTTPRequestHandler(BaseHTTPRequestHandler): + frontend: "FrontendServer" + + def do_GET(self): + self.frontend._handle_request(self) + + def get_path(self) -> str: + return urlparse(self.path).path + + def get_query(self) -> dict[str, list[str]]: + return parse_qs(urlparse(self.path).query) + + def get_query_arg( + self, name: str, default: object = None, type: type = str + ) -> object: + query = self.get_query() + if name not in query: + return default + return type(query[name][0]) + + +class FrontendServer: + def __init__(self, port: int): + # Setup templates + loader = DictLoader(templates) + self._jinja_env = Environment(loader=loader, enable_async=True) + self._jinja_env.globals.update( + zip=zip, + format_json=format_json, + enumerate=enumerate, + ) + + # Create routes + self._routes = { + "/": self._handle_index, + "/stacks": self._handle_stacks, + "/fr_trace": self._handle_fr_trace, + "/fr_trace_nccl": self._handle_fr_trace_nccl, + "/profile": self._handle_profiler, + } + + # Create HTTP server + RequestHandlerClass = type( + "HTTPRequestHandler", + (HTTPRequestHandler,), + {"frontend": self}, + ) + + server_address = ("", port) + self._server = _IPv6HTTPServer(server_address, RequestHandlerClass) + + self._thread = threading.Thread( + target=self._serve, + args=(), + daemon=True, + ) + self._thread.start() + + def _serve(self) -> None: + try: + self._server.serve_forever() + except Exception: + logger.exception("got exception in checkpoint server") + + def join(self) -> None: + self._thread.join() + + def _handle_request(self, req: HTTPRequestHandler) -> None: + path = req.get_path() + if path not in self._routes: + req.send_error(404, f"Handler not found: {path}") + return + + handler = self._routes[path] + try: + resp = handler(req) + except Exception as e: + logger.exception( + "Exception in checkpoint server when handling %s", + path, + ) + req.send_error(500, str(e)) + return + + req.send_response(200) + req.send_header("Content-type", "text/html") + req.end_headers() + req.wfile.write(resp) + + def _render_template(self, template: str, **kwargs: object) -> bytes: + return self._jinja_env.get_template(template).render(**kwargs).encode() + + def _handle_index(self, req: HTTPRequestHandler) -> bytes: + return self._render_template("index.html") + + def _handle_stacks(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("dump_traceback") + return self._render_template( + "raw_resp.html", title="Stacks", addrs=addrs, resps=resps + ) + + def _handle_fr_trace(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("fr_trace_json") + + return self._render_template( + "json_resp.html", + title="FlightRecorder", + addrs=addrs, + resps=resps, + ) + + def _handle_fr_trace_nccl(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("dump_nccl_trace_json", "onlyactive=true") + + return self._render_template( + "json_resp.html", + title="FlightRecorder NCCL", + addrs=addrs, + resps=resps, + ) + + def _handle_profiler(self, req: HTTPRequestHandler) -> bytes: + duration = req.get_query_arg("duration", default=1.0, type=float) + + addrs, resps = fetch_all("torch_profile", f"duration={duration}") + + return self._render_template("profile.html", addrs=addrs, resps=resps) + + +def main(port: int) -> None: + server = FrontendServer(port=port) + logger.info("Frontend server started on port %d", server._server.server_port) + server.join() diff --git a/torch/distributed/debug/_handlers.py b/torch/distributed/debug/_handlers.py new file mode 100644 index 0000000000000..ba951b7bda075 --- /dev/null +++ b/torch/distributed/debug/_handlers.py @@ -0,0 +1,22 @@ +import tempfile +import time + +from torch._C._distributed_c10d import _register_handler, _Request, _Response +from torch.profiler import _ExperimentalConfig, profile + + +def _torch_profile(req: _Request, resp: _Response) -> None: + experimental_config = _ExperimentalConfig( + profile_all_threads=True, + ) + duration = float(req.get_param("duration")) + with profile(record_shapes=True, experimental_config=experimental_config) as prof: + time.sleep(duration) + + with tempfile.NamedTemporaryFile(prefix="torch_debug", suffix=".json") as f: + prof.export_chrome_trace(f.name) + resp.set_content(open(f.name, "rb").read(), "application/json") + resp.set_status(200) + + +_register_handler("torch_profile", _torch_profile) diff --git a/torch/distributed/debug/_store.py b/torch/distributed/debug/_store.py new file mode 100644 index 0000000000000..70c6cd0f3dde1 --- /dev/null +++ b/torch/distributed/debug/_store.py @@ -0,0 +1,24 @@ +import os + +import torch.distributed as dist + + +def get_rank() -> int: + return int(os.environ["RANK"]) + + +def get_world_size() -> int: + return int(os.environ["WORLD_SIZE"]) + + +def tcpstore_client() -> dist.Store: + MASTER_ADDR = os.environ["MASTER_ADDR"] + MASTER_PORT = int(os.environ["MASTER_PORT"]) + + store = dist.TCPStore( + host_name=MASTER_ADDR, + port=MASTER_PORT, + is_master=False, + ) + store = dist.PrefixStore("debug_server", store) + return store From aa22d41f9b1214042c4b834aaf15bc5e625c3a41 Mon Sep 17 00:00:00 2001 From: Aniket Panse Date: Tue, 18 Nov 2025 19:48:15 +0000 Subject: [PATCH 007/230] [refcycle-logger] Output tensor size in the refcycle visualization (#167079) Summary: As title. Knowing the size of the leaked tensor is useful, it allows us to focus on the largest leaks. Differential Revision: D86218574 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167079 Approved by: https://github.com/kausv --- torch/utils/viz/_cycles.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/utils/viz/_cycles.py b/torch/utils/viz/_cycles.py index 8abb547d500f8..df4bf34db2114 100644 --- a/torch/utils/viz/_cycles.py +++ b/torch/utils/viz/_cycles.py @@ -249,6 +249,8 @@ def format_sequence(obj): if len(filename) > FRAME_FILENAME_LIMIT: filename = "..." + filename[-(FRAME_FILENAME_LIMIT - 3):] return f"frame\n{filename}:{obj.f_lineno}" + elif is_cuda_tensor(obj): + return f"object\n{type(obj).__module__}.{type(obj).__name__} ({obj.shape})" else: return f"object\n{type(obj).__module__}.{type(obj).__name__}" From 14f370f5518c02c51dae1d424584f6ad4602b4bf Mon Sep 17 00:00:00 2001 From: "Liao, Wei" Date: Tue, 18 Nov 2025 19:49:40 +0000 Subject: [PATCH 008/230] [xpu][test] port some distributed tensor test files for Intel GPU (#161703) it's another pr to port distributed tensor test for Intel GPU, while the other pr is https://github.com/pytorch/pytorch/pull/161604 We could enable Intel GPU with following methods and try the best to keep the original code styles: Use torch.accelerator for general gpu Skip the case if running on xpu which has known issues Pull Request resolved: https://github.com/pytorch/pytorch/pull/161703 Approved by: https://github.com/guangyey, https://github.com/d4l3k, https://github.com/albanD --- .../tensor/debug/test_comm_mode.py | 13 +++++--- test/distributed/tensor/test_dtensor.py | 14 ++++---- .../tensor/test_dtensor_compile.py | 32 +++++++++++-------- test/distributed/tensor/test_redistribute.py | 4 +-- test/distributed/tensor/test_tensor_ops.py | 2 +- .../distributed/_tensor/common_dtensor.py | 2 +- 6 files changed, 36 insertions(+), 31 deletions(-) diff --git a/test/distributed/tensor/debug/test_comm_mode.py b/test/distributed/tensor/debug/test_comm_mode.py index c87164750c684..d122a9f716fcd 100644 --- a/test/distributed/tensor/debug/test_comm_mode.py +++ b/test/distributed/tensor/debug/test_comm_mode.py @@ -6,7 +6,7 @@ import torch.nn as nn from torch.distributed.tensor import DeviceMesh, DTensor, Shard from torch.distributed.tensor.debug import CommDebugMode -from torch.testing._internal.common_distributed import requires_nccl +from torch.testing._internal.common_distributed import requires_accelerator_dist_backend from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule from torch.testing._internal.distributed.fake_pg import FakeStore @@ -14,6 +14,9 @@ c10d_functional = torch.ops.c10d_functional c10d_ops = torch.ops.c10d +device_type = ( + acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu" +) class TestCommMode(TestCase): @@ -28,7 +31,7 @@ def setUp(self): dist.init_process_group( backend="fake", rank=1, world_size=self.world_size, store=store ) - self.device_type = "cuda" if torch.cuda.is_available() else "cpu" + self.device_type = device_type self.world_pg = dist.distributed_c10d._get_default_group() def checksAssert(self, comm_mode, key, expected_value, expected_total_value): @@ -111,12 +114,12 @@ def f(x, y): self.assertEqual(comm_counts[c10d_functional.all_gather_into_tensor], 1) self.assertEqual(comm_counts[c10d_functional.reduce_scatter_tensor], 0) - @requires_nccl() + @requires_accelerator_dist_backend(["nccl", "xccl"]) def test_comm_mode_with_c10d(self): - if not torch.cuda.is_available(): + if not torch.accelerator.is_available(): return - inp = torch.rand(2, 8, 16).cuda() + inp = torch.rand(2, 8, 16).to(device_type) all_gather_out = inp.new_empty(self.world_size * 2, 8, 16) comm_mode = CommDebugMode() diff --git a/test/distributed/tensor/test_dtensor.py b/test/distributed/tensor/test_dtensor.py index e99734c6b8437..c47ff79091493 100644 --- a/test/distributed/tensor/test_dtensor.py +++ b/test/distributed/tensor/test_dtensor.py @@ -658,11 +658,11 @@ def sub_mesh_assert_equal(self, mesh, exp_in_mesh, exp_out_of_mesh, tensor): @with_comms def test_dtensor_device_mesh_device_conversion(self): - # construct a cuda device mesh + # construct a gpu device mesh mesh = self.build_device_mesh() - # construct from a cpu local tensor with cuda device mesh - # should automatically convert the dist tensor to cuda + # construct from a cpu local tensor with gpu device mesh + # should automatically convert the dist tensor to gpu placements = [Shard(0)] local_tensor = torch.randn(3, 3) dist_tensor = DTensor.from_local(local_tensor, mesh, placements) @@ -711,7 +711,7 @@ def test_dtensor_api_device_mesh_context_manager(self): @with_comms def test_dtensor_2d_mesh(self): mesh_tensor = torch.arange(self.world_size).reshape(2, 4) - # construct a cuda device mesh + # construct a gpu device mesh mesh = DeviceMesh(self.device_type, mesh_tensor) # construct a dist tensor on 2d device mesh and test if works @@ -733,7 +733,7 @@ def test_dtensor_2d_mesh(self): @with_comms def test_device_mesh_nd(self): - # construct a cuda device mesh + # construct a gpu device mesh mesh_tensor = torch.arange(self.world_size).reshape(2, 2, 2) mesh = DeviceMesh(self.device_type, mesh_tensor) # construct a dist tensor on 3d device mesh and test if works @@ -1064,8 +1064,8 @@ def _create_tensor(self, size): # Keep everything deterministic. torch.manual_seed(0) tensor = torch.rand(size) - if self.device_type == "cuda": - return tensor.cuda() + if self.device_type != "cpu": + return tensor.to(self.device_type) else: return tensor diff --git a/test/distributed/tensor/test_dtensor_compile.py b/test/distributed/tensor/test_dtensor_compile.py index ddba3150b05fb..22493c4451d63 100644 --- a/test/distributed/tensor/test_dtensor_compile.py +++ b/test/distributed/tensor/test_dtensor_compile.py @@ -39,6 +39,7 @@ RowwiseParallel, ) from torch.distributed.tensor.placement_types import _StridedShard +from torch.testing._internal.common_device_type import skipXPUIf from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import get_devtype from torch.testing._internal.common_utils import ( @@ -47,8 +48,6 @@ run_tests, skipIfHpu, skipIfTorchDynamo, - TEST_CUDA, - TEST_HPU, ) from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -95,6 +94,10 @@ def extract_graph(fx_g, _, graph_cell): partition_fn=min_cut_rematerialization_partition, ) +device_type = ( + acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu" +) + def _apply_sharding(mod: nn.Module, shard_dim: int, device_mesh: DeviceMesh): """ @@ -141,7 +144,7 @@ def tearDown(self): @property def device_type(self) -> str: - return "cuda" if TEST_CUDA else "hpu" if TEST_HPU else "cpu" + return device_type @property def world_size(self) -> int: @@ -160,9 +163,9 @@ def fn(x): res = fn(x) res.to_local().sum().backward() - @unittest.skipIf(not TEST_CUDA, "CUDA not available") + @unittest.skipIf(not torch.accelerator.is_available(), "accelerator not available") def test_dtensor_basic_export(self): - mesh = DeviceMesh("cuda", torch.arange(self.world_size)) + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) param = torch.randn(4, 4) param_x = DTensor.from_local(param, mesh, [Shard(0)], run_check=False) @@ -188,10 +191,10 @@ def forward(self, x): ) self.assertExpectedInline( str(ep.graph_module.code).strip(), - """\ + f"""\ def forward(self, b_buffer, x): _assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(x, dtype = torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None - to = torch.ops.aten.to.dtype_layout(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda')); x = None + to = torch.ops.aten.to.dtype_layout(x, dtype = torch.float64, layout = torch.strided, device = device(type='{self.device_type}')); x = None view_as = torch.ops.aten.view_as.default(to, to); to = None dtensor___init__0 = self.dtensor___init__0 dtensor_const_func_spec0 = self.dtensor_const_func_spec0 @@ -206,10 +209,10 @@ def forward(self, b_buffer, x): # add is performed in _propagate_tensor_meta_non_cached, hence add_1 instead of add self.assertExpectedInline( str(ep.run_decompositions({}).graph_module.code).strip(), - """\ + f"""\ def forward(self, b_parametrizations_buffer_original0, x): _assert_tensor_metadata = torch.ops.aten._assert_tensor_metadata.default(x, None, None, torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata = None - _to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda', index=0)); x = None + _to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='{self.device_type}', index=0)); x = None view = torch.ops.aten.view.default(_to_copy, [4, 4]); _to_copy = None add = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None view_1 = torch.ops.aten.view.default(add, [4, 4]); add = None @@ -377,6 +380,7 @@ def fn(x): self.assertEqual(res, ref) @skipIfHpu + @skipXPUIf(True, "https://github.com/intel/torch-xpu-ops/issues/1981") def test_dtensor_dynamic_loss_parallel_log_softmax(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -815,13 +819,13 @@ def fn(x, y, z): out = layer_norm.permute(0, 2, 1) return out - x = torch.randn(4, 2, 4, requires_grad=True, device="cuda") + x = torch.randn(4, 2, 4, requires_grad=True, device=self.device_type) x_dt = DTensor.from_local(x, mesh, [Shard(1)], run_check=False) - y = torch.randn(4, requires_grad=True, device="cuda") + y = torch.randn(4, requires_grad=True, device=self.device_type) y_dt = DTensor.from_local(y, mesh, [Replicate()], run_check=False) - z = torch.randn(4, requires_grad=True, device="cuda") + z = torch.randn(4, requires_grad=True, device=self.device_type) z_dt = DTensor.from_local(z, mesh, [Replicate()], run_check=False) opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) @@ -919,7 +923,7 @@ def test_dtensor_dynamo_device_mesh_attrs(self): # pass in tensor as inputs/outputs, create DTensor and run redistribute # (allgather collective) inside the fn def fn(x_dt): - if x_dt.device_mesh.device_type == "cuda": + if x_dt.device_mesh.device_type == f"{self.device_type}": return x_dt + 1 else: return x_dt + 2 @@ -1051,7 +1055,7 @@ def forward(self, input): model = FakeTransformer().to(self.device_type) - tp_mesh = init_device_mesh("cuda", (2,), mesh_dim_names=("tp",)) + tp_mesh = init_device_mesh(self.device_type, (2,), mesh_dim_names=("tp",)) # apply sequence parallel parallel_plan = { diff --git a/test/distributed/tensor/test_redistribute.py b/test/distributed/tensor/test_redistribute.py index 381660e47927d..86bb567a39616 100644 --- a/test/distributed/tensor/test_redistribute.py +++ b/test/distributed/tensor/test_redistribute.py @@ -27,8 +27,6 @@ instantiate_parametrized_tests, parametrize, run_tests, - TEST_CUDA, - TEST_HPU, ) from torch.testing._internal.distributed._tensor.common_dtensor import ( create_local_tensor_test_class, @@ -541,7 +539,7 @@ def test_redistribute_shard_dim_change(self, dtype): local_out_dt = out_dt.to_local() local_expected_dt = expected_dt.to_local() self.assertEqual(out_dt.to_local(), expected_dt.to_local()) - if TEST_HPU or TEST_CUDA: + if torch.accelerator.is_available(): self.assertEqual( comm_mode.get_comm_counts()[ torch.ops._dtensor.shard_dim_alltoall diff --git a/test/distributed/tensor/test_tensor_ops.py b/test/distributed/tensor/test_tensor_ops.py index 80968fb52e904..4748db4f7377b 100644 --- a/test/distributed/tensor/test_tensor_ops.py +++ b/test/distributed/tensor/test_tensor_ops.py @@ -296,8 +296,8 @@ def test_zeros_like(self): self.assertEqual(dist_tensor.dtype, torch.float32) self.assertEqual(zeros_like_dt.dtype, torch.bfloat16) - @with_comms @skip_if_lt_x_gpu(4) + @with_comms def test_stack(self): mesh_2d = DeviceMesh( self.device_type, torch.arange(self.world_size).reshape(2, 2) diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index 6ce7d4b2ca507..9666765b01e71 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -386,7 +386,7 @@ def device_type(self) -> str: @property def backend(self) -> str: - backend = dist.get_default_backend_for_device(DEVICE_TYPE) + backend = dist.get_default_backend_for_device(self.device_type) return backend def init_manual_seed_for_rank(self) -> None: From e3c5b789994e19eb40681da91af38b7a1321af90 Mon Sep 17 00:00:00 2001 From: eellison Date: Tue, 18 Nov 2025 08:38:45 -0800 Subject: [PATCH 009/230] small changes (#167852) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167852 Approved by: https://github.com/fmassa --- .../test_aten_comm_compute_reordering.py | 4 +- torch/_inductor/comm_analysis.py | 58 +++++++++++++++++-- .../_inductor/fx_passes/overlap_scheduling.py | 37 ++++++++---- 3 files changed, 79 insertions(+), 20 deletions(-) diff --git a/test/distributed/test_aten_comm_compute_reordering.py b/test/distributed/test_aten_comm_compute_reordering.py index 426f77e379f8f..97cb8c02c8b1b 100644 --- a/test/distributed/test_aten_comm_compute_reordering.py +++ b/test/distributed/test_aten_comm_compute_reordering.py @@ -30,7 +30,7 @@ from torch.testing._internal.inductor_utils import HAS_GPU -def estimate_aten_runtime(fx_node, compute_multiplier=1.0): +def estimate_aten_runtime(fx_node, override_size=None, compute_multiplier=1.0): # for tests, assume a matmul can hide a single collective if "c10" in str(fx_node.target): return 1.0 @@ -1112,7 +1112,7 @@ def test_multiple_hiding_nodes_bucketing(self): # Use 0.5 compute multiplier so each collective needs 2 matmuls to be fully hidden def estimate_with_half_compute(fx_node, override_size=None): - return estimate_aten_runtime(fx_node, compute_multiplier=0.5) + return estimate_aten_runtime(fx_node, override_size, compute_multiplier=0.5) def func(a, b, *, ranks): # Two all_gathers that will be hidden by multiple compute operations diff --git a/torch/_inductor/comm_analysis.py b/torch/_inductor/comm_analysis.py index 681aef9afb35f..55279f393d3aa 100644 --- a/torch/_inductor/comm_analysis.py +++ b/torch/_inductor/comm_analysis.py @@ -341,12 +341,58 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float: def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int: - sz_bytes = 0 - for node in fx_node.all_input_nodes: - if (t := node.meta.get("val")) is not None: - numel = get_size_numel(t.size()) - sz_bytes += numel * get_dtype_size(t.dtype) - return sz_bytes + """Estimate the size of a collective operation in bytes, including inputs and outputs.""" + input_bytes = None + + args, kwargs = fx_node.args, fx_node.kwargs + kwargs = dict(kwargs) + + # dont double count pre-allocated buffer passed in + kwargs.pop("out", None) + + def tensor_bytes(t) -> int: + return get_size_numel(t.size()) * get_dtype_size(t.dtype) + + def add_inp_bytes(inp: torch.fx.Node): + t = inp.meta.get("val", None) + if t is None: + return + + nonlocal input_bytes + if input_bytes is None: + input_bytes = 0 + input_bytes += tensor_bytes(t) + + pytree.tree_map_only( + torch.fx.Node, + add_inp_bytes, + (args, kwargs), + ) + + output_tensor = fx_node.meta.get("val", None) + + if input_bytes is None or output_tensor is None: + return 0 + + output_bytes = ( + get_size_numel(output_tensor.size()) * output_tensor.element_size() + ) # pyre-ignore + + return input_bytes + output_bytes + + +def estimate_fx_collective_memory_footprint(fx_node: torch.fx.Node) -> int: + """Estimate the memory footprint of a collective operation in bytes. + + This returns the total bytes that need to be live concurrently in memory. + For all_reduce, we divide by 2 since it can be done in-place. + """ + from torch._inductor.fx_passes.bucketing import ( + is_all_reduce_tensor as is_all_reduce, + ) + + size = estimate_fx_collective_size(fx_node) + return size if not is_all_reduce(fx_node) else size // 2 def estimate_nccl_collective_runtime_from_fx_node( diff --git a/torch/_inductor/fx_passes/overlap_scheduling.py b/torch/_inductor/fx_passes/overlap_scheduling.py index 0649e36f23361..b7617038f4e6a 100644 --- a/torch/_inductor/fx_passes/overlap_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_scheduling.py @@ -11,7 +11,7 @@ import torch import torch.fx as fx from torch._dynamo.utils import counters, dynamo_timed -from torch._inductor.comm_analysis import estimate_fx_collective_size +from torch._inductor.comm_analysis import estimate_fx_collective_memory_footprint from torch._inductor.fx_passes.bucketing import _schedulable_wait_node, is_wait_tensor from torch._inductor.fx_passes.memory_estimator import ( _is_releasable, @@ -45,21 +45,26 @@ def get_group_name(n: fx.Node) -> str: def get_custom_estimation( n: fx.Node, - custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None, + custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] + | None = None, + override_size: int | None = None, ) -> float | None: if custom_runtime_estimation is None: return None - return custom_runtime_estimation(n) + return custom_runtime_estimation(n, override_size) def estimate_collective_time( n: fx.Node, override_size: int | None = None, - custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None, + custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] + | None = None, ) -> float: """Estimate the runtime of a collective operation, optionally with an overridden size.""" - if (est := get_custom_estimation(n, custom_runtime_estimation)) is not None: + if ( + est := get_custom_estimation(n, custom_runtime_estimation, override_size) + ) is not None: return est # Use analytical model (benchmarking is handled separately in alignment) @@ -99,7 +104,8 @@ def get_collective_do_bench() -> Callable[[Callable[[], Any]], float]: def benchmark_node_with_cache_key( n: fx.Node, - custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None, + custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] + | None = None, ) -> tuple[float, str | None]: """Benchmark a compute node and return (runtime, cache_key).""" assert is_compute_node(n) @@ -142,7 +148,9 @@ def to_real(t: torch.Tensor) -> torch.Tensor | None: if unbacked_tensor: return 0, key - if (est := get_custom_estimation(n, custom_runtime_estimation)) is not None: + if ( + est := get_custom_estimation(n, custom_runtime_estimation, None) + ) is not None: set_cached_node_time(key, est) return est, key @@ -154,7 +162,8 @@ def to_real(t: torch.Tensor) -> torch.Tensor | None: def benchmark_node( n: fx.Node, - custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None, + custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] + | None = None, ) -> float: return benchmark_node_with_cache_key(n, custom_runtime_estimation)[0] @@ -236,7 +245,7 @@ def __init__( insert_overlap_deps: bool, compute_overlap_multipler: float, max_coll_distance: int, - custom_runtime_estimation: Callable[[fx.Node], float | None] | None, + custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] | None, collective_estimator: Literal["analytical", "benchmark"], ): self.gm = gm @@ -318,7 +327,7 @@ def _identify_collectives(self) -> None: info = CollectiveInfo( start_node=start, wait_node=node, - size_bytes=estimate_fx_collective_size(start), + size_bytes=estimate_fx_collective_memory_footprint(start), estimated_time_ms=coll_time_ms, exposed_time_ms=coll_time_ms, # Initially fully exposed ) @@ -431,7 +440,10 @@ def _align_compute_nodes_runtime_estimations_across_all_distributed_ranks( # Benchmark CUDA events (non-deterministic, needs alignment) # Skip collectives with custom estimation for n in collective_nodes: - if get_custom_estimation(n, self.custom_runtime_estimation) is not None: + if ( + get_custom_estimation(n, self.custom_runtime_estimation, None) + is not None + ): continue # Benchmark actual size @@ -1000,7 +1012,8 @@ def schedule_overlap_bucketing( insert_overlap_deps: bool = False, compute_overlap_multipler: float = 1.0, max_coll_distance: int = 1000, - custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None, + custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] + | None = None, collective_estimator: Literal["analytical", "benchmark"] = "analytical", ) -> torch.fx.GraphModule: """Schedule nodes to maximize compute-collective overlap. From 4c5042b3682ae05a287e705743359ef04b9e6fe1 Mon Sep 17 00:00:00 2001 From: eellison Date: Tue, 18 Nov 2025 09:03:35 -0800 Subject: [PATCH 010/230] Fix all gather bucketing fusion in of dtype casts (#167853) The all gather bucketing was part of the way to fusing in dtype casts into the bucket. We do this by allocating the group bucket buffer, then viewing each slice of it as the destination dtype. We then foreach_copy_ into the allocated buffer, with each collective copying in to its destination dtype. This logic was causing an issue in a later part of the stack, but not fully firing, so might as well fix it. Note: custom ops dont yet support list[dtype], so i worked around by list[int], but will fix in a follow up. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167853 Approved by: https://github.com/ruisizhang123 ghstack dependencies: #167852 --- .../test_aten_comm_compute_reordering.py | 50 ++++++++++++ torch/_inductor/fx_passes/bucketing.py | 76 ++++++++++++++----- 2 files changed, 108 insertions(+), 18 deletions(-) diff --git a/test/distributed/test_aten_comm_compute_reordering.py b/test/distributed/test_aten_comm_compute_reordering.py index 97cb8c02c8b1b..a60d3868e4f82 100644 --- a/test/distributed/test_aten_comm_compute_reordering.py +++ b/test/distributed/test_aten_comm_compute_reordering.py @@ -1162,6 +1162,56 @@ def func(a, b, *, ranks): correct = func(a, b, ranks=ranks) self.assertTrue(same(out, correct)) + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @torch._inductor.config.patch(get_bucket_patches()) + def test_bucketing_with_convert_dtype(self): + """Test that all_gathers with dtype conversion get bucketed and produce correct results.""" + + def func(a, b, c, d, *, ranks): + # Convert inputs to float16 before all_gather + a_fp16 = a.to(torch.float16) + b_fp16 = b.to(torch.float16) + + # Two all_gathers with converted dtypes + ag1 = _functional_collectives.all_gather_tensor(a_fp16, 0, ranks) + ag2 = _functional_collectives.all_gather_tensor(b_fp16, 0, ranks) + + # same dtype + ag3 = _functional_collectives.all_gather_tensor(c, 0, ranks) + ag4 = _functional_collectives.all_gather_tensor(d, 0, ranks) + + return ag1, ag2, ag3, ag4 + + with _dynamo_dist_per_rank_init( + self.rank, + self.world_size, + self.backend(device_type), + fake_pg=not at_least_x_gpu(2), + ): + a = torch.ones(4, 4, dtype=torch.float32, device=device_type) + b = torch.ones(4, 4, dtype=torch.float64, device=device_type) * 2 + c = torch.ones(4, 4, dtype=torch.float16, device=device_type) * 3 + d = torch.ones(4, 4, dtype=torch.float64, device=device_type) * 4 + ranks = list(range(self.world_size)) + + func_c = functools.partial(func, ranks=ranks) + compiled = torch.compile(func_c) + out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c, d) + + # Should have 1 bucketed all_gather (both ag1 and ag2 bucketed together) + FileCheck().check_count( + "torch.ops._c10d_functional.wait_tensor.default", 1, exactly=True + ).run(aten_graph_str) + + # Verify convert_element_type ops are removed (dtype conversion handled by _pre_bucket_all_gather) + FileCheck().check_not("torch.ops.prims.convert_element_type").run( + aten_graph_str + ) + + # Verify correctness - this tests that dtype conversion is handled correctly + correct = func(a, b, c, d, ranks=ranks) + self.assertTrue(same(out, correct)) + def get_toy_model(device_type: str): """ diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index 5641c4294356f..00737a3b6e3b7 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -489,15 +489,34 @@ def all_reduce_merge_fn_to_trace( return new_outs +# List of all torch dtypes for serialization through custom ops +# TODO: custom ops support list[dtype] input +_ALL_DTYPES = tuple( + [ + getattr(torch, attr) + for attr in dir(torch) + if isinstance(getattr(torch, attr), torch.dtype) + ] +) + + @torch.library.custom_op("bucketing::_pre_bucket_all_gather", mutates_args={}) def _pre_bucket_all_gather( ag_ins: list[torch.Tensor], group_size: int, group_name: str, dtype: torch.dtype, # type: ignore[name-defined] + out_dtype_ints: list[ + int + ], # dtype enum values, that inputs are converted to before all_gather rank: int, ) -> torch.Tensor: - ins_split_sizes_bytes = [ag_in.numel() * ag_in.element_size() for ag_in in ag_ins] + # Convert int indices back to torch.dtype + out_dtypes = [_ALL_DTYPES[d] for d in out_dtype_ints] + ins_split_sizes_bytes = [ + ag_in.numel() * out_dtype.itemsize + for ag_in, out_dtype in zip(ag_ins, out_dtypes, strict=True) + ] bucket_dtype_size_bytes = dtype.itemsize ins_split_sizes = [ _bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes @@ -507,8 +526,14 @@ def _pre_bucket_all_gather( new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device) new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel) foreach_copy_dsts = torch.split(new_ag_in, ins_split_sizes) - ag_ins_flattened = [ag_in.reshape(-1).view(dtype) for ag_in in ag_ins] - torch._foreach_copy_(foreach_copy_dsts, ag_ins_flattened) + # View each destination slice as its output dtype, then copy + # The copy operation handles dtype conversion from input dtype to output dtype + foreach_copy_dsts_typed = [ + dst.view(out_dtype) + for dst, out_dtype in zip(foreach_copy_dsts, out_dtypes, strict=True) + ] + ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins] + torch._foreach_copy_(foreach_copy_dsts_typed, ag_ins_flattened) return new_ag_out @@ -517,9 +542,14 @@ def _pre_bucket_all_gather_fake( group_size: int, group_name: str, dtype: torch.dtype, # type: ignore[name-defined] + out_dtype_ints: list[int], rank: int, ) -> torch.Tensor: - ins_split_sizes_bytes = [ag_in.numel() * ag_in.element_size() for ag_in in ag_ins] + out_dtypes = [_ALL_DTYPES[d] for d in out_dtype_ints] + ins_split_sizes_bytes = [ + ag_in.numel() * out_dtype.itemsize + for ag_in, out_dtype in zip(ag_ins, out_dtypes, strict=True) + ] bucket_dtype_size_bytes = dtype.itemsize ins_split_sizes = [ _bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes @@ -541,12 +571,9 @@ def all_gather_merge_fn_to_trace_custom_ops( out_dtypes: list[torch.dtype], # type: ignore[name-defined] rank: int, ) -> list[torch.Tensor]: - ag_ins = [ - torch._prims.convert_element_type(_ag_in, out_dtype) - if _ag_in.dtype != out_dtype - else _ag_in - for _ag_in, out_dtype in zip(_ag_ins, out_dtypes) - ] + # Don't create convert_element_type ops - _pre_bucket_all_gather handles conversion + # by viewing destination slices as output dtypes and letting copy do the conversion + ag_ins = _ag_ins ins_sizes = [ag_in.shape for ag_in in ag_ins] ins_split_sizes_bytes = [ ag_in.numel() * out_dtype.itemsize @@ -557,8 +584,13 @@ def all_gather_merge_fn_to_trace_custom_ops( _bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes ] ag_input_numel = sum(ins_split_sizes) + + # Convert out_dtypes to indices for custom_op + # TODO: custom ops support list[dtype] input + out_dtype_ints = [_ALL_DTYPES.index(dt) for dt in out_dtypes] + new_ag_out = torch.ops.bucketing._pre_bucket_all_gather( - ag_ins, group_size, group_name, dtype, rank + ag_ins, group_size, group_name, dtype, out_dtype_ints, rank ) new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel) wait_tensor = torch.ops.c10d_functional.wait_tensor( @@ -721,6 +753,20 @@ def _insert_fn_trace_before_node( # type: ignore[no-untyped-def] return replacements, new_nodes +def has_mergeable_all_gather_convert_dtype(n: torch.fx.Node) -> bool: + node_in = n.args[0] + return ( + is_all_gather_into_tensor(n) + and isinstance(node_in, torch.fx.Node) + and node_in.op == "call_function" + and ( + node_in.target is torch.ops.prims.convert_element_type.default + or node_in.target is torch.ops.aten._to_copy.default + ) + and len(node_in.users) == 1 + ) + + def process_collective_bucket( g: torch.fx.Graph, bucket_nodes: list[torch.fx.Node], @@ -755,13 +801,7 @@ def process_collective_bucket( # Handle convert_element_type operations (for all_gather) node_in = n.args[0] - if ( - is_all_gather_into_tensor(n) - and isinstance(node_in, torch.fx.Node) # Add type check - and node_in.op == "call_function" - and node_in.target is torch.ops.prims.convert_element_type.default - and len(node_in.users) == 1 - ): + if has_mergeable_all_gather_convert_dtype(n): ag_node_to_pre_nodes[n].append(node_in) node_in = node_in.args[0] From dda2cb3769f6bab6114da5951162c4ec7d705701 Mon Sep 17 00:00:00 2001 From: eellison Date: Tue, 18 Nov 2025 09:34:58 -0800 Subject: [PATCH 011/230] Handled erased hiding nodes from dtype bucketing (#167863) The bucketing dtype fusing was causing nodes which had dependencies to be erased. Transfer those deps over to the new nodes, and also add an assertion that none of our deps are erased to catch this type of error in the future. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167863 Approved by: https://github.com/fmassa ghstack dependencies: #167852, #167853 --- .../test_overlap_bucketing_unit.py | 88 +++++++++++++++++++ .../fx_passes/overlap_preserving_bucketer.py | 34 ++++++- 2 files changed, 121 insertions(+), 1 deletion(-) diff --git a/test/distributed/test_overlap_bucketing_unit.py b/test/distributed/test_overlap_bucketing_unit.py index de6f2ba612977..c0c4c31cc1a81 100644 --- a/test/distributed/test_overlap_bucketing_unit.py +++ b/test/distributed/test_overlap_bucketing_unit.py @@ -667,6 +667,94 @@ def func(a, b): str(traced.graph) ) + def test_can_bucket_with_convert_dtype_as_hiding_nodes(self): + """ + Test that all_gathers can bucket when convert_element_type ops ARE the hiding nodes. + + Graph structure: + ag1_start -> convert1 (hides ag1) -> ag1_wait -> ag2_start -> convert2 (hides ag2) -> ag2_wait + + The convert_element_type ops ARE hiding nodes - no matmuls. + This tests that dependencies are transferred correctly when convert nodes are erased. + """ + + def func(a, b, c): + group_name = "0" + group_size = 1 + + ag1 = torch.ops._c10d_functional.all_gather_into_tensor( + a, group_size, group_name + ) + b = torch.ops.prims.convert_element_type.default(b, torch.float16) + ag1_out = torch.ops._c10d_functional.wait_tensor(ag1) + + ag2 = torch.ops._c10d_functional.all_gather_into_tensor( + b, group_size, group_name + ) + ag3 = torch.ops._c10d_functional.all_gather_into_tensor( + c, group_size, group_name + ) + + mm = ag1_out @ ag1_out + + ag2_out = torch.ops._c10d_functional.wait_tensor(ag2) + ag3_out = torch.ops._c10d_functional.wait_tensor(ag3) + + return ag1_out, ag2_out, ag3_out, mm + + with FakeTensorMode(): + a = torch.ones(4, 4, device=self.device, dtype=torch.float32) + b = torch.ones(4, 4, device=self.device, dtype=torch.float32) + c = torch.ones(4, 4, device=self.device, dtype=torch.float32) + + traced = make_fx(func)(a, b, c) + + # Find nodes + ag1, ag2, ag3 = traced.graph.find_nodes( + op="call_function", + target=torch.ops._c10d_functional.all_gather_into_tensor.default, + ) + convert1 = traced.graph.find_nodes( + op="call_function", + target=torch.ops.prims.convert_element_type.default, + )[0] + mm = traced.graph.find_nodes( + op="call_function", + target=torch.ops.aten.mm.default, + )[0] + + hiding_annotations = { + ag1: convert1, + ag2: mm, + ag3: mm, + } + + # Build collective info and ancestors + collective_info = build_collective_info(traced.graph, hiding_annotations) + node_ancestors = compute_ancestors(traced.graph) + scheduled = OrderedSet(traced.graph.nodes) + + # Run bucketing + from torch._inductor.fx_passes.overlap_preserving_bucketer import ( + OverlapPreservingBucketer, + ) + + bucketer = OverlapPreservingBucketer( + traced.graph, + collective_info, + node_ancestors, + scheduled, + ) + bucketer.bucket_collectives() + + graph_str = str(traced.graph) + + f = FileCheck() + f.check_count("%all_gather_into_tensor", 1, exactly=True) + f.check("pre_bucket_all_gather").check("wait_tensor").check( + "%all_gather_into_tensor_out" + ).run(graph_str) + if __name__ == "__main__": run_tests() diff --git a/torch/_inductor/fx_passes/overlap_preserving_bucketer.py b/torch/_inductor/fx_passes/overlap_preserving_bucketer.py index 4060a29c7c3db..b6cbf32bfba8e 100644 --- a/torch/_inductor/fx_passes/overlap_preserving_bucketer.py +++ b/torch/_inductor/fx_passes/overlap_preserving_bucketer.py @@ -3,12 +3,14 @@ from dataclasses import dataclass from typing import Any, Literal, Optional +import torch import torch.fx as fx from torch._dynamo.utils import counters from torch._inductor.augmented_graph_helper import AugmentedGraphHelper from torch._inductor.fx_passes.bucketing import ( bucket_key, BucketMode, + has_mergeable_all_gather_convert_dtype, is_all_gather_into_tensor as is_all_gather, is_reduce_scatter_tensor as is_reduce_scatter, is_wait_tensor, @@ -207,6 +209,7 @@ def build_timeline(self, pg: str) -> Optional[PGEvent]: prev_event = event position += 1 + return head def _populate_node_to_event(self, pg: str) -> None: @@ -231,7 +234,6 @@ def _add_hiding_interval_constraints(self) -> None: self.aug_graph.add_extra_dep(n=info.wait_node, dep=hn) def bucket_collectives(self) -> None: - """Main entry point for bucketing collectives.""" # Group collectives by PG first pg_collectives: dict[str, OrderedSet[fx.Node]] = defaultdict(OrderedSet) for start in self.collective_info: @@ -281,6 +283,15 @@ def bucket_collectives(self) -> None: # Apply topological sort with all dependencies from torch._dynamo.graph_deduplication import _stable_topological_sort + for n, deps in additional_deps.items(): + torch._check( + not n._erased, lambda: f"Erased node deps not transferred: {n}" + ) + for d in deps: + torch._check( + not d._erased, lambda: f"Erased node deps not transferred: {d}" + ) + _stable_topological_sort(self.graph, additional_deps) # After topological sort, preserve dependencies using effect tokens @@ -762,6 +773,11 @@ def _apply_bucket(self, bucket_info: CollBucket) -> None: old_starts = list(bucket) old_waits = [self.collective_info[n].wait_node for n in bucket] + fused_convert_dtypes = [] + for n in old_starts: + if has_mergeable_all_gather_convert_dtype(n): + fused_convert_dtypes.append(n.args[0]) + # Find where to place the bucketed operations next_node = bucket[0] while next_node in bucket: @@ -809,6 +825,22 @@ def _apply_bucket(self, bucket_info: CollBucket) -> None: for old_wait in old_waits: erased_to_new[old_wait] = new_wait + # Handle convert_element_type nodes that were fused and erased + # The bucketed operation may have a _pre_bucket op that handles dtype conversion + if fused_convert_dtypes: + # all gather bucketing may fuse in dtype conversion into the bucketing + # if so, we need to transfer hiding deps from the old dtype conversion + # to the new bucketing node + new_convert_dtypes_node = new_start.kwargs["out"] + assert isinstance(new_convert_dtypes_node, fx.Node) + assert ( + new_convert_dtypes_node.target + == torch.ops.bucketing._pre_bucket_all_gather.default + ) + + for n in fused_convert_dtypes: + erased_to_new[n] = new_convert_dtypes_node + # Transfer all dependencies from old nodes to new nodes self.aug_graph.transfer_erased_node_deps(erased_to_new) From 7921c0eb0eb79d5d9ea687c7858ba81429ca119f Mon Sep 17 00:00:00 2001 From: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com> Date: Tue, 18 Nov 2025 20:04:20 +0000 Subject: [PATCH 012/230] [ROCm][CI] Limit caching to ROCm jammy docker images (#168088) Since the currently intended workflow on the new MI3xx CI capacity is [trunk-rocm-mi300.yml](https://github.com/pytorch/pytorch/blob/d91269e8ce309437c1f849b5ab3362d69b178ef4/.github/workflows/trunk-rocm-mi300.yml#L54), which only needs the jammy images, limiting those to optimize docker caching times. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168088 Approved by: https://github.com/jeffdaily --- .github/workflows/docker-cache-rocm.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/docker-cache-rocm.yml b/.github/workflows/docker-cache-rocm.yml index 78d38de3ac69a..8039b4f71087b 100644 --- a/.github/workflows/docker-cache-rocm.yml +++ b/.github/workflows/docker-cache-rocm.yml @@ -50,9 +50,10 @@ jobs: matrix: runner: [linux.rocm.gfx942.docker-cache] docker-image: [ - "${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}", - "${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}", - "${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}" + "${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}" + #"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}", + #"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}", + #"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}" ] runs-on: "${{ matrix.runner }}" steps: From ae85307512c582bbe073f5ab9c81a032e95fcfba Mon Sep 17 00:00:00 2001 From: Justin Turney Date: Tue, 18 Nov 2025 20:06:26 +0000 Subject: [PATCH 013/230] huber_loss numerical issue (#166952) For GPU: Previously reported that only a single sample could be tested with huber_loss functional. Current snapshot of the code does not appear to suffer from numerical issues as reported before. For CPU: While testing GPU, it was discovered that with Half appears to be numerically unstable. This commit resolves issue with CPU by upcasting Half to float for the computation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166952 Approved by: https://github.com/benjaminglass1, https://github.com/isuruf --- aten/src/ATen/native/cpu/BinaryOpsKernel.cpp | 40 +++++++++++++++++++- test/inductor/test_torchinductor_opinfo.py | 6 --- 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index 221f621ea1e06..b5f3d91692b9a 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -813,8 +813,43 @@ void smooth_l1_kernel(TensorIteratorBase& iter, double beta) { } void huber_kernel(TensorIterator& iter, double delta) { - AT_DISPATCH_FLOATING_TYPES_AND2( - kBFloat16, kHalf, iter.dtype(), "huber_cpu", [&]() { + // Special-case kHalf: compute in float for numerical stability + if (iter.dtype() == kHalf) { + const float delta_val(static_cast(delta)); + const Vectorized delta_vec(static_cast(delta)); + const Vectorized point_five_vec(static_cast(0.5)); + cpu_kernel_vec( + iter, + // scalar lambda: convert half -> float, compute in float, cast back to half + [&delta_val] (at::Half a, at::Half b) -> at::Half { + float af = static_cast(a); + float bf = static_cast(b); + float z = std::abs(af - bf); + float out = z < delta_val + ? 0.5f * z * z + : delta_val * (z - 0.5f * delta_val); + return static_cast(out); + }, + [&delta_vec, &point_five_vec] (Vectorized a, Vectorized b) { + auto [a0, a1] = convert_half_float(a); + auto [b0, b1] = convert_half_float(b); + auto z = (a0 - b0).abs(); + a0 = Vectorized::blendv( + point_five_vec * z * z, + delta_vec * (z - point_five_vec * delta_vec), + z >= delta_vec); + z = (a1 - b1).abs(); + a1 = Vectorized::blendv( + point_five_vec * z * z, + delta_vec * (z - point_five_vec * delta_vec), + z >= delta_vec); + return convert_float_half(a0, a1); + } + ); + return; + } + else { + AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.dtype(), "huber_cpu", [&]() { using Vec = Vectorized; const scalar_t delta_val(delta); const Vec delta_val_vec(delta_val); @@ -835,6 +870,7 @@ void huber_kernel(TensorIterator& iter, double delta) { z >= delta_val_vec); }); }); + } } void sigmoid_backward_kernel(TensorIteratorBase& iter) { diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 1c9b39a1bd08d..d1b62feed3b41 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -828,9 +828,6 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "nn.functional.fractional_max_pool3d": {f16, f32, f64}, "nn.functional.group_norm": {f16}, "nn.functional.hinge_embedding_loss": {f16}, - # Enabling all tests for this test fails randomly - # See https://github.com/pytorch/pytorch/issues/129238 - "nn.functional.huber_loss": {f16}, "nn.functional.interpolate.bicubic": {f16}, "nn.functional.interpolate.bilinear": {f16}, "nn.functional.interpolate.trilinear": {f16}, @@ -948,9 +945,6 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "nn.functional.fractional_max_pool3d": {f16, f32, f64}, "nn.functional.group_norm": {f16}, "nn.functional.hinge_embedding_loss": {f16}, - # Enabling all tests for this test fails randomly - # See https://github.com/pytorch/pytorch/issues/129238 - "nn.functional.huber_loss": {f16}, "nn.functional.interpolate.bicubic": {f16}, "nn.functional.interpolate.bilinear": {f16}, "nn.functional.interpolate.trilinear": {f16}, From ebb2001a489181dfe5c879a5c78cde3b4bc201e4 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Tue, 18 Nov 2025 20:21:48 +0000 Subject: [PATCH 014/230] [codemod][lowrisk] Remove unused exception parameter from caffe2/torch/csrc/Exceptions.h (#168056) Summary: `-Wunused-exception-parameter` has identified an unused exception parameter. This diff removes it. This: ``` try { ... } catch (exception& e) { // no use of e } ``` should instead be written as ``` } catch (exception&) { ``` If the code compiles, this is safe to land. Test Plan: Sandcastle Reviewed By: dtolnay Differential Revision: D87273132 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168056 Approved by: https://github.com/malfet, https://github.com/Skylion007 --- torch/csrc/Exceptions.h | 2 +- torch/csrc/distributed/c10d/ProcessGroup.cpp | 2 +- torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp | 4 ++-- torch/csrc/fx/node.cpp | 8 ++++---- torch/csrc/jit/python/pybind.h | 4 ++-- torch/csrc/stable/stableivalue_conversions.h | 4 ++-- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/torch/csrc/Exceptions.h b/torch/csrc/Exceptions.h index d580809460811..adba98beb2724 100644 --- a/torch/csrc/Exceptions.h +++ b/torch/csrc/Exceptions.h @@ -138,7 +138,7 @@ inline void PyErr_SetString(PyObject* type, const std::string& message) { throw; \ } \ } \ - catch (const std::exception& e) { \ + catch (const std::exception&) { \ torch::translate_exception_to_python(std::current_exception()); \ return retval; \ } diff --git a/torch/csrc/distributed/c10d/ProcessGroup.cpp b/torch/csrc/distributed/c10d/ProcessGroup.cpp index 9f79a09d236e5..b888e315021ac 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.cpp @@ -81,7 +81,7 @@ c10::intrusive_ptr ProcessGroup::getBackend( ProcessGroup::BackendType backendType{ProcessGroup::BackendType::UNDEFINED}; try { backendType = deviceTypeToBackendType_.at(deviceType); - } catch (const std::out_of_range& e) { + } catch (const std::out_of_range&) { TORCH_CHECK( false, "No backend type associated with device type ", deviceType); } diff --git a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp index f3ff9e623043e..7427848b8445b 100644 --- a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp +++ b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp @@ -246,7 +246,7 @@ class UvTcpServer : public UvTcpSocket { uv_err_name(uv_res), uv_strerror(uv_res))); res->cacheSocketPort(); - } catch (std::exception& ex) { + } catch (std::exception&) { res->close(); throw; } @@ -322,7 +322,7 @@ class UvTcpServer : public UvTcpSocket { uv_err_name(uv_res), uv_strerror(uv_res))); res->cacheSocketPort(); - } catch (std::exception& ex) { + } catch (std::exception&) { res->close(); throw; } diff --git a/torch/csrc/fx/node.cpp b/torch/csrc/fx/node.cpp index 11659cc24eb89..117324796e7f8 100644 --- a/torch/csrc/fx/node.cpp +++ b/torch/csrc/fx/node.cpp @@ -353,7 +353,7 @@ static PyObject* NodeBase__update_args_kwargs( Py_CLEAR(node->_kwargs); node->_kwargs = map_aggregate(args[1], visit_fn); Py_RETURN_NONE; - } catch (const PythonError& e) { + } catch (const PythonError&) { return nullptr; } } @@ -397,7 +397,7 @@ static PyObject* NodeBase__replace_input_with( PyObject* update_args[2] = {new_args.get(), new_kwargs.get()}; return NodeBase__update_args_kwargs(self, update_args, 2); - } catch (const PythonError& e) { + } catch (const PythonError&) { return nullptr; } } @@ -802,7 +802,7 @@ static PyObject* py_map_aggregate( // args[0]: aggregate, args[1]: callable fn return map_aggregate( args[0], [fn](PyObject* a) { return PyObject_CallOneArg(fn, a); }); - } catch (const PythonError& e) { + } catch (const PythonError&) { return nullptr; // error should already be set } } @@ -824,7 +824,7 @@ static PyObject* py_map_arg( } return Py_NewRef(a); }); - } catch (const PythonError& e) { + } catch (const PythonError&) { return nullptr; // error should already be set } } diff --git a/torch/csrc/jit/python/pybind.h b/torch/csrc/jit/python/pybind.h index 066ff7f77f56c..845beb540c9f1 100644 --- a/torch/csrc/jit/python/pybind.h +++ b/torch/csrc/jit/python/pybind.h @@ -117,7 +117,7 @@ struct type_caster { try { value = torch::jit::toTypeInferredIValue(src); return true; - } catch (std::exception& e) { + } catch (std::exception&) { return false; } } @@ -142,7 +142,7 @@ struct type_caster { std::string src_str; try { src_str = py::cast(src); - } catch (std::exception& e) { + } catch (std::exception&) { return false; } value = torch::jit::Symbol::fromQualString(src_str); diff --git a/torch/csrc/stable/stableivalue_conversions.h b/torch/csrc/stable/stableivalue_conversions.h index 15ac8e539e76b..38fc4fe4cc8bf 100644 --- a/torch/csrc/stable/stableivalue_conversions.h +++ b/torch/csrc/stable/stableivalue_conversions.h @@ -285,7 +285,7 @@ struct FromImpl> { torch_list_push_back(new_list_handle, from(elem))); } return from(new_list_handle); - } catch (const std::runtime_error& e) { + } catch (const std::runtime_error&) { if (new_list_handle != nullptr) { // clean up memory if an error was thrown TORCH_ERROR_CODE_CHECK(torch_delete_list(new_list_handle)); @@ -553,7 +553,7 @@ struct ToImpl> { } TORCH_ERROR_CODE_CHECK(torch_delete_list(list_handle)); return result; - } catch (const std::runtime_error& e) { + } catch (const std::runtime_error&) { // clean up memory if an exception is thrown, and rethrow TORCH_ERROR_CODE_CHECK(torch_delete_list(list_handle)); throw; From 41999a579d0fe2f74cc0c34f89441e2073f7dd3e Mon Sep 17 00:00:00 2001 From: Sam Gross Date: Tue, 18 Nov 2025 22:02:02 +0000 Subject: [PATCH 015/230] Fix Tensor use_count check in VariableType.cpp (#168060) Summary: If the Tensor has a PyObject, it's use count will now be two instead of one. Test Plan: `buck test -j 18 fbcode//mode/dev-nosan fbcode//caffe2/test:torch` Differential Revision: D87297965 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168060 Approved by: https://github.com/albanD, https://github.com/Skylion007 --- tools/autograd/gen_variable_type.py | 8 ++++---- tools/autograd/templates/VariableType.cpp | 12 ++++++++++++ 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 4796153f24f05..e1a518aca6704 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -421,7 +421,7 @@ # inplace or out-variants) # If the function does not modify its arguments, we also check the following properties # pertaining to its output: -# 2) Its TensorImpl has use_count of 1 +# 2) Its TensorImpl has use_count of 1 (or 2 if it has a PyObject) # 3) If the function is a view function, it has the same StorageImpl as that of # the input it is aliased with. Otherwise, its StorageImpl has use_count of 1 # @@ -496,10 +496,10 @@ """ ) -ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE = CodeTemplate( +ENFORCE_TENSOR_IMPL_USE_COUNT = CodeTemplate( """\ if (!at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(${tensor_name})) - TORCH_INTERNAL_ASSERT(${tensor_name}.use_count() <= 1, "function: ${fn_name}"); + TORCH_INTERNAL_ASSERT(${tensor_name}.use_count() == expected_fresh_use_count(${tensor_name}), "function: ${fn_name}"); """ ) @@ -1664,7 +1664,7 @@ def check_tensorimpl_and_storage( if type_wrapper_name(f) not in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT: stmts_after_call += [ - ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE.substitute( + ENFORCE_TENSOR_IMPL_USE_COUNT.substitute( tensor_name=ret_name, fn_name=type_wrapper_name(f) ) ] diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp index 23976a48473a3..d1de108283b11 100644 --- a/tools/autograd/templates/VariableType.cpp +++ b/tools/autograd/templates/VariableType.cpp @@ -47,6 +47,18 @@ namespace{ meta->grad_accumulator_.reset(); } } +[[maybe_unused]] size_t expected_fresh_use_count(const Variable& self) { + if (!self.defined()) { + // An UndefinedTensorImpl always has a use count of 0 + return 0; + } + if (self.unsafeGetTensorImpl()->pyobj_slot()->load_pyobj() != nullptr) { + // A TensorImpl with a Python object has a use count of 2 + return 2; + } + // A fresh TensorImpl (with no PyObject) has a use count of 1 + return 1; +} } namespace { From e8970ba0105fa922bf154156c360da77a2b5bf89 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 18 Nov 2025 13:50:17 -0800 Subject: [PATCH 016/230] [CI] Migrate all gcc9 jobs to gcc11 (#167933) As compiler has not been supported for last 3 years and all manylinux2_28 builds should have at least gcc-11 Prep change for C++20 standard migration Pull Request resolved: https://github.com/pytorch/pytorch/pull/167933 Approved by: https://github.com/yangw-dev, https://github.com/atalman ghstack dependencies: #168090 --- .ci/docker/build.sh | 14 +---- .../workflows/attention_op_microbenchmark.yml | 8 +-- .github/workflows/docker-builds.yml | 3 +- .../workflows/inductor-micro-benchmark.yml | 10 ++-- .github/workflows/inductor-perf-compare.yml | 10 ++-- .github/workflows/inductor-perf-test-b200.yml | 18 +++---- .../inductor-perf-test-nightly-h100.yml | 10 ++-- .../workflows/inductor-perf-test-nightly.yml | 18 +++---- .github/workflows/inductor-periodic.yml | 12 ++--- .github/workflows/inductor-unittest.yml | 6 +-- .github/workflows/inductor.yml | 6 +-- .github/workflows/operator_microbenchmark.yml | 8 +-- .github/workflows/periodic.yml | 51 +++++-------------- .github/workflows/pull.yml | 20 ++++---- .github/workflows/torchbench.yml | 10 ++-- .github/workflows/trunk.yml | 4 +- 16 files changed, 87 insertions(+), 121 deletions(-) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index b7e61115e37d6..748608005e622 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -125,10 +125,10 @@ case "$tag" in UCC_COMMIT=${_UCC_COMMIT} TRITON=yes ;; - pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks) + pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks) CUDA_VERSION=12.8.1 ANACONDA_PYTHON_VERSION=3.10 - GCC_VERSION=9 + GCC_VERSION=11 VISION=yes KATEX=yes UCX_COMMIT=${_UCX_COMMIT} @@ -146,16 +146,6 @@ case "$tag" in UCC_COMMIT=${_UCC_COMMIT} TRITON=yes ;; - pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9) - CUDA_VERSION=12.8.1 - ANACONDA_PYTHON_VERSION=3.10 - GCC_VERSION=9 - VISION=yes - KATEX=yes - UCX_COMMIT=${_UCX_COMMIT} - UCC_COMMIT=${_UCC_COMMIT} - TRITON=yes - ;; pytorch-linux-jammy-py3-clang12-onnx) ANACONDA_PYTHON_VERSION=3.10 CLANG_VERSION=12 diff --git a/.github/workflows/attention_op_microbenchmark.yml b/.github/workflows/attention_op_microbenchmark.yml index e01bc49621dcf..eec4d21fe2616 100644 --- a/.github/workflows/attention_op_microbenchmark.yml +++ b/.github/workflows/attention_op_microbenchmark.yml @@ -23,7 +23,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: runner: linux.12xlarge.memory - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '8.0 9.0' test-matrix: | @@ -39,7 +39,7 @@ jobs: needs: attn-microbenchmark-build with: timeout-minutes: 500 - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 docker-image: ${{ needs.attn-microbenchmark-build.outputs.docker-image }} test-matrix: ${{ needs.attn-microbenchmark-build.outputs.test-matrix }} secrets: inherit @@ -51,7 +51,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: runner: linux.12xlarge.memory - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '10.0' test-matrix: | @@ -66,7 +66,7 @@ jobs: needs: opmicrobenchmark-build-b200 with: timeout-minutes: 500 - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 docker-image: ${{ needs.opmicrobenchmark-build-b200.outputs.docker-image }} test-matrix: ${{ needs.opmicrobenchmark-build-b200.outputs.test-matrix }} aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 5700c8e3c74b3..fa1f083800fe0 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -52,8 +52,7 @@ jobs: pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11, pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11, pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm, - pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks, - pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9, + pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks, pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11, pytorch-linux-jammy-py3.10-clang12, pytorch-linux-jammy-py3.11-clang12, diff --git a/.github/workflows/inductor-micro-benchmark.yml b/.github/workflows/inductor-micro-benchmark.yml index a0ae234ab5669..3421e2b9af77d 100644 --- a/.github/workflows/inductor-micro-benchmark.yml +++ b/.github/workflows/inductor-micro-benchmark.yml @@ -30,14 +30,14 @@ jobs: opt_out_experiments: lf build: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-build.yml needs: - get-default-label-prefix with: runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ @@ -46,11 +46,11 @@ jobs: secrets: inherit test: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-test.yml needs: build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} timeout-minutes: 720 diff --git a/.github/workflows/inductor-perf-compare.yml b/.github/workflows/inductor-perf-compare.yml index 628f624240127..764e631819ccc 100644 --- a/.github/workflows/inductor-perf-compare.yml +++ b/.github/workflows/inductor-perf-compare.yml @@ -27,14 +27,14 @@ jobs: opt_out_experiments: lf build: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-build.yml needs: - get-default-label-prefix with: runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ @@ -47,11 +47,11 @@ jobs: secrets: inherit test: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-test.yml needs: build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} # disable monitor in perf tests for more investigation diff --git a/.github/workflows/inductor-perf-test-b200.yml b/.github/workflows/inductor-perf-test-b200.yml index 7b59e92386a33..11f5f10a55ad8 100644 --- a/.github/workflows/inductor-perf-test-b200.yml +++ b/.github/workflows/inductor-perf-test-b200.yml @@ -80,7 +80,7 @@ jobs: opt_out_experiments: lf build: - name: cuda12.8-py3.10-gcc9-sm100 + name: cuda12.8-py3.10-gcc11-sm100 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: @@ -90,8 +90,8 @@ jobs: # from trunk. Also use a memory-intensive runner here because memory is # usually the bottleneck runner: linux.12xlarge.memory - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '10.0' test-matrix: | { include: [ @@ -104,12 +104,12 @@ jobs: secrets: inherit test-periodically: - name: cuda12.8-py3.10-gcc9-sm100 + name: cuda12.8-py3.10-gcc11-sm100 uses: ./.github/workflows/_linux-test.yml needs: build if: github.event.schedule == '0 7 * * 1-6' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} @@ -121,12 +121,12 @@ jobs: secrets: inherit test-weekly: - name: cuda12.8-py3.10-gcc9-sm100 + name: cuda12.8-py3.10-gcc11-sm100 uses: ./.github/workflows/_linux-test.yml needs: build if: github.event.schedule == '0 7 * * 0' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-cudagraphs_low_precision-true docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} @@ -138,11 +138,11 @@ jobs: secrets: inherit test: - name: cuda12.8-py3.10-gcc9-sm100 + name: cuda12.8-py3.10-gcc11-sm100 uses: ./.github/workflows/_linux-test.yml needs: build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }} docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} diff --git a/.github/workflows/inductor-perf-test-nightly-h100.yml b/.github/workflows/inductor-perf-test-nightly-h100.yml index 8209bf053a772..1c35fc6794537 100644 --- a/.github/workflows/inductor-perf-test-nightly-h100.yml +++ b/.github/workflows/inductor-perf-test-nightly-h100.yml @@ -95,8 +95,8 @@ jobs: # from trunk. Also use a memory-intensive runner here because memory is # usually the bottleneck runner: linux.12xlarge.memory - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm90 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '9.0' test-matrix: | { include: [ @@ -132,7 +132,7 @@ jobs: needs: build if: github.event.schedule == '15 0 * * 1-6' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm90 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90 dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} @@ -149,7 +149,7 @@ jobs: needs: build if: github.event.schedule == '0 7 * * 0' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm90 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90 dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-cudagraphs_low_precision-true docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} @@ -168,7 +168,7 @@ jobs: # needs one round of benchmark if: ${{ github.event_name == 'workflow_dispatch' || github.event_name == 'pull_request' }} with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm90 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90 dashboard-tag: training-${{ inputs.training || 'true' }}-inference-${{ inputs.inference || 'true' }}-default-${{ inputs.default || 'true' }}-dynamic-${{ inputs.dynamic || 'true' }}-cudagraphs-${{ inputs.cudagraphs || 'true' }}-cppwrapper-${{ inputs.cppwrapper || 'false' }}-aotinductor-${{ inputs.aotinductor || 'false' }}-maxautotune-${{ inputs.maxautotune || 'false' }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs || 'false' }}-cudagraphs_low_precision-${{ inputs.cudagraphs || 'false' }} docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} diff --git a/.github/workflows/inductor-perf-test-nightly.yml b/.github/workflows/inductor-perf-test-nightly.yml index 19f72ba453414..88a528ba1b075 100644 --- a/.github/workflows/inductor-perf-test-nightly.yml +++ b/.github/workflows/inductor-perf-test-nightly.yml @@ -80,15 +80,15 @@ jobs: opt_out_experiments: lf build: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" # Every bit to make perf run faster helps runner: linux.12xlarge.memory - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ @@ -117,12 +117,12 @@ jobs: secrets: inherit test-nightly: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-test.yml needs: build if: github.event.schedule == '0 7 * * 1-6' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} @@ -133,12 +133,12 @@ jobs: secrets: inherit test-weekly: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-test.yml needs: build if: github.event.schedule == '0 7 * * 0' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-cudagraphs_low_precision-true docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} @@ -150,12 +150,12 @@ jobs: secrets: inherit test: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-test.yml needs: build if: github.event_name == 'workflow_dispatch' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }} docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} diff --git a/.github/workflows/inductor-periodic.yml b/.github/workflows/inductor-periodic.yml index b08d9865d15d3..f3e34d6ecb52f 100644 --- a/.github/workflows/inductor-periodic.yml +++ b/.github/workflows/inductor-periodic.yml @@ -37,8 +37,8 @@ jobs: needs: get-default-label-prefix with: runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.0;8.6' test-matrix: | { include: [ @@ -76,7 +76,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: periodic-dynamo-benchmarks-build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86 docker-image: ${{ needs.periodic-dynamo-benchmarks-build.outputs.docker-image }} test-matrix: ${{ needs.periodic-dynamo-benchmarks-build.outputs.test-matrix }} secrets: inherit @@ -138,8 +138,8 @@ jobs: - get-default-label-prefix with: runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ @@ -153,7 +153,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: inductor-smoke-build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 docker-image: ${{ needs.inductor-smoke-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-smoke-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/inductor-unittest.yml b/.github/workflows/inductor-unittest.yml index ca9b57cab2ddb..0902026adb8ce 100644 --- a/.github/workflows/inductor-unittest.yml +++ b/.github/workflows/inductor-unittest.yml @@ -33,8 +33,8 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.6' runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | @@ -52,7 +52,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: inductor-build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86 docker-image: ${{ needs.inductor-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index 8a913c3b36a11..e524ed548b741 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -49,8 +49,8 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.6' runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | @@ -69,7 +69,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: inductor-build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86 docker-image: ${{ needs.inductor-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/operator_microbenchmark.yml b/.github/workflows/operator_microbenchmark.yml index 89d6d63c72875..dd5cd832570f9 100644 --- a/.github/workflows/operator_microbenchmark.yml +++ b/.github/workflows/operator_microbenchmark.yml @@ -25,7 +25,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: runner: linux.12xlarge.memory - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '8.0 9.0' test-matrix: | @@ -41,7 +41,7 @@ jobs: needs: opmicrobenchmark-build with: timeout-minutes: 500 - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 docker-image: ${{ needs.opmicrobenchmark-build.outputs.docker-image }} test-matrix: ${{ needs.opmicrobenchmark-build.outputs.test-matrix }} secrets: inherit @@ -53,7 +53,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: runner: linux.12xlarge.memory - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '10.0' test-matrix: | @@ -68,7 +68,7 @@ jobs: needs: opmicrobenchmark-build-b200 with: timeout-minutes: 500 - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 docker-image: ${{ needs.opmicrobenchmark-build-b200.outputs.docker-image }} test-matrix: ${{ needs.opmicrobenchmark-build-b200.outputs.test-matrix }} aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 5a90db9ab5737..325050392a393 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -90,6 +90,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-cuda12.8-py3.10-gcc11 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 + cuda-arch-list: 8.6 test-matrix: | { include: [ { config: "nogpu_AVX512", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, @@ -97,7 +98,9 @@ jobs: { config: "nogpu_AVX512", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, { config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, { config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, - { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "multigpu", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu", owners: ["oncall:distributed"] }, + { config: "multigpu", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu", owners: ["oncall:distributed"] }, ]} secrets: inherit @@ -113,40 +116,14 @@ jobs: test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.test-matrix }} secrets: inherit - linux-jammy-cuda12_8-py3_10-gcc9-build: - name: linux-jammy-cuda12.8-py3.10-gcc9 - uses: ./.github/workflows/_linux-build.yml - needs: get-label-type - with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc9 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9 - cuda-arch-list: 8.6 - test-matrix: | - { include: [ - { config: "multigpu", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu", owners: ["oncall:distributed"] }, - { config: "multigpu", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu", owners: ["oncall:distributed"] }, - ]} - secrets: inherit - - linux-jammy-cuda12_8-py3_10-gcc9-test: - name: linux-jammy-cuda12.8-py3.10-gcc9 - uses: ./.github/workflows/_linux-test.yml - needs: linux-jammy-cuda12_8-py3_10-gcc9-build - with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9 - docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-build.outputs.test-matrix }} - secrets: inherit - - linux-jammy-cuda12_8-py3_10-gcc9-debug-build: - name: linux-jammy-cuda12.8-py3.10-gcc9-debug + linux-jammy-cuda12_8-py3_10-gcc11-debug-build: + name: linux-jammy-cuda12.8-py3.10-gcc11-debug uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-debug + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: 8.9 test-matrix: | { include: [ @@ -160,16 +137,16 @@ jobs: ]} secrets: inherit - linux-jammy-cuda12_8-py3_10-gcc9-debug-test: - name: linux-jammy-cuda12.8-py3.10-gcc9-debug + linux-jammy-cuda12_8-py3_10-gcc11-debug-test: + name: linux-jammy-cuda12.8-py3.10-gcc11-debug uses: ./.github/workflows/_linux-test.yml needs: - - linux-jammy-cuda12_8-py3_10-gcc9-debug-build + - linux-jammy-cuda12_8-py3_10-gcc11-debug-build - target-determination with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug - docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-debug-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-debug-build.outputs.test-matrix }} + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-debug + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-debug-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-debug-build.outputs.test-matrix }} secrets: inherit linux-jammy-cuda13_0-py3_10-gcc11-build: diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 51e211a5ad2ad..f2483dff9a94c 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -318,14 +318,14 @@ jobs: ]} secrets: inherit - linux-jammy-cuda12_8-py3_10-gcc9-inductor-build: - name: cuda12.8-py3.10-gcc9-sm75 + linux-jammy-cuda12_8-py3_10-gcc11-inductor-build: + name: cuda12.8-py3.10-gcc11-sm75 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm75 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm75 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '7.5' test-matrix: | { include: [ @@ -333,14 +333,14 @@ jobs: ]} secrets: inherit - linux-jammy-cuda12_8-py3_10-gcc9-inductor-test: - name: cuda12.8-py3.10-gcc9-sm75 + linux-jammy-cuda12_8-py3_10-gcc11-inductor-test: + name: cuda12.8-py3.10-gcc11-sm75 uses: ./.github/workflows/_linux-test.yml - needs: linux-jammy-cuda12_8-py3_10-gcc9-inductor-build + needs: linux-jammy-cuda12_8-py3_10-gcc11-inductor-build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm75 - docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.test-matrix }} + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm75 + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-inductor-build.outputs.test-matrix }} secrets: inherit linux-noble-xpu-n-py3_10-build: diff --git a/.github/workflows/torchbench.yml b/.github/workflows/torchbench.yml index 08fcd33402625..5a0273f0b745e 100644 --- a/.github/workflows/torchbench.yml +++ b/.github/workflows/torchbench.yml @@ -26,14 +26,14 @@ jobs: curr_ref_type: ${{ github.ref_type }} build: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-build.yml needs: - get-default-label-prefix with: runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ @@ -42,11 +42,11 @@ jobs: secrets: inherit test: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-test.yml needs: build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 6e775da47fc1e..eeba4c08a0c68 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -231,8 +231,8 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - build-environment: linux-jammy-cuda12.8-py3.12-gcc9-sm80 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.12-gcc11-sm80 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.0' secrets: inherit From dc4f3c7505a810322db51e68800b477ed2147947 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Tue, 4 Nov 2025 12:18:37 -0600 Subject: [PATCH 017/230] [MPS] Move `elu` impl to Metal (#166903) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166903 Approved by: https://github.com/malfet --- aten/src/ATen/native/mps/MetalShaderLibrary.h | 13 ++ aten/src/ATen/native/mps/OperationUtils.h | 145 +++++++++++++ aten/src/ATen/native/mps/kernels/Activation.h | 16 ++ .../native/mps/kernels/ActivationKernel.metal | 54 +++++ .../ATen/native/mps/operations/Activation.mm | 190 ------------------ .../native/mps/operations/ActivationKernel.mm | 28 +++ aten/src/ATen/native/native_functions.yaml | 6 +- 7 files changed, 258 insertions(+), 194 deletions(-) create mode 100644 aten/src/ATen/native/mps/kernels/Activation.h diff --git a/aten/src/ATen/native/mps/MetalShaderLibrary.h b/aten/src/ATen/native/mps/MetalShaderLibrary.h index d9f126938b301..fcdf39b8a9f4b 100644 --- a/aten/src/ATen/native/mps/MetalShaderLibrary.h +++ b/aten/src/ATen/native/mps/MetalShaderLibrary.h @@ -147,6 +147,19 @@ class MetalShaderLibrary { const std::optional alpha = std::nullopt, const std::optional scalar_arg_type = std::nullopt); + template + void exec_unary_kernel_with_params( + TensorIteratorBase& iter, + const std::string& name, + T params, + const std::string& params_type_name); + template + void exec_binary_kernel_with_params( + TensorIteratorBase& iter, + const std::string& name, + T params, + const std::string& params_type_name); + protected: virtual MTLLibrary_t getLibrary(); virtual MTLLibrary_t getLibrary( diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index cb488a3f5f117..5ca0ebe3de9bb 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -7,10 +7,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include @@ -630,4 +632,147 @@ inline bool needsGather(const TensorBase& t) { return !is_macOS_15_0_or_newer && (!t.is_contiguous() || t.storage_offset()); } +template +void MetalShaderLibrary::exec_unary_kernel_with_params(TensorIteratorBase& iter, + const std::string& name, + T params, + const std::string& params_type_name) { + using namespace at::mps; + // Decompose 64-bit tensor into 32-bit ones + if (!iter.can_use_32bit_indexing()) { + for (auto&& sub_iter : iter.with_32bit_indexing()) { + exec_unary_kernel_with_params(sub_iter, name, params, params_type_name); + } + return; + } + + auto inputTensor = iter.input(0); + auto outputTensor = iter.output(0); + uint32_t length = iter.numel(); + if (length == 0) { + return; + } + auto kernel_name = fmt::format("{}_{}_{}_{}{}", + name, + iter.is_contiguous() ? "dense" : "strided", + scalarToMetalTypeString(outputTensor), + scalarToMetalTypeString(inputTensor), + fmt::format("_{}", params_type_name)); + @autoreleasepool { + auto cplState = getPipelineStateForFunc(kernel_name); + + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + auto computeEncoder = mpsStream->commandEncoder(); + + getMPSProfiler().beginProfileKernel(cplState, name, {inputTensor}); + + [computeEncoder setComputePipelineState:cplState]; + bind_iter_tensors(computeEncoder, iter); + if (!iter.is_contiguous()) { + mtl_setArgs<2>(computeEncoder, + outputTensor.sizes(), + inputTensor.strides(), + outputTensor.strides(), + inputTensor.ndimension()); + } + detail::mtl_setArg(computeEncoder, params, iter.is_contiguous() ? 2 : 6); + mtl_dispatch1DJob(computeEncoder, cplState, length); + + getMPSProfiler().endProfileKernel(cplState); + }); + } +} + +template +void MetalShaderLibrary::exec_binary_kernel_with_params(TensorIteratorBase& iter, + const std::string& name, + T params, + const std::string& params_type_name) { + using namespace mps; + // TODO: Figure a better place to downcast double scalars (probably in tensor iterator itself?) + // Right now running something like 1.0-torch.rand(5, device='mps') will create iterator with + // double as common dtype (because Python floating point are always 64-bit values) + TORCH_CHECK(iter.output().scalar_type() != at::kDouble, "float64 is not supported on MPS"); + + // Skip for empty iterators + if (iter.numel() == 0) { + return; + } + + // Decompose 64-bit tensor into 32-bit ones + if (!iter.can_use_32bit_indexing()) { + for (auto&& sub_iter : iter.with_32bit_indexing()) { + exec_binary_kernel_with_params(sub_iter, name, params, params_type_name); + } + return; + } + + auto convert_double_scalar = [](Tensor& t) { + if (t.dim() != 0) { + return; + } + if (t.scalar_type() == kDouble) { + t = t.to(kFloat); + } else if (t.scalar_type() == kComplexDouble) { + t = t.to(kComplexFloat); + } + }; + + Tensor input = iter.input(0); + Tensor other = iter.input(1); + Tensor out = iter.output(); + + convert_double_scalar(input); + convert_double_scalar(other); + + MPSStream* mpsStream = getCurrentMPSStream(); + const auto cast_needed = input.scalar_type() != other.scalar_type(); + const auto suffix = iter.is_contiguous() ? "dense" : "strided"; + // TODO: Implicitly pass both input and output types to non-cast kernels + const auto kernel_name = cast_needed + ? fmt::format("{}_{}_cast_{}_{}", name, suffix, scalarToMetalTypeString(out), params_type_name) + : fmt::format("{}_{}_{}_{}_{}", + name, + suffix, + scalarToMetalTypeString(out), + scalarToMetalTypeString(input), + params_type_name); + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + auto computeEncoder = mpsStream->commandEncoder(); + auto binaryPSO = getPipelineStateForFunc(kernel_name); + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(binaryPSO, kernel_name, {input, other}); + [computeEncoder setComputePipelineState:binaryPSO]; + // Set input and output tensors + bind_iter_tensors(computeEncoder, iter); + // Iterator is contiguous if all of its elements are dense in storage, + // i.e. it's true for both row-first and column-first tensors + if (iter.is_contiguous()) { + detail::mtl_setArg(computeEncoder, params, 3); + if (cast_needed) { + std::array size_and_types = {static_cast(c10::elementSize(input.scalar_type())), + static_cast(c10::elementSize(other.scalar_type())), + static_cast(input.scalar_type()), + static_cast(other.scalar_type())}; + mtl_setBytes(computeEncoder, size_and_types, 4); + } + } else { + // Please note that shapes and strides of the iterator might be + // different than that of its operands, for example binary op + // between 4x4 tensor and scalar will result in 1D 16 element iterator + std::array ndim_and_types = {iter.ndim(), + static_cast(input.scalar_type()), + static_cast(other.scalar_type()), + static_cast(out.scalar_type())}; + mtl_setArgs<3>( + computeEncoder, params, iter.shape(), iter.strides(0), iter.strides(1), iter.strides(2), ndim_and_types); + } + mtl_dispatch1DJob(computeEncoder, binaryPSO, iter.numel()); + getMPSProfiler().endProfileKernel(binaryPSO); + } + }); +} + } // namespace at::native::mps diff --git a/aten/src/ATen/native/mps/kernels/Activation.h b/aten/src/ATen/native/mps/kernels/Activation.h new file mode 100644 index 0000000000000..34ad90dd7a2a3 --- /dev/null +++ b/aten/src/ATen/native/mps/kernels/Activation.h @@ -0,0 +1,16 @@ +#pragma once + +template +struct ELUParams { + T alpha; + T scale; + T input_scale; +}; + +template +struct ELUBackwardParams { + T alpha; + T scale; + T input_scale; + bool is_result; +}; diff --git a/aten/src/ATen/native/mps/kernels/ActivationKernel.metal b/aten/src/ATen/native/mps/kernels/ActivationKernel.metal index ae1fda66c3b38..7d1f3aa5bacf6 100644 --- a/aten/src/ATen/native/mps/kernels/ActivationKernel.metal +++ b/aten/src/ATen/native/mps/kernels/ActivationKernel.metal @@ -1,3 +1,4 @@ +#include #include #include #include @@ -99,6 +100,59 @@ REGISTER_BINARY_OP(hardswish_backward, float, float); REGISTER_BINARY_OP(hardswish_backward, half, half); REGISTER_BINARY_OP(hardswish_backward, bfloat, bfloat); +struct elu_functor { + template + inline T operator()(const T self_, const ELUParams params) { + using op_T = opmath_t; + auto alpha = static_cast(params.alpha); + auto scale = static_cast(params.scale); + auto input_scale = static_cast(params.input_scale); + auto self = static_cast(self_); + auto neg_res = alpha * (::metal::precise::exp(self * input_scale) - 1); + return static_cast(scale * (self < 0 ? neg_res : self)); + } +}; + +struct elu_backward_functor { + template + inline T operator()( + const T grad_output_, + const T self_, + ELUBackwardParams params) { + using op_T = opmath_t; + auto alpha = static_cast(params.alpha); + auto scale = static_cast(params.scale); + auto input_scale = static_cast(params.input_scale); + auto grad_output = static_cast(grad_output_); + auto self = static_cast(self_); + + if (params.is_result) { + auto neg_coef = input_scale * (self + alpha * scale); + return static_cast(grad_output * (self <= 0 ? neg_coef : scale)); + } else { + auto neg_coef = input_scale * alpha * scale * + ::metal::precise::exp(self * input_scale); + return static_cast(grad_output * (self <= 0 ? neg_coef : scale)); + } + } +}; + +#define REGISTER_ELU_OP(T) \ + typedef ELUParams ELUParams_##T; \ + REGISTER_UNARY_ALPHA_OP(elu, T, ELUParams_##T, T); + +REGISTER_ELU_OP(float); +REGISTER_ELU_OP(half); +REGISTER_ELU_OP(bfloat); + +#define REGISTER_ELU_BACKWARD_OP(T) \ + typedef ELUBackwardParams ELUBackwardParams_##T; \ + REGISTER_BINARY_ALPHA_OP(elu_backward, T, ELUBackwardParams_##T, T); + +REGISTER_ELU_BACKWARD_OP(float); +REGISTER_ELU_BACKWARD_OP(half); +REGISTER_ELU_BACKWARD_OP(bfloat); + struct leaky_relu_functor { template inline T operator()(const T x, const T negative_slope) { diff --git a/aten/src/ATen/native/mps/operations/Activation.mm b/aten/src/ATen/native/mps/operations/Activation.mm index e437ea5ed7989..64ef972b55530 100644 --- a/aten/src/ATen/native/mps/operations/Activation.mm +++ b/aten/src/ATen/native/mps/operations/Activation.mm @@ -11,8 +11,6 @@ #include #include #include -#include -#include #include #include #include @@ -698,194 +696,6 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c } } -static void elu_variants_out_mps(const Tensor& self, - const Scalar& alpha, - const Scalar& scale, - const Scalar& input_scale, - const Tensor& result, - std::string func_name) { - using namespace mps; - using CachedGraph = MPSUnaryCachedGraph; - - auto resultMemFormat = result.suggest_memory_format(); - bool executeGatherOp = !(self.is_contiguous(resultMemFormat) && result.is_contiguous(resultMemFormat)); - Tensor out; - if (executeGatherOp) { - out = at::empty_like(result, MemoryFormat::Contiguous); - } - - // Empty output - if (result.numel() == 0) { - return; - } - - MPSStream* stream = getCurrentMPSStream(); - - @autoreleasepool { - std::string key = func_name + ":" + getTensorsStringKey({self}) + ":" + std::to_string(alpha.to()) + ":" + - std::to_string(scale.to()) + ":" + std::to_string(input_scale.to()); - - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - - // scale * (max(0, x) + min(0, alpha * (exp(input_scale * x) - 1) )) - - MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha.to() - shape:@[ @1 ] - dataType:getMPSDataType(self)]; - - MPSGraphTensor* inputScaleTensor = [mpsGraph constantWithScalar:input_scale.to() - shape:@[ @1 ] - dataType:getMPSDataType(self)]; - - MPSGraphTensor* scaleTensor = [mpsGraph constantWithScalar:scale.to() - shape:@[ @1 ] - dataType:getMPSDataType(self)]; - MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0f shape:@[ @1 ] dataType:getMPSDataType(self)]; - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0f shape:@[ @1 ] dataType:getMPSDataType(self)]; - - MPSGraphTensor* scaledInputTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor - secondaryTensor:inputScaleTensor - name:nil]; - MPSGraphTensor* exponentTensor = [mpsGraph exponentWithTensor:scaledInputTensor name:nil]; - MPSGraphTensor* exponentMinusOneTensor = [mpsGraph subtractionWithPrimaryTensor:exponentTensor - secondaryTensor:unitTensor - name:nil]; - MPSGraphTensor* alphaTimesTensor = [mpsGraph multiplicationWithPrimaryTensor:exponentMinusOneTensor - secondaryTensor:alphaTensor - name:nil]; - MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:inputTensor - secondaryTensor:zeroTensor - name:nil]; - MPSGraphTensor* fusedOutput = [mpsGraph selectWithPredicateTensor:predicateTensor - truePredicateTensor:inputTensor - falsePredicateTensor:alphaTimesTensor - name:nil]; - MPSGraphTensor* outputTensor = [mpsGraph multiplicationWithPrimaryTensor:fusedOutput - secondaryTensor:scaleTensor - name:nil]; - - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; - }); - - auto selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp); - auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out.has_storage() ? out : result, nil, false); - auto feeds = dictionaryFromPlaceholders(selfPlaceholder); - runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); - if (out.has_storage()) { - result.copy_(out); - } - } -} - -// scale * (max(0, x) + min(0, alpha * (exp(input_scale * x) - 1) )) -TORCH_IMPL_FUNC(elu_out_mps) -(const Tensor& self, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale, const Tensor& result) { - elu_variants_out_mps(self, alpha, scale, input_scale, result, "elu_out_mps"); -} - -TORCH_IMPL_FUNC(elu_backward_out_mps) -(const Tensor& grad_output, - const Scalar& alpha, - const Scalar& scale, - const Scalar& input_scale, - bool is_result, - const Tensor& self_or_result, - const Tensor& grad_input) { - using namespace mps; - using CachedGraph = MPSUnaryGradCachedGraph; - auto gradMemFormat = grad_input.suggest_memory_format(); - bool executeGatherOp = !(grad_output.is_contiguous(gradMemFormat) && self_or_result.is_contiguous(gradMemFormat) && - grad_input.is_contiguous(gradMemFormat)); - Tensor out; - if (executeGatherOp && gradMemFormat == MemoryFormat::ChannelsLast) { - out = at::empty_like(grad_input, MemoryFormat::Contiguous); - } - - // Empty output - if (grad_input.numel() == 0) { - return; - } - - MPSStream* stream = getCurrentMPSStream(); - - @autoreleasepool { - std::string key = "elu_backward_out_mps:" + getTensorsStringKey({grad_output, self_or_result}) + ":" + - std::to_string(alpha.to()) + ":" + std::to_string(scale.to()) + ":" + - std::to_string(input_scale.to()) + ":" + std::to_string(is_result); - - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); - MPSGraphTensor* selfOrResultTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_or_result); - MPSGraphTensor* lessThanZeroGradTensor = nil; - - if (is_result) { - MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha.to() - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - MPSGraphTensor* resultPlusAlphaTensor = [mpsGraph additionWithPrimaryTensor:selfOrResultTensor - secondaryTensor:alphaTensor - name:nil]; - auto constMul = scale.to() * input_scale.to(); - MPSGraphTensor* constMulTensor = [mpsGraph constantWithScalar:constMul - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - lessThanZeroGradTensor = [mpsGraph multiplicationWithPrimaryTensor:resultPlusAlphaTensor - secondaryTensor:constMulTensor - name:nil]; - } else { - MPSGraphTensor* inputScaleTensor = [mpsGraph constantWithScalar:input_scale.to() - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - MPSGraphTensor* scaledInputTensor = [mpsGraph multiplicationWithPrimaryTensor:selfOrResultTensor - secondaryTensor:inputScaleTensor - name:nil]; - MPSGraphTensor* expTensor = [mpsGraph exponentWithTensor:scaledInputTensor name:nil]; - auto constMul = scale.to() * input_scale.to() * alpha.to(); - MPSGraphTensor* constMulTensor = [mpsGraph constantWithScalar:constMul - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - lessThanZeroGradTensor = [mpsGraph multiplicationWithPrimaryTensor:expTensor - secondaryTensor:constMulTensor - name:nil]; - } - - MPSGraphTensor* scaleTensor = [mpsGraph constantWithScalar:scale.to() - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0f - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:selfOrResultTensor - secondaryTensor:zeroTensor - name:nil]; - MPSGraphTensor* gradTensor = [mpsGraph selectWithPredicateTensor:predicateTensor - truePredicateTensor:scaleTensor - falsePredicateTensor:lessThanZeroGradTensor - name:nil]; - MPSGraphTensor* gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:gradTensor - secondaryTensor:gradOutputTensor - name:nil]; - - newCachedGraph->gradOutputTensor_ = gradOutputTensor; - newCachedGraph->inputTensor_ = selfOrResultTensor; - newCachedGraph->gradInputTensor_ = gradInputTensor; - }); - - Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output, nil, executeGatherOp); - Placeholder selfOrResultPlaceholder = Placeholder(cachedGraph->inputTensor_, self_or_result, nil, executeGatherOp); - Placeholder gradInputPlaceholder = - Placeholder(cachedGraph->gradInputTensor_, out.has_storage() ? out : grad_input, nil, false); - - auto feeds = dictionaryFromPlaceholders(gradOutputPlaceholder, selfOrResultPlaceholder); - runMPSGraph(stream, cachedGraph->graph(), feeds, gradInputPlaceholder); - if (out.has_storage()) { - grad_input.copy_(out); - } - } -} - TORCH_IMPL_FUNC(glu_out_mps)(const Tensor& self, const int64_t dim, const Tensor& output) { using namespace mps; using CachedGraph = MPSUnaryCachedGraph; diff --git a/aten/src/ATen/native/mps/operations/ActivationKernel.mm b/aten/src/ATen/native/mps/operations/ActivationKernel.mm index cec8bfa2312e4..f6d3ad986ade0 100644 --- a/aten/src/ATen/native/mps/operations/ActivationKernel.mm +++ b/aten/src/ATen/native/mps/operations/ActivationKernel.mm @@ -1,8 +1,10 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include #include #include #include +#include #include namespace at::native { @@ -41,6 +43,30 @@ static void hardswish_backward_kernel(at::TensorIterator& iter) { lib.exec_binary_kernel(iter, "hardswish_backward"); } +static void elu_kernel(TensorIteratorBase& iter, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale) { + AT_DISPATCH_FLOATING_TYPES_AND2(c10::kHalf, c10::kBFloat16, iter.common_dtype(), "elu_mps", [&]() { + ELUParams params{alpha.to(), scale.to(), input_scale.to()}; + lib.exec_unary_kernel_with_params( + iter, "elu", params, fmt::format("ELUParams_{}", mps::scalarToMetalTypeString(iter.common_dtype()))); + }); +} + +static void elu_backward_kernel(TensorIteratorBase& iter, + const Scalar& alpha, + const Scalar& scale, + const Scalar& input_scale, + bool is_result) { + AT_DISPATCH_FLOATING_TYPES_AND2(c10::kHalf, c10::kBFloat16, iter.common_dtype(), "elu_backward_mps", [&]() { + ELUBackwardParams params{ + alpha.to(), scale.to(), input_scale.to(), is_result}; + lib.exec_binary_kernel_with_params( + iter, + "elu_backward", + params, + fmt::format("ELUBackwardParams_{}", mps::scalarToMetalTypeString(iter.common_dtype()))); + }); +} + static void leaky_relu_kernel(TensorIteratorBase& iter, const Scalar& negative_slope) { lib.exec_unary_kernel(iter, "leaky_relu", negative_slope); } @@ -56,6 +82,8 @@ static void leaky_relu_backward_kernel(TensorIteratorBase& iter, const Scalar& n REGISTER_DISPATCH(hardsigmoid_backward_stub, hardsigmoid_backward_kernel); REGISTER_DISPATCH(hardswish_stub, hardswish_kernel); REGISTER_DISPATCH(hardswish_backward_stub, hardswish_backward_kernel); +REGISTER_DISPATCH(elu_stub, elu_kernel); +REGISTER_DISPATCH(elu_backward_stub, elu_backward_kernel); REGISTER_DISPATCH(leaky_relu_stub, leaky_relu_kernel); REGISTER_DISPATCH(leaky_relu_backward_stub, leaky_relu_backward_kernel); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 9a1c7c790afaa..fd88794d38f52 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -12064,8 +12064,7 @@ device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA: elu_out - MPS: elu_out_mps + CPU, CUDA, MPS: elu_out - func: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor structured_delegate: elu.out @@ -12078,8 +12077,7 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: elu_backward_out - MPS: elu_backward_out_mps + CPU, CUDA, MPS: elu_backward_out - func: elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result) -> Tensor structured_delegate: elu_backward.grad_input From 1efc14a50d56701e3ad0639849bc26da59e7a3c3 Mon Sep 17 00:00:00 2001 From: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com> Date: Wed, 19 Nov 2025 00:06:09 +0000 Subject: [PATCH 018/230] [ROCm][CI] Update concurrency setting for docker-cache-rocm.yml (#168104) We only want to cache the latest CI docker image for `main` and `release` branches in cases where multiple `docker-builds` workflow runs get triggered in quick succession. This is because the latest run will anyway overwrite the cached images, since we do not maintain a cached image per-SHA, instead it's only one-per-branch (to minimize cache size and docker load times at runner bringup). Also removing `workflow_dispatch` as a trigger since it won't work (needs artifacts from `docker-builds` run) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168104 Approved by: https://github.com/jeffdaily --- .github/workflows/docker-cache-rocm.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/docker-cache-rocm.yml b/.github/workflows/docker-cache-rocm.yml index 8039b4f71087b..c973656018944 100644 --- a/.github/workflows/docker-cache-rocm.yml +++ b/.github/workflows/docker-cache-rocm.yml @@ -6,10 +6,9 @@ on: branches: [main, release] types: - completed - workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }} + group: ${{ github.workflow }}-${{ github.event.workflow_run.head_branch }} cancel-in-progress: true permissions: From a4e0720fe251215235392f186d091f2cb4e4ac0b Mon Sep 17 00:00:00 2001 From: RajeshvShiyal Date: Wed, 19 Nov 2025 00:15:32 +0000 Subject: [PATCH 019/230] typo corrected in type.cpp (#167907) Fixes #167905 Below typo correction has been done. Existing comment: // List of Any can contains heterogenous types Suggested comment: // List of Any can contains heterogeneous types Pull Request resolved: https://github.com/pytorch/pytorch/pull/167907 Approved by: https://github.com/albanD --- aten/src/ATen/core/type.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index 46dc550b1f37b..35a729ccc9f39 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -680,7 +680,7 @@ TORCH_API bool elementTypeCanBeInferredFromMembers(const TypePtr& elem_type) { return false; } if (elem_type->kind() == AnyType::Kind) { - // List of Any can contains heterogenous types + // List of Any can contains heterogeneous types return false; } return true; From a369a5672653f9e00e23bf7d000c217339d078c0 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Wed, 19 Nov 2025 00:36:47 +0000 Subject: [PATCH 020/230] [ROCm][CI] forward fix libtorch agnostic tests (#168087) Unclear which PR in the ghstack caused the ROCm failure. Stack was (oldest at bottom): - #167962 - #167804 - #167803 - #167802 - #168025 Fixes the following test: PYTORCH_TEST_WITH_ROCM=1 python test/cpp_extensions/libtorch_agnostic_2_10_extension/test_version_compatibility.py FunctionVersionCompatibilityTest.test_mv_tensor_accessor_cuda_works_with_2_9 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168087 Approved by: https://github.com/jeffdaily, https://github.com/janeyx99 Co-authored-by: Jeff Daily Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com> --- .../csrc/mv_tensor_accessor_cuda.cu | 4 ++++ .../test_version_compatibility.py | 21 +++++++++++++------ .../libtorch_agnostic_2_9/csrc/cuda_kernel.cu | 4 ++++ 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/mv_tensor_accessor_cuda.cu b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/mv_tensor_accessor_cuda.cu index 7773210a089ee..f8d87f60d9a2e 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/mv_tensor_accessor_cuda.cu +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/mv_tensor_accessor_cuda.cu @@ -3,7 +3,11 @@ #include "tensor_accessor_kernel.h" +#ifdef USE_ROCM +#include +#else #include +#endif #include #include #include diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/test_version_compatibility.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/test_version_compatibility.py index a094c57f8e614..05027a41b6715 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/test_version_compatibility.py +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/test_version_compatibility.py @@ -22,9 +22,15 @@ from pathlib import Path from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase -from torch.utils.cpp_extension import CUDA_HOME, include_paths as torch_include_paths +from torch.utils.cpp_extension import ( + CUDA_HOME, + include_paths as torch_include_paths, + ROCM_HOME, +) +GPU_HOME = CUDA_HOME or ROCM_HOME + # TODO: Fix this error in Windows: # numba.cuda.cudadrv.driver:driver.py:384 Call to cuInit results in CUDA_ERROR_NO_DEVICE if not IS_WINDOWS: @@ -42,8 +48,8 @@ def setUpClass(cls): f"-I{path}" for path in torch_include_paths(device_type="cpu") ] cls.cuda_includes = [] - if CUDA_HOME: - cuda_include_path = os.path.join(CUDA_HOME, "include") + if GPU_HOME: + cuda_include_path = os.path.join(GPU_HOME, "include") if os.path.exists(cuda_include_path): cls.cuda_includes = [f"-I{cuda_include_path}"] @@ -105,13 +111,13 @@ def _compile_cu_file( Compile a CUDA file with TORCH_TARGET_VERSION=2.9.0. Returns (success, error_message). """ - if not CUDA_HOME: - return False, "CUDA_HOME not set" + if not GPU_HOME: + return False, "one of CUDA_HOME and ROCM_HOME should be set but is not" torch_version_2_9 = "0x0209000000000000" cmd = [ - os.path.join(CUDA_HOME, "bin", "nvcc"), + os.path.join(GPU_HOME, "bin", "nvcc" if CUDA_HOME else "hipcc"), "-c", "-std=c++17", f"-DTORCH_TARGET_VERSION={torch_version_2_9}", @@ -120,6 +126,9 @@ def _compile_cu_file( *self.cuda_includes, ] + if ROCM_HOME: + cmd.extend(["-DUSE_ROCM=1"]) + cmd.extend([str(source_file), "-o", str(output_file)]) result = subprocess.run(cmd, capture_output=True, text=True, timeout=30) diff --git a/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/cuda_kernel.cu b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/cuda_kernel.cu index 88c19d0ebf062..1f549630262a6 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/cuda_kernel.cu +++ b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/cuda_kernel.cu @@ -1,6 +1,10 @@ #include "kernel.h" +#ifdef USE_ROCM +#include +#else #include +#endif #include #include #include From 878757cb664da5830631ae72be0f602ed3bbb5bc Mon Sep 17 00:00:00 2001 From: Wei Wang Date: Wed, 19 Nov 2025 00:50:25 +0000 Subject: [PATCH 021/230] [CI][CUDA] Unskip nvshmem triton tests (#167760) Fixes false negative (illusion): "all B200 periodic nvshmem-triton tests passed" Pull Request resolved: https://github.com/pytorch/pytorch/pull/167760 Approved by: https://github.com/ngimel --- test/distributed/test_nvshmem_triton.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/test/distributed/test_nvshmem_triton.py b/test/distributed/test_nvshmem_triton.py index 3fec9a01f049c..ad30a7df5d43a 100644 --- a/test/distributed/test_nvshmem_triton.py +++ b/test/distributed/test_nvshmem_triton.py @@ -12,7 +12,6 @@ import torch.distributed._symmetric_memory._nvshmem_triton as nvshmem from torch._inductor.runtime.triton_compat import triton from torch.distributed._symmetric_memory._nvshmem_triton import requires_nvshmem -from torch.testing._internal.common_cuda import SM100OrLater from torch.testing._internal.common_distributed import MultiProcContinuousTest from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -265,10 +264,6 @@ def my_reduce_kernel( nvshmem.reduce(team_handle, dest_tensor, source_tensor, nreduce, operation) -@skip_but_pass_in_sandcastle_if( - SM100OrLater, - "Skipping all NVSHMEM Triton tests due to https://github.com/pytorch/pytorch/issues/162897", -) @instantiate_parametrized_tests class NVSHMEMTritonTest(MultiProcContinuousTest): def _init_device(self) -> None: From c8d790b56d7bbb4ee266976944ab920f2d513b40 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Tue, 18 Nov 2025 21:24:02 +0000 Subject: [PATCH 022/230] [xpu][fix] Fix empty cache on mempool (#168074) # Motivation This is definitely a bug: we were attempting to release cached memory back to the system without proper **synchronization**. Callers must ensure that all accesses to memory blocks allocated by SYCL APIs have completed before invoking `sycl::free`. For a simple example, in the following code: ```python pool = torch.xpu.MemPool() with torch.xpu.use_mem_pool(pool): input = torch.randn(100, device='xpu') sum = input.sum() del pool print(sum) ``` `sum` may exhibit undefined behavior because `input.sum()` might not have finished executing before `del pool` triggers `input`'s memory release. With this fix, we ensure that all kernels on the associated streams complete before the memory pool is destroyed, guaranteeing that `sum` holds the correct value. # Solution Because `c10::xpu::syncStreamsOnDevice` has host overhead, we use a boolean flag `streams_synced` to ensure it is called only once. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168074 Approved by: https://github.com/EikanWang --- c10/xpu/XPUCachingAllocator.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index 3bd9eff0fee63..b8838d9b7ee35 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -893,11 +893,13 @@ class DeviceCachingAllocator { } bool release_cached_blocks(MempoolId_t mempool_id) { + bool streams_synced = false; if (mempool_id.first == 0 && mempool_id.second == 0 && captures_underway.empty()) { synchronize_and_free_events(); // See Note [Safe to Free Blocks on BlockPool] c10::xpu::syncStreamsOnDevice(device_index); + streams_synced = true; release_blocks(large_blocks); release_blocks(small_blocks); @@ -916,6 +918,12 @@ class DeviceCachingAllocator { continue; } } + + if (!streams_synced) { + // See Note [Safe to Free Blocks on BlockPool] + c10::xpu::syncStreamsOnDevice(device_index); + streams_synced = true; + } TORCH_INTERNAL_ASSERT(it->second->use_count == 0); release_blocks(it->second->small_blocks); release_blocks(it->second->large_blocks); From 8f161997b1e0d8546fec92a9bc88b17104e5bfa5 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Wed, 19 Nov 2025 00:40:42 +0200 Subject: [PATCH 023/230] Fix stable ABI to/from deprecation warnings. Add my_shape test. (#167923) As in the title. The my_shape test is added to reproduce https://github.com/pytorch/audio/actions/runs/19395471276/job/55494871226: Pull Request resolved: https://github.com/pytorch/pytorch/pull/167923 Approved by: https://github.com/janeyx99, https://github.com/mikaylagawarecki --- .../libtorch_agnostic_2_10/csrc/my_shape.cpp | 20 +++++++++++++++++++ .../libtorch_agnostic_2_10/ops.py | 12 +++++++++++ test/cpp_extensions/test_libtorch_agnostic.py | 9 +++++++++ torch/csrc/stable/stableivalue_conversions.h | 6 +++--- 4 files changed, 44 insertions(+), 3 deletions(-) create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_shape.cpp diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_shape.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_shape.cpp new file mode 100644 index 0000000000000..c560fb0a60af9 --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_shape.cpp @@ -0,0 +1,20 @@ +#include +#include +#include + +using torch::stable::Tensor; + +torch::headeronly::HeaderOnlyArrayRef my_shape(Tensor t) { + return t.sizes(); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def("my_shape(Tensor t) -> int[]"); +} + +STABLE_TORCH_LIBRARY_IMPL( + libtorch_agnostic_2_10, + CompositeExplicitAutograd, + m) { + m.impl("my_shape", TORCH_BOX(&my_shape)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py index db1a4fd43033c..a740df8c9e25f 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py @@ -199,6 +199,18 @@ def my_view(t, size) -> Tensor: return torch.ops.libtorch_agnostic_2_10.my_view.default(t, size) +def my_shape(t) -> tuple[int]: + """ + Returns a shape of the input tensor. + + Args: + t: Tensor - input tensor + + Returns: tuple - shape of the imput tensor. + """ + return torch.ops.libtorch_agnostic_2_10.my_shape.default(t) + + def get_any_data_ptr(t, mutable) -> int: """ Return data pointer value of the tensor. diff --git a/test/cpp_extensions/test_libtorch_agnostic.py b/test/cpp_extensions/test_libtorch_agnostic.py index 48ede590cecbf..ef92fc316daa7 100644 --- a/test/cpp_extensions/test_libtorch_agnostic.py +++ b/test/cpp_extensions/test_libtorch_agnostic.py @@ -711,6 +711,15 @@ def test_my_view(self, device): expected_flat = t.view([-1]) self.assertEqual(result_flat, expected_flat) + @skipIfTorchVersionLessThan(2, 10) + def test_my_shape(self, device): + import libtorch_agnostic_2_10 as libtorch_agnostic + + expected = (3, 5) + t = torch.rand(*expected, device=device) + shape = libtorch_agnostic.ops.my_shape(t) + self.assertEqual(shape, expected) + def test_mv_tensor_accessor(self, device): import libtorch_agnostic_2_9 as libtorch_agnostic diff --git a/torch/csrc/stable/stableivalue_conversions.h b/torch/csrc/stable/stableivalue_conversions.h index 38fc4fe4cc8bf..0e09eeb7f7b14 100644 --- a/torch/csrc/stable/stableivalue_conversions.h +++ b/torch/csrc/stable/stableivalue_conversions.h @@ -281,10 +281,10 @@ struct FromImpl> { TORCH_ERROR_CODE_CHECK( torch_new_list_reserve_size(val.size(), &new_list_handle)); for (const auto& elem : val) { - TORCH_ERROR_CODE_CHECK( - torch_list_push_back(new_list_handle, from(elem))); + TORCH_ERROR_CODE_CHECK(torch_list_push_back( + new_list_handle, torch::stable::detail::from(elem))); } - return from(new_list_handle); + return torch::stable::detail::from(new_list_handle); } catch (const std::runtime_error&) { if (new_list_handle != nullptr) { // clean up memory if an error was thrown From b8a3165d28b672ac6d84128e66265bf471b92a55 Mon Sep 17 00:00:00 2001 From: "Ma, Jing1" Date: Wed, 19 Nov 2025 02:41:36 +0000 Subject: [PATCH 024/230] [2/3][XPU][feature] The implementation of MemPool for XPU (#166833) The implementation plan of MemPool for XPU, which is the dependance of [XPUGraph](https://github.com/pytorch/pytorch/pull/166285), following the [RFC](https://github.com/pytorch/pytorch/issues/162143). - [ ] #166831 - [ ] ->#166833 - [ ] #166843 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166833 Approved by: https://github.com/EikanWang, https://github.com/gujinghui --- c10/xpu/XPUCachingAllocator.cpp | 176 ++++++++++++++++++++++++++++++++ c10/xpu/XPUCachingAllocator.h | 55 ++++++++++ 2 files changed, 231 insertions(+) diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index b8838d9b7ee35..d7eeb10caba1b 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -1227,6 +1227,63 @@ class DeviceCachingAllocator { allowed_memory_maximum = static_cast(fraction * device_total); set_fraction = true; } + + void createOrIncrefPool( + MempoolId_t mempool_id, + XPUAllocator* allocator = nullptr) { + std::scoped_lock lock(mutex); + create_or_incref_pool(mempool_id, allocator); + } + + int getPoolUseCount(MempoolId_t mempool_id) { + std::scoped_lock lock(mutex); + auto it = graph_pools.find(mempool_id); + if (it == graph_pools.end()) { + return 0; + } + return it->second->use_count; + } + + // Called by XPUGraph::capture_begin + void beginAllocateToPool( + MempoolId_t mempool_id, + std::function filter) { + std::lock_guard lock(mutex); + create_or_incref_pool(mempool_id); + auto not_found = std::all_of( + captures_underway.begin(), + captures_underway.end(), + [&](const auto& entry) { return entry.first != mempool_id; }); + TORCH_CHECK( + not_found, "beginAllocateToPool: already recording to mempool_id"); + captures_underway.emplace_back(mempool_id, std::move(filter)); + } + + // Called by XPUGraph::capture_end + void endAllocateToPool(MempoolId_t mempool_id) { + std::lock_guard lock(mutex); + + auto it = std::find_if( + captures_underway.begin(), + captures_underway.end(), + [&](const auto& entry) { return entry.first == mempool_id; }); + TORCH_INTERNAL_ASSERT( + it != captures_underway.end(), + "endAllocatePool: not currently recording to mempool_id"); + captures_underway.erase(it); + } + + // Called by XPUGraph::reset and MemPool::~MemPool() + void releasePool(MempoolId_t mempool_id) { + std::lock_guard lock(mutex); + auto pp = get_private_pool(mempool_id); + auto uc = --(pp->use_count); + TORCH_INTERNAL_ASSERT(uc >= 0); + if (uc == 0) { + bool inserted = graph_pools_freeable.insert({mempool_id, pp}).second; + TORCH_INTERNAL_ASSERT(inserted); + } + } }; static void local_raw_delete(void* ptr); @@ -1416,6 +1473,39 @@ class XPUAllocator : public DeviceAllocator { ". Please set within (0, 1]."); device_allocators[device]->setMemoryFraction(fraction); } + + void createOrIncrefPool( + c10::DeviceIndex device, + MempoolId_t mempool_id, + XPUAllocator* allocator) { + assertValidDevice(device); + device_allocators[device]->createOrIncrefPool( + std::move(mempool_id), allocator); + } + + void beginAllocateToPool( + c10::DeviceIndex device, + MempoolId_t mempool_id, + std::function filter) { + assertValidDevice(device); + device_allocators[device]->beginAllocateToPool( + std::move(mempool_id), std::move(filter)); + } + + void endAllocateToPool(c10::DeviceIndex device, MempoolId_t mempool_id) { + assertValidDevice(device); + device_allocators[device]->endAllocateToPool(mempool_id); + } + + void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) { + assertValidDevice(device); + device_allocators[device]->releasePool(std::move(mempool_id)); + } + + int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) { + assertValidDevice(device); + return device_allocators[device]->getPoolUseCount(std::move(mempool_id)); + } }; static XPUAllocator allocator; @@ -1472,6 +1562,92 @@ void setMemoryFraction(double fraction, DeviceIndex device) { return allocator.setMemoryFraction(fraction, device); } +void createOrIncrefPool( + c10::DeviceIndex device, + MempoolId_t mempool_id, + XPUAllocator* allocator_ptr) { + return allocator.createOrIncrefPool(device, mempool_id, allocator_ptr); +} + +void beginAllocateToPool( + c10::DeviceIndex device, + MempoolId_t mempool_id, + std::function filter) { + return allocator.beginAllocateToPool(device, mempool_id, std::move(filter)); +} + +void endAllocateToPool(c10::DeviceIndex device, MempoolId_t mempool_id) { + return allocator.endAllocateToPool(device, mempool_id); +} + +void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) { + return allocator.releasePool(device, mempool_id); +} + +int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) { + return allocator.getPoolUseCount(device, mempool_id); +} + REGISTER_ALLOCATOR(kXPU, &allocator) } // namespace c10::xpu::XPUCachingAllocator + +namespace c10::xpu { + +// uid_ is incremented when a user creates a MemPool, +// +// uuid_ is incremented when XPUGraph creates a MemPool +// as a result of a user not providing a pool. + +std::atomic MemPool::uid_{1}; +std::atomic MemPool::uuid_{1}; + +MemPool::MemPool( + XPUCachingAllocator::XPUAllocator* allocator, + bool is_user_created, + bool use_on_oom) + : allocator_(allocator), is_user_created_(is_user_created) { + if (is_user_created_) { + id_ = {0, uid_++}; + } else { + id_ = {uuid_++, 0}; + } + device_ = c10::xpu::current_device(); + XPUCachingAllocator::createOrIncrefPool(device_, id_, allocator); + if (use_on_oom) { + // XPU doesn't support use_on_oom yet + TORCH_WARN( + "XPUCachingAllocator::MemPool: use_on_oom is not supported on XPU"); + } +} + +MemPool::~MemPool() { + TORCH_INTERNAL_ASSERT(use_count() == 1); + XPUCachingAllocator::releasePool(device_, id_); + c10::xpu::XPUCachingAllocator::emptyCache(id_); // release cached blocks +} + +MempoolId_t MemPool::id() { + return id_; +} + +XPUCachingAllocator::XPUAllocator* MemPool::allocator() { + return allocator_; +} + +int MemPool::use_count() { + return XPUCachingAllocator::getPoolUseCount(device_, id_); +} + +c10::DeviceIndex MemPool::device() { + return device_; +} + +MempoolId_t MemPool::graph_pool_handle(bool is_user_created) { + if (is_user_created) { + return {0, uid_++}; + } + return {uuid_++, 0}; +} + +} // namespace c10::xpu diff --git a/c10/xpu/XPUCachingAllocator.h b/c10/xpu/XPUCachingAllocator.h index bbb20a5b2ecdf..c55de309032e0 100644 --- a/c10/xpu/XPUCachingAllocator.h +++ b/c10/xpu/XPUCachingAllocator.h @@ -33,4 +33,59 @@ C10_XPU_API double getMemoryFraction(DeviceIndex device); C10_XPU_API void setMemoryFraction(double fraction, DeviceIndex device); +class XPUAllocator; + +C10_XPU_API void createOrIncrefPool( + c10::DeviceIndex device, + c10::MempoolId_t mempool_id, + XPUAllocator* allocator = nullptr); + +C10_XPU_API void beginAllocateToPool( + c10::DeviceIndex device, + c10::MempoolId_t mempool_id, + std::function filter); + +C10_XPU_API void endAllocateToPool( + c10::DeviceIndex device, + c10::MempoolId_t mempool_id); + +C10_XPU_API void releasePool( + c10::DeviceIndex device, + c10::MempoolId_t mempool_id); + +C10_XPU_API int getPoolUseCount( + c10::DeviceIndex device, + c10::MempoolId_t mempool_id); + } // namespace c10::xpu::XPUCachingAllocator + +namespace c10::xpu { + +using c10::CaptureId_t; +using c10::MempoolId_t; +struct C10_XPU_API MemPool { + MemPool( + XPUCachingAllocator::XPUAllocator* allocator = nullptr, + bool is_user_created = true, + bool use_on_oom = false); + MemPool(const MemPool&) = delete; + MemPool(MemPool&&) = default; + MemPool& operator=(const MemPool&) = delete; + MemPool& operator=(MemPool&&) = default; + ~MemPool(); + + MempoolId_t id(); + XPUCachingAllocator::XPUAllocator* allocator(); + int use_count(); + c10::DeviceIndex device(); + static MempoolId_t graph_pool_handle(bool is_user_created = true); + + private: + static std::atomic uid_; + static std::atomic uuid_; + XPUCachingAllocator::XPUAllocator* allocator_; + bool is_user_created_; + MempoolId_t id_; + c10::DeviceIndex device_; +}; +} // namespace c10::xpu From cdca10b2753909d1eaeb096c4e91c47add3935b9 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Mon, 17 Nov 2025 20:00:48 -0800 Subject: [PATCH 025/230] [AOTI] Fix a GPU memory leak caused by reference circle (#168063) Summary: Fix https://github.com/pytorch/pytorch/issues/167630. There was a reference circle between GraphLowering and CppWrapperCpu due to caching, which makes GraphLowering unnecessarily hold some contant tensors causing GPU memory leaks. This PR fixes that by changing the cache to use the object id of GraphLowering as a part of the key. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168063 Approved by: https://github.com/yushangdi --- test/inductor/test_aot_inductor.py | 44 ++++++++++++++++++++++ torch/_inductor/codegen/cpp_wrapper_cpu.py | 24 +++++++++++- 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 5f0447c32264e..69f5eb92b58ce 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -7437,6 +7437,50 @@ def forward(self, x): "RAIIAtenTensorHandle buf0(buf0_handle_restrided);" ).run(code) + def test_codegen_int_array_var_fix_memory_leak(self): + """ + Fix https://github.com/pytorch/pytorch/issues/167630 + """ + if self.device != "cuda": + raise unittest.SkipTest("test is only for cuda") + + def make_mlp(in_dim=128, hidden=256, out_dim=64, depth=3): + layers = [] + d = in_dim + for _ in range(depth): + layers += [nn.Linear(d, hidden), nn.ReLU()] + d = hidden + layers += [nn.Linear(d, out_dim)] + return nn.Sequential(*layers) + + batch = 32 + in_dim = 2048 + hidden = 512 + out_dim = 10 + depth = 6 + + import gc + + allocated_memory = [] + for _ in range(3): + torch.cuda.reset_peak_memory_stats() + + model = make_mlp(in_dim, hidden, out_dim, depth).to(self.device) + example_inputs = (torch.randn(batch, in_dim, device=self.device),) + ep = torch.export.export( + model, + example_inputs, + ) + torch._inductor.aoti_compile_and_package(ep) + + del model, example_inputs, ep + torch.cuda.synchronize() + torch.cuda.empty_cache() + gc.collect() + allocated_memory.append(torch.cuda.memory_allocated()) + + self.assertTrue(allocated_memory[1] == allocated_memory[2]) + @unittest.skipIf(IS_MACOS, "might have no readelf on Mac") def test_libtorch_free_so(self): class Model(torch.nn.Module): diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 61a97fd740cbc..65d356dce0979 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -96,6 +96,7 @@ def __init__(self): self.include_extra_header = functools.lru_cache(None)( # type: ignore[method-assign] self._include_extra_header ) + self.codegen_int_array_var_cache = {} @staticmethod def create( @@ -1636,14 +1637,33 @@ def codegen_memory_format(self, memory_format): self.used_cached_memory_formats.add(memory_format_str) return f"cached_torch_memory_format_{memory_format_str}" - @functools.cache # noqa: B019 def codegen_int_array_var( self, int_array: str, writeline: Callable[..., None], known_statically=False, graph=None, # for per-graph caching - ): + ) -> str: + # Use id(graph) for caching to avoid circular references + cache_key = ( + int_array, + id(writeline), + known_statically, + id(graph) if graph else None, + ) + if cache_key not in self.codegen_int_array_var_cache: + self.codegen_int_array_var_cache[cache_key] = ( + self._codegen_int_array_var_impl(int_array, writeline, known_statically) + ) + + return self.codegen_int_array_var_cache[cache_key] + + def _codegen_int_array_var_impl( + self, + int_array: str, + writeline: Callable[..., None], + known_statically: bool, + ) -> str: # Used for size/stride declaration # # Because the memory planning is done in two passes (see the implementation From cea86781f291c10e72d0aaef5893ec6c823cb9fd Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Thu, 13 Nov 2025 13:52:49 -0800 Subject: [PATCH 026/230] [CD] Add `cuda-bindings` dependency to CUDA wheels (#167769) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167769 Approved by: https://github.com/ngimel, https://github.com/leofang --- .../scripts/generate_binary_build_matrix.py | 4 ++ ...linux-aarch64-binary-manywheel-nightly.yml | 56 +++++++++---------- ...nerated-linux-binary-manywheel-nightly.yml | 56 +++++++++---------- 3 files changed, 60 insertions(+), 56 deletions(-) diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index f7df4335cb5b6..d69db191b9464 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -50,6 +50,7 @@ PYTORCH_EXTRA_INSTALL_REQUIREMENTS = { "12.6": ( + "cuda-bindings==12.9.4; platform_system == 'Linux' | " "nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | " "nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | " "nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | " @@ -67,6 +68,7 @@ "nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux'" ), "12.8": ( + "cuda-bindings==12.9.4; platform_system == 'Linux' | " "nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | " "nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | " "nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | " @@ -84,6 +86,7 @@ "nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux'" ), "12.9": ( + "cuda-bindings==12.9.4; platform_system == 'Linux' | " "nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | " "nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | " "nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | " @@ -101,6 +104,7 @@ "nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'" ), "13.0": ( + "cuda-bindings==13.0.3; platform_system == 'Linux' | " "nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | " "nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | " "nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | " diff --git a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml index b8a6403faffbd..6a22e14af09b7 100644 --- a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml @@ -132,7 +132,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_10-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -178,7 +178,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_10-cuda-aarch64-12_8 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -224,7 +224,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_10-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -270,7 +270,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_10-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -381,7 +381,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_11-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -427,7 +427,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_11-cuda-aarch64-12_8 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -473,7 +473,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_11-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -519,7 +519,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_11-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -630,7 +630,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_12-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -676,7 +676,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_12-cuda-aarch64-12_8 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -722,7 +722,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_12-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -768,7 +768,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_12-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -879,7 +879,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_13-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -925,7 +925,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_13-cuda-aarch64-12_8 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -971,7 +971,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_13-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1017,7 +1017,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_13-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1128,7 +1128,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_13t-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1174,7 +1174,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_13t-cuda-aarch64-12_8 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1220,7 +1220,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_13t-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1266,7 +1266,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_13t-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1377,7 +1377,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_14-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1423,7 +1423,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_14-cuda-aarch64-12_8 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1469,7 +1469,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_14-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1515,7 +1515,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_14-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1626,7 +1626,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_14t-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1672,7 +1672,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_14t-cuda-aarch64-12_8 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1718,7 +1718,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_14t-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1764,7 +1764,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_14t-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index 21c1d5caa3829..a5f4e85ca58c1 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -127,7 +127,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda12_6-test: # Testing @@ -193,7 +193,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda12_8-test: # Testing @@ -259,7 +259,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda12_9-test: # Testing @@ -325,7 +325,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda13_0-test: # Testing @@ -793,7 +793,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda12_6-test: # Testing @@ -859,7 +859,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda12_8-test: # Testing @@ -925,7 +925,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda12_9-test: # Testing @@ -991,7 +991,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda13_0-test: # Testing @@ -1459,7 +1459,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda12_6-test: # Testing @@ -1525,7 +1525,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda12_8-test: # Testing @@ -1591,7 +1591,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda12_9-test: # Testing @@ -1657,7 +1657,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda13_0-test: # Testing @@ -2125,7 +2125,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-cuda12_6-test: # Testing @@ -2191,7 +2191,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-cuda12_8-test: # Testing @@ -2257,7 +2257,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-cuda12_9-test: # Testing @@ -2323,7 +2323,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-cuda13_0-test: # Testing @@ -2791,7 +2791,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13t-cuda12_6-test: # Testing @@ -2857,7 +2857,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13t-cuda12_8-test: # Testing @@ -2923,7 +2923,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13t-cuda12_9-test: # Testing @@ -2989,7 +2989,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13t-cuda13_0-test: # Testing @@ -3457,7 +3457,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14-cuda12_6-test: # Testing @@ -3523,7 +3523,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14-cuda12_8-test: # Testing @@ -3589,7 +3589,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14-cuda12_9-test: # Testing @@ -3655,7 +3655,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14-cuda13_0-test: # Testing @@ -4123,7 +4123,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14t-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14t-cuda12_6-test: # Testing @@ -4189,7 +4189,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14t-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14t-cuda12_8-test: # Testing @@ -4255,7 +4255,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14t-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14t-cuda12_9-test: # Testing @@ -4321,7 +4321,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14t-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14t-cuda13_0-test: # Testing From 13ec55d15b64e00312183f9b3dac628a9c8cf1be Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Wed, 19 Nov 2025 04:33:07 +0000 Subject: [PATCH 027/230] Update AGENTS.md (#168111) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/168111 Approved by: https://github.com/ezyang --- AGENTS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/AGENTS.md b/AGENTS.md index 3d5436a02a85d..718217d3e663d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -10,6 +10,7 @@ - Do NOT run pre-commit, it is not setup - To run lint, run 'lintrunner -a' (which will autoapply changes) - Do NOT attempt to install dependencies, you do not have Internet access +- Do NOT create summary files unless explicitly asked - When you are ready to make a PR, do exactly these steps: - git stash -u - git reset --hard $(cat /tmp/orig_work.txt) # NB: reset to the LOCAL branch, do NOT fetch From 65f08eeec1c5de7511688319c059c770a6edf119 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 18 Nov 2025 20:39:09 -0800 Subject: [PATCH 028/230] [MPS][1/N] Fix unsupported dtypes error checking for some MPS ops (#166273) Partially vibe-coded with ClaudeCode, and changes following ops (summary also created by Claude): - **Activation operations**: Added checks rejecting Long, Complex, and Bool types for operations like log_softmax, log_sigmoid, mish, softplus, and silu, as MPS doesn't support exponent operations on these types - **Linear algebra operations**: Restricted linalg_lu_factor, linalg_solve, and linalg_solve_triangular to Float type only (previously only checked for complex types) - **Pooling operations**: Added checks to reject Complex types for avg_pool2d and max_pool2d operations - **Loss functions**: Added type checks for nll_loss (Complex), huber_loss (Long, Complex), and grid_sampler_2d (Complex) - **Reduction operations**: - Fixed NANSUM to handle integral types correctly (can't contain NaN, so just performs regular sum) - Added Long type check for std/var operations - **Other operations**: - softmax: Now explicitly requires floating point types - bincount: Rejects Bool type to prevent crashes All checks use `TORCH_CHECK_NOT_IMPLEMENTED` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166273 Approved by: https://github.com/manuelcandales --- .../ATen/native/mps/operations/Activation.mm | 16 +++++++++++ .../ATen/native/mps/operations/GridSampler.mm | 5 ++++ .../native/mps/operations/LinearAlgebra.mm | 8 +++--- .../src/ATen/native/mps/operations/LossOps.mm | 5 ++++ .../src/ATen/native/mps/operations/Pooling.mm | 5 ++++ .../ATen/native/mps/operations/ReduceOps.mm | 28 +++++++++++-------- .../src/ATen/native/mps/operations/SoftMax.mm | 1 + .../ATen/native/mps/operations/SummaryOps.mm | 4 +++ 8 files changed, 57 insertions(+), 15 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/Activation.mm b/aten/src/ATen/native/mps/operations/Activation.mm index 64ef972b55530..802c648c888d5 100644 --- a/aten/src/ATen/native/mps/operations/Activation.mm +++ b/aten/src/ATen/native/mps/operations/Activation.mm @@ -117,6 +117,10 @@ Tensor relu_mps(const Tensor& self) { TORCH_IMPL_FUNC(log_softmax_mps_out) (const Tensor& self, const int64_t dim, const bool half_to_float, const Tensor& out) { + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64"); + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()), + "log_softmax for complex is not supported for MPS"); + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kBool, "log_softmax for bool is not supported for MPS"); using namespace mps; using CachedGraph = MPSUnaryCachedGraph; @@ -160,6 +164,10 @@ Tensor relu_mps(const Tensor& self) { TORCH_IMPL_FUNC(log_softmax_backward_mps_out) (const Tensor& grad_output, const Tensor& output, int64_t dim, ScalarType input_dtype, const Tensor& out) { + TORCH_CHECK_NOT_IMPLEMENTED(grad_output.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64"); + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(grad_output.scalar_type()), + "log_softmax for complex is not supported for MPS"); + TORCH_CHECK_NOT_IMPLEMENTED(grad_output.scalar_type() != kBool, "log_softmax for bool is not supported for MPS"); using namespace mps; using CachedGraph = MPSUnaryGradCachedGraph; @@ -200,6 +208,7 @@ Tensor relu_mps(const Tensor& self) { } std::tuple log_sigmoid_forward_out_mps(const Tensor& self, Tensor& output, Tensor& buffer) { + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64"); // NOTE: buffer is only used by CPU dispatch, we just ignore it here using namespace mps; using CachedGraph = MPSUnaryCachedGraph; @@ -706,6 +715,7 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c if (output.numel() == 0) return; + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64"); // this can't pass anyway because a 0-dimensional tensor has "size" 1, which // can't be evenly halved, but give a nicer error message here. TORCH_CHECK(self.dim() > 0, "glu does not support 0-dimensional tensors"); @@ -819,6 +829,7 @@ Tensor glu_backward_mps(const Tensor& grad_output, const Tensor& self, const int (const Tensor& self, const Scalar& beta, const Scalar& threshold, const Tensor& result) { using namespace mps; TORCH_CHECK(self.is_mps()); + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "Not implemented for long"); // Applies the Softplus function :math:`\text{Softplus}(x) = \frac{1}{\beta} * // \log(1 + \exp(\beta * x))` element-wise. // For numerical stability the implementation reverts to the linear function @@ -969,6 +980,8 @@ Tensor glu_backward_mps(const Tensor& grad_output, const Tensor& self, const int (const Tensor& self, const Tensor& result) { using namespace mps; TORCH_CHECK(self.is_mps()); + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64"); + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()), "Mish for complex is not supported for MPS"); if (result.numel() == 0) return; @@ -1017,6 +1030,8 @@ Tensor glu_backward_mps(const Tensor& grad_output, const Tensor& self, const int Tensor mish_backward_mps(const Tensor& grad_output, const Tensor& self) { using namespace mps; TORCH_CHECK(self.is_mps()); + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64"); + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()), "Mish for complex is not supported for MPS"); Tensor grad_input = at::empty_like(self, self.suggest_memory_format()); if (grad_input.numel() == 0) @@ -1206,6 +1221,7 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { using CachedGraph = MPSUnaryCachedGraph; TORCH_CHECK(self.is_mps()); + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64"); // Empty output if (result.numel() == 0) diff --git a/aten/src/ATen/native/mps/operations/GridSampler.mm b/aten/src/ATen/native/mps/operations/GridSampler.mm index 92f2b9c6fbf74..d75456c1ad3f0 100644 --- a/aten/src/ATen/native/mps/operations/GridSampler.mm +++ b/aten/src/ATen/native/mps/operations/GridSampler.mm @@ -80,6 +80,11 @@ static void grid_sampler_2d_mps_impl(Tensor& output, MPSGraphTensor* outputTensor_ = nil; }; + // Crashes with + // MPSGraphUtilities.mm:97:0: error: 'mps.sample_grid' op operand #0 must be tensor of mps native type values, but got + // 'tensor<2x3x5x20xcomplex>' + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()), + "grid_sampler_2d is not supported for complex on MPS"); @autoreleasepool { std::string key = "grid_sampler_2d_mps" + getTensorsStringKey({input, grid}) + ":" + std::to_string(interpolation_mode) + ":" + std::to_string(padding_mode) + ":" + std::to_string(align_corners); diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index ca19d121bb718..00f9c96b78af8 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -240,7 +240,7 @@ static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A, bool check_errors) { using namespace mps; - TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()), + TORCH_CHECK(A.scalar_type() == kFloat && LU.scalar_type() == kFloat, "linalg.lu_factor(): MPS doesn't support complex types."); TORCH_CHECK(pivot, "linalg.lu_factor(): MPS doesn't allow pivot == False."); @@ -364,8 +364,7 @@ static void linalg_solve_out_mps_impl(const Tensor& A, const Tensor& info) { using namespace mps; - TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()), - "linalg.lu_factor(): MPS doesn't support complex types."); + TORCH_CHECK(A.scalar_type() == kFloat && LU.scalar_type() == kFloat, "linalg.lu_factor(): MPS only supports floats."); Tensor A_t, B_t; // If 'left' is false, reinterpret the problem so that Ax = B becomes A^T ⋅ (x^T) = B^T // Then we solve the normal "left" case on the transposed matrices and transpose x finally to get the output @@ -1058,7 +1057,8 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const using namespace mps; checkInputsSolver(A, B, left, "linalg.solve_triangular"); - TORCH_CHECK(!A.is_complex() && !B.is_complex(), "linalg.solve.triangular(); Not supported for complex yet!"); + TORCH_CHECK(A.scalar_type() == kFloat && B.scalar_type() == kFloat, + "linalg.solve.triangular(); Only float is supported!"); Tensor A_t, B_t; std::tie(B_t, A_t) = _linalg_broadcast_batch_dims(B, A, /*don't check errors*/ nullptr); at::native::resize_output(out, B_t.sizes()); diff --git a/aten/src/ATen/native/mps/operations/LossOps.mm b/aten/src/ATen/native/mps/operations/LossOps.mm index f0bbcdabfa5cd..11ee09d6e23f2 100644 --- a/aten/src/ATen/native/mps/operations/LossOps.mm +++ b/aten/src/ATen/native/mps/operations/LossOps.mm @@ -416,6 +416,8 @@ static void nllnd_loss_forward_impl(Tensor& output, int64_t reduction, int64_t ignore_index, bool is2D) { + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(output.scalar_type()), + "nlld_loss for complex is not supported for MPS"); std::vector reshapedTarget(target_arg.sizes().begin(), target_arg.sizes().end()); reshapedTarget.push_back(1); @@ -824,6 +826,9 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output, Tensor& huber_loss_out_mps(const Tensor& input, const Tensor& target, int64_t reduction, double delta, Tensor& output) { std::string op_name = __func__; using namespace mps; + TORCH_CHECK_NOT_IMPLEMENTED(input.scalar_type() != kLong, "MPS doesn't know how to do square_i64"); + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()), + "huber_loss for complex is not supported for MPS"); TORCH_CHECK(delta > 0, "huber_loss does not support non-positive values for delta.") TORCH_CHECK(target.is_same_size(input), op_name + ": target and input tensors must have identical shapes") TORCH_CHECK(output.is_mps()); diff --git a/aten/src/ATen/native/mps/operations/Pooling.mm b/aten/src/ATen/native/mps/operations/Pooling.mm index 2d466f7c79436..ecd5f12df17f8 100644 --- a/aten/src/ATen/native/mps/operations/Pooling.mm +++ b/aten/src/ATen/native/mps/operations/Pooling.mm @@ -597,6 +597,7 @@ static void avg_pool2d_template(const Tensor& input, bool count_include_pad, const std::optional divisor_override, const std::string& op_name) { + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()), "Not implemented for complex"); const Tensor& grad_output = *(at::borrow_from_optional_tensor(grad_output_opt)); const bool is_backward_pass = grad_output.defined(); const bool use_divisor = divisor_override.has_value() && divisor_override.value() != 0; @@ -915,6 +916,8 @@ Tensor mps_max_pool2d_backward(const Tensor& grad_output, bool ceil_mode, const Tensor& output, const Tensor& indices) { + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()), + "Max pooling for complex is not supported for MPS"); bool use_graph = use_graph_for_max_pool2d(kernel_size, stride); if (use_graph) { auto indices_memory_format = indices.suggest_memory_format(); @@ -967,6 +970,8 @@ Tensor mps_max_pool2d_backward(const Tensor& grad_output, bool ceil_mode, const Tensor& indices, const Tensor& grad_input) { + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()), + "Max pooling for complex is not supported for MPS"); mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) { MPSGraph* mpsGraph = cachedGraph.graph(); return [mpsGraph maxPooling2DGradientWithGradientTensor:cachedGraph.gradOutputTensor diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index 3747f314adfa1..e634eefee2058 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -269,17 +269,22 @@ static void reduction_out_mps(const Tensor& input_t, name:nil]; castOutputTensor = [mpsGraph reductionSumWithTensor:bandPartWithTensor axes:@[ @0, @1 ] name:nil]; } else if (reduction_type == MPSReductionType::NANSUM) { - // Create a 0 tensor of the same shape as inputTensor - MPSGraphTensor* zeros = [mpsGraph constantWithScalar:0.0 dataType:castInputTensor.dataType]; - // Find NaNs - MPSGraphTensor* nanMask = [mpsGraph isNaNWithTensor:castInputTensor name:nil]; - // Replace NaNs with 0 - MPSGraphTensor* nanReplaced = [mpsGraph selectWithPredicateTensor:nanMask - truePredicateTensor:zeros - falsePredicateTensor:castInputTensor - name:nil]; - // Sum - castOutputTensor = [mpsGraph reductionSumWithTensor:nanReplaced axes:wrappedAxes name:nil]; + // Integral types cannot contain NaN, so just do regular sum + if (([castInputTensor dataType] & MPSDataTypeFloatBit) == 0) { + castOutputTensor = [mpsGraph reductionSumWithTensor:castInputTensor axes:wrappedAxes name:nil]; + } else { + // Create a 0 tensor of the same shape as inputTensor + auto zeros = [mpsGraph constantWithScalar:0.0 dataType:castInputTensor.dataType]; + // Find NaNs + auto nanMask = [mpsGraph isNaNWithTensor:castInputTensor name:nil]; + // Replace NaNs with 0 + auto nanReplaced = [mpsGraph selectWithPredicateTensor:nanMask + truePredicateTensor:zeros + falsePredicateTensor:castInputTensor + name:nil]; + // Sum + castOutputTensor = [mpsGraph reductionSumWithTensor:nanReplaced axes:wrappedAxes name:nil]; + } } MPSGraphTensor* outputTensor = castOutputTensor; @@ -442,6 +447,7 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t, const std::optional& correction, bool keepdim, StdVarType stdVarType) { + TORCH_CHECK_NOT_IMPLEMENTED(input_t.scalar_type() != kLong, "Not implemented for MPS"); using CachedGraph = MPSUnaryCachedGraph; IntArrayRef input_shape = input_t.sizes(); diff --git a/aten/src/ATen/native/mps/operations/SoftMax.mm b/aten/src/ATen/native/mps/operations/SoftMax.mm index 8f70e216dcae8..8eb24d0cb68bf 100644 --- a/aten/src/ATen/native/mps/operations/SoftMax.mm +++ b/aten/src/ATen/native/mps/operations/SoftMax.mm @@ -39,6 +39,7 @@ static void get_shapes(MPSShape* input_shape_readonly, TORCH_IMPL_FUNC(softmax_mps_out) (const Tensor& input_, const int64_t dim, const bool half_to_float, const Tensor& output) { TORCH_CHECK(!half_to_float, "softmax with half to float conversion is not supported on MPS"); + TORCH_CHECK(c10::isFloatingType(input_.scalar_type()), "softmax only supported for floating types"); static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); if (input_.numel() == 0) { diff --git a/aten/src/ATen/native/mps/operations/SummaryOps.mm b/aten/src/ATen/native/mps/operations/SummaryOps.mm index e709ec2d4f618..21cae885c3685 100644 --- a/aten/src/ATen/native/mps/operations/SummaryOps.mm +++ b/aten/src/ATen/native/mps/operations/SummaryOps.mm @@ -18,6 +18,10 @@ MPSStream* stream = getCurrentMPSStream(); bool has_weights = weights.defined(); + // Crashes with + // MPSGraphUtilities.mm:190:0: error: 'mps.scatter' op operand #2 must be tensor of int values, but got 'tensor<5xi1>' + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kBool, "bincount is not supported for Bool"); + @autoreleasepool { std::string key = "bincount_mps_impl" + getTensorsStringKey({self, weights}); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { From d48cae96a6815bbee742062b6eadc7ddb87a6ac8 Mon Sep 17 00:00:00 2001 From: Sam Gross Date: Wed, 19 Nov 2025 05:11:51 +0000 Subject: [PATCH 029/230] Shrink binary size (#168080) Summary: Shrink binary size to reduce relocation overflows. The most important change is to split `intrusive_ptr::reset_()` into two functions and mark the bigger one as `C10_NOINLINE`. Differential Revision: D87308588 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168080 Approved by: https://github.com/albanD, https://github.com/Skylion007, https://github.com/malfet, https://github.com/ezyang --- c10/core/StorageImpl.cpp | 6 +- c10/core/StorageImpl.h | 6 +- c10/core/TensorImpl.cpp | 6 +- c10/core/TensorImpl.h | 6 +- c10/util/intrusive_ptr.h | 118 +++++++++++++++++++++------------------ 5 files changed, 75 insertions(+), 67 deletions(-) diff --git a/c10/core/StorageImpl.cpp b/c10/core/StorageImpl.cpp index 00fc03bbd0fcf..56bc75e01adb1 100644 --- a/c10/core/StorageImpl.cpp +++ b/c10/core/StorageImpl.cpp @@ -48,7 +48,7 @@ void warnDeprecatedDataPtr() { TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid."); } -void StorageImpl::incref_pyobject() const { +void StorageImpl::incref_pyobject() const noexcept { // Because intrusive_ptr incref uses relaxed memory order, we need to // do an acquire fence to ensure that the kHasPyObject bit was // observed before the load of the PyObject* below. @@ -59,12 +59,12 @@ void StorageImpl::incref_pyobject() const { (*pyobj_slot_.pyobj_interpreter())->incref(obj); } -void StorageImpl::decref_pyobject() const { +void StorageImpl::decref_pyobject() const noexcept { PyObject* obj = pyobj_slot_.load_pyobj(); (*pyobj_slot_.pyobj_interpreter())->decref(obj); } -bool StorageImpl::try_incref_pyobject() const { +bool StorageImpl::try_incref_pyobject() const noexcept { c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter(); if (C10_UNLIKELY(!interp)) { return false; diff --git a/c10/core/StorageImpl.h b/c10/core/StorageImpl.h index c7dbd5c1f005b..8df32f552c754 100644 --- a/c10/core/StorageImpl.h +++ b/c10/core/StorageImpl.h @@ -105,11 +105,11 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target { data_ptr_.clear(); } - void incref_pyobject() const override final; + void incref_pyobject() const noexcept override final; - void decref_pyobject() const override final; + void decref_pyobject() const noexcept override final; - bool try_incref_pyobject() const override final; + bool try_incref_pyobject() const noexcept override final; size_t nbytes() const { // OK to do this instead of maybe_as_int as nbytes is guaranteed positive diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 94a7375cc32fb..c890d6d084eb3 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -988,7 +988,7 @@ void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) { } } -void TensorImpl::incref_pyobject() const { +void TensorImpl::incref_pyobject() const noexcept { // Because intrusive_ptr incref uses relaxed memory order, we need to // do an acquire fence to ensure that the kHasPyObject bit was // observed before the load of the PyObject* below. @@ -999,12 +999,12 @@ void TensorImpl::incref_pyobject() const { (*pyobj_slot_.pyobj_interpreter())->incref(obj); } -void TensorImpl::decref_pyobject() const { +void TensorImpl::decref_pyobject() const noexcept { PyObject* obj = pyobj_slot_.load_pyobj(); (*pyobj_slot_.pyobj_interpreter())->decref(obj); } -bool TensorImpl::try_incref_pyobject() const { +bool TensorImpl::try_incref_pyobject() const noexcept { c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter(); if (C10_UNLIKELY(!interp)) { return false; diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 71a0195dde773..42b6bb1e80d2e 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -2178,11 +2178,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return &pyobj_slot_; } - void incref_pyobject() const override final; + void incref_pyobject() const noexcept override final; - void decref_pyobject() const override final; + void decref_pyobject() const noexcept override final; - bool try_incref_pyobject() const override final; + bool try_incref_pyobject() const noexcept override final; private: // See NOTE [std::optional operator usage in CUDA] diff --git a/c10/util/intrusive_ptr.h b/c10/util/intrusive_ptr.h index 0c8f55f5061ab..f3c4ab0dc7cbc 100644 --- a/c10/util/intrusive_ptr.h +++ b/c10/util/intrusive_ptr.h @@ -68,6 +68,10 @@ inline bool has_pyobject(uint64_t combined_refcount) { return (combined_refcount & kHasPyObject) != 0; } +inline bool is_uniquely_owned(uint64_t combined_refcount) { + return (combined_refcount & ~detail::kHasPyObject) == detail::kUniqueRef; +} + // The only requirement for refcount increment is that it happens-before // decrement, so no additional memory ordering is needed. inline uint64_t atomic_combined_refcount_increment( @@ -287,9 +291,9 @@ class C10_API intrusive_ptr_target { * These two methods are called when the refcount transitions between one * and two and the object has a PyObject wrapper. */ - virtual void incref_pyobject() const {} - virtual void decref_pyobject() const {} - virtual bool try_incref_pyobject() const { + virtual void incref_pyobject() const noexcept {} + virtual void decref_pyobject() const noexcept {} + virtual bool try_incref_pyobject() const noexcept { return false; } @@ -363,7 +367,7 @@ class intrusive_ptr final { template friend class pybind11::class_; - void retain_() { + void retain_() noexcept { if (target_ != NullType::singleton()) { uint64_t combined = detail::atomic_combined_refcount_increment( target_->combined_refcount_, detail::kReferenceCountOne); @@ -377,9 +381,7 @@ class intrusive_ptr final { // PyObject. In other words, we need to ensure that the PyObject stays // alive now that we have a C++ reference to this object in addition to // the PyObject itself. - if (C10_UNLIKELY( - detail::has_pyobject(combined) && - detail::refcount(combined) == 2)) { + if (detail::has_pyobject(combined) && detail::refcount(combined) == 2) { target_->incref_pyobject(); } } else { @@ -392,51 +394,60 @@ class intrusive_ptr final { void reset_() noexcept { if (target_ != NullType::singleton()) { - if (is_uniquely_owned()) { - // Both counts are 1, so there are no weak references and - // we are releasing the last strong reference. No other - // threads can observe the effects of this target_ deletion - // call (e.g. calling use_count()) without a data race. - target_->combined_refcount_.store(0, std::memory_order_relaxed); - delete target_; + reset_not_null_(target_); + } + } + + // C10_NOINLINE to keep binary size a bit smaller. We pass TTarget* here + // to avoid an extra pointer dereference in the call from reset_(). + C10_NOINLINE static void reset_not_null_(TTarget* target) noexcept { + if (detail::is_uniquely_owned( + target->combined_refcount_.load(std::memory_order_acquire))) { + // Both counts are 1, so there are no weak references and + // we are releasing the last strong reference. No other + // threads can observe the effects of this target deletion + // call (e.g. calling use_count()) without a data race. + target->combined_refcount_.store(0, std::memory_order_relaxed); + delete target; + return; + } + + auto combined_refcount = detail::atomic_combined_refcount_decrement( + target->combined_refcount_, detail::kReferenceCountOne); + uint32_t new_refcount = detail::refcount(combined_refcount); + bool has_pyobject = detail::has_pyobject(combined_refcount); + if (new_refcount == 0) { + if (detail::weakcount(combined_refcount) == 1) { + delete target; return; } - - auto combined_refcount = detail::atomic_combined_refcount_decrement( - target_->combined_refcount_, detail::kReferenceCountOne); - uint32_t new_refcount = detail::refcount(combined_refcount); - bool has_pyobject = detail::has_pyobject(combined_refcount); - if (new_refcount == 0) { - bool should_delete = detail::weakcount(combined_refcount) == 1; - // See comment above about weakcount. As long as refcount>0, - // weakcount is one larger than the actual number of weak references. - // So we need to decrement it here. - if (!should_delete) { - // justification for const_cast: release_resources is basically a - // destructor and a destructor always mutates the object, even for - // const objects. - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - const_cast*>(target_) - ->release_resources(); - should_delete = detail::atomic_weakcount_decrement( - target_->combined_refcount_) == 0; - } - if (should_delete) { - delete target_; - } - } else if constexpr (detail::TargetTraits::can_have_pyobject) { - // If the refcount transitioned from 2 to 1, we need to decref the - // PyObject. In other words, we don't want to keep the PyObject alive if - // there are no C++ references to this object other than the PyObject - // itself. - if (C10_UNLIKELY(has_pyobject && new_refcount == 1)) { - target_->decref_pyobject(); - } - } else { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - !has_pyobject, - "TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set."); + // See comment above about weakcount. As long as refcount>0, + // weakcount is one larger than the actual number of weak references. + // So we need to decrement it here. + release_resources_and_decrement_weakrefs_(target); + } else if constexpr (detail::TargetTraits::can_have_pyobject) { + // If the refcount transitioned from 2 to 1, we need to decref the + // PyObject. In other words, we don't want to keep the PyObject alive if + // there are no C++ references to this object other than the PyObject + // itself. + if (has_pyobject && new_refcount == 1) { + target->decref_pyobject(); } + } else { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + !has_pyobject, + "TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set."); + } + } + + C10_NOINLINE static void release_resources_and_decrement_weakrefs_( + TTarget* target) noexcept { + // justification for const_cast: release_resources is basically a + // destructor and a destructor always mutates the object, even for + // const objects. + const_cast*>(target)->release_resources(); + if (detail::atomic_weakcount_decrement(target->combined_refcount_) == 0) { + delete target; } } @@ -607,9 +618,8 @@ class intrusive_ptr final { */ bool is_uniquely_owned() const noexcept { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(target_ != NullType::singleton()); - uint64_t combined = - target_->combined_refcount_.load(std::memory_order_acquire); - return (combined & ~detail::kHasPyObject) == detail::kUniqueRef; + return detail::is_uniquely_owned( + target_->combined_refcount_.load(std::memory_order_acquire)); } /** @@ -1174,9 +1184,7 @@ inline void incref(intrusive_ptr_target* self) { self->combined_refcount_, detail::kReferenceCountOne); #ifndef C10_MOBILE - if (C10_UNLIKELY( - detail::has_pyobject(combined) && - detail::refcount(combined) == 2)) { + if (detail::has_pyobject(combined) && detail::refcount(combined) == 2) { self->incref_pyobject(); } #else From 6fc430644b1357fab03a03619576fef7197ac60e Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Thu, 6 Nov 2025 14:42:35 +0000 Subject: [PATCH 030/230] Improve build logic in activities for kineto (#167204) # Motivation Thanks to @KarhouTam for finding the issue mentioned in #167172 This PR aims to improve the build logic in activities for kineto. # Additional Context Fix #167172 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167204 Approved by: https://github.com/EikanWang, https://github.com/ezyang --- torch/csrc/autograd/init.cpp | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index a13cc70270ccb..7470344cc05f7 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -390,31 +390,27 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { m.def("_supported_activities", []() { std::set activities{ torch::profiler::impl::ActivityType::CPU}; -#if defined(USE_KINETO) && \ - (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER)) - if (at::hasMTIA()) { - activities.insert(torch::profiler::impl::ActivityType::MTIA); - } - if (at::hasHPU()) { - activities.insert(torch::profiler::impl::ActivityType::HPU); - } +#if defined(USE_KINETO) +#if (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER)) if (at::getNumGPUs() > 0) { activities.insert(torch::profiler::impl::ActivityType::CUDA); } -#elif defined(USE_KINETO) +#endif // (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER)) +#if (!defined(LIBKINETO_NOXPUPTI)) if (at::hasXPU()) { activities.insert(torch::profiler::impl::ActivityType::XPU); } - if (at::hasHPU()) { - activities.insert(torch::profiler::impl::ActivityType::HPU); - } +#endif // (!defined(LIBKINETO_NOXPUPTI)) if (at::hasMTIA()) { activities.insert(torch::profiler::impl::ActivityType::MTIA); } + if (at::hasHPU()) { + activities.insert(torch::profiler::impl::ActivityType::HPU); + } if (c10::get_privateuse1_backend() != "privateuseone") { activities.insert(torch::profiler::impl::ActivityType::PrivateUse1); } -#endif +#endif // defined(USE_KINETO) return activities; }); From 28c7602c902695fd8abf710e9acb95e2bf18367d Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Wed, 19 Nov 2025 06:11:30 +0000 Subject: [PATCH 031/230] [vision hash update] update the pinned vision hash (#168130) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned vision hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168130 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/vision.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/vision.txt b/.github/ci_commit_pins/vision.txt index 64ee992f566b7..c3b209c216014 100644 --- a/.github/ci_commit_pins/vision.txt +++ b/.github/ci_commit_pins/vision.txt @@ -1 +1 @@ -2d82dc5caa336d179d9b46ac4a0fb8c43d84c5cc +617079d944b0e72632311c30ae2bbdf1168b901e From f49833de54450b03b808a5b9ad774ce14ff2c8a2 Mon Sep 17 00:00:00 2001 From: angelayi Date: Mon, 17 Nov 2025 16:55:31 -0800 Subject: [PATCH 032/230] [hoo] Invoke subgraph + effect (#167231) This PR adds support for effectful ops within invoke_subgraphs. * Most of the logic is in `invoke_subgraph.py_functionalize_impl`. * In the functionalization metadata collection phase, we note the tokens before going further down the dispatcher, and then note the tokens after coming back from the dispatcher. If there are nodes in the invoke_subgraph subgraph that contain effects, the number of effects should change, or the tokens used for an effect should. * We will store this effect difference in the `InvokeSubgraphCache` where the key is the identifier and value is the effect. For now we only support one effect within a subgraph. * During the tracing part of AOTAutograd, we will then wrap the subgraph to take in and output a token. Before: ``` def forward(self, x): repeated_subgraph0 = self.repeated_subgraph0 invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', x) return invoke_subgraph def repeated_subgraph(self, x): record_memory = torch.ops.mylib.record_memory.default("forward", "N") add = torch.ops.aten.add(x, x) return add ``` After: ``` def forward(self, token, x): repeated_subgraph0 = self.repeated_subgraph0 invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', token, x) getitem = invoke_subgraph[0] # output token getitem_1 = invoke_subgraph[1] return (getitem, getitem_1) def repeated_subgraph(self, token, x): with_effects = torch.ops.higher_order.with_effects(token, torch.ops.mylib.record_memory.default, 'forward', 'N') getitem = with_effects[0] # output token add = torch.ops.aten.add(x, x) return (getitem, add) ``` * Then there is a bunch of logic within `_remove_effect_tokens` to handle removing the effects from the invoke_subgraph subgraph Differential Revision: [D87392741](https://our.internmc.facebook.com/intern/diff/D87392741) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167231 Approved by: https://github.com/anijain2305 --- test/export/test_converter.py | 2 +- test/export/test_passes.py | 15 +- test/export/test_torchbind.py | 12 +- test/higher_order_ops/test_with_effects.py | 98 ++++++++ torch/_guards.py | 18 ++ torch/_higher_order_ops/invoke_subgraph.py | 50 ++++ torch/_library/effects.py | 15 ++ torch/export/_remove_effect_tokens_pass.py | 267 ++++++++++++--------- torch/export/_unlift.py | 24 +- torch/fx/node.py | 4 +- 10 files changed, 370 insertions(+), 135 deletions(-) diff --git a/test/export/test_converter.py b/test/export/test_converter.py index e739e5c346677..5b608503a1168 100644 --- a/test/export/test_converter.py +++ b/test/export/test_converter.py @@ -1405,7 +1405,7 @@ def func3(x): # noqa: F841 ) # qnnpack not supported on s390x @xfailIfS390X - def test_ts2ep_convert_quantized_model(self): + def test_ts2ep_convert_quantized_model1(self): class Standalone(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/export/test_passes.py b/test/export/test_passes.py index 9cf442c27a2bb..866eeaaee3986 100644 --- a/test/export/test_passes.py +++ b/test/export/test_passes.py @@ -640,16 +640,13 @@ def forward(self, x): self.assertExpectedInline( without_token_ep.graph_module.code.strip(), """\ -def forward(self, token, obj_attr, x): - with_effects = torch.ops.higher_order.with_effects(token, torch.ops._TorchScriptTesting.takes_foo_tuple_return.default, foo = obj_attr, x = x); token = x = None - getitem = with_effects[0] - getitem_1 = with_effects[1] - getitem_2 = with_effects[2]; with_effects = None +def forward(self, obj_attr, x): + takes_foo_tuple_return_default = torch.ops._TorchScriptTesting.takes_foo_tuple_return.default(foo = obj_attr, x = x); x = None + getitem_1 = takes_foo_tuple_return_default[0] + getitem_2 = takes_foo_tuple_return_default[1]; takes_foo_tuple_return_default = None add = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = getitem_2 = None - with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, foo = obj_attr, x = add); getitem = obj_attr = add = None - getitem_3 = with_effects_1[0] - getitem_4 = with_effects_1[1]; with_effects_1 = None - return (getitem_3, getitem_4)""", # noqa: B950 + takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(foo = obj_attr, x = add); obj_attr = add = None + return (takes_foo_default,)""", # noqa: B950 ) def test_fakify_script_objects(self): diff --git a/test/export/test_torchbind.py b/test/export/test_torchbind.py index 246122433e06c..adf0986811648 100644 --- a/test/export/test_torchbind.py +++ b/test/export/test_torchbind.py @@ -461,9 +461,9 @@ def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) attr = self.attr _guards_fn = self._guards_fn(x); _guards_fn = None - takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, x) - takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default_1); attr = takes_foo_default_1 = None - add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None + takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, x) + takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default); attr = takes_foo_default = None + add = torch.ops.aten.add.Tensor(x, takes_foo_default_1); x = takes_foo_default_1 = None return pytree.tree_unflatten((add,), self._out_spec)""", # noqa: B950 ) self.assertExpectedInline( @@ -1087,10 +1087,12 @@ def forward(self, token, tq, x): str(ep.graph_module.graph).strip(), """\ graph(): + %token : [num_users=1] = placeholder[target=token] %tq : [num_users=2] = placeholder[target=tq] %x : [num_users=1] = placeholder[target=x] - %queue_push_default : [num_users=0] = call_function[target=torch.ops._TorchScriptTesting.queue_push.default](args = (%tq, %x), kwargs = {}) - return (tq,)""", # noqa: B950 + %with_effects : [num_users=1] = call_function[target=torch.ops.higher_order.with_effects](args = (%token, _TorchScriptTesting.queue_push.default, %tq, %x), kwargs = {}) + %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%with_effects, 0), kwargs = {}) + return (getitem, tq)""", # noqa: B950 ) def test_deepcopy(self): diff --git a/test/higher_order_ops/test_with_effects.py b/test/higher_order_ops/test_with_effects.py index 2c4cf02bc1c8a..38e38c9e13f01 100644 --- a/test/higher_order_ops/test_with_effects.py +++ b/test/higher_order_ops/test_with_effects.py @@ -870,6 +870,104 @@ def forward(self, primals_2, getitem_1, tangents_1, tangents_token): finally: handle.destroy() + @unittest.skipIf(not TEST_CUDA, "triton") + def test_export_invoke_subgraph(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + recorded_list = [] + + @torch.library.custom_op("mylib::record_memory", mutates_args=()) + def record_memory(prefix: str, module_name: str) -> None: + torch.cuda.synchronize() + mem_alloc = torch.cuda.memory_allocated() / 1024**2 + mem_reserved = torch.cuda.memory_reserved() / 1024**2 + memory_str = f"[{prefix}] {module_name}: allocated={mem_alloc:.2f} MB, reserved={mem_reserved:.2f} MB" + recorded_list.append(memory_str) + + @record_memory.register_fake + def record_memory_fake(prefix, module_name): + return + + record_memory.register_effect(_EffectType.ORDERED) + + class N(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(1024, 1024) + self.relu = torch.nn.ReLU() + self.linear2 = torch.nn.Linear(1024, 1024) + + @torch.compiler.nested_compile_region + def forward(self, x): + torch.ops.mylib.record_memory("forward", "N") + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + return x + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.mod_list = torch.nn.ModuleList(N() for _ in range(3)) + + def forward(self, x): + for m in self.mod_list: + x = m(x) + torch.ops.mylib.record_memory("forward", "N") + return (x,) + + model = M().to("cuda") + torch.cuda.reset_peak_memory_stats() + + x = torch.randn(32, 1024, requires_grad=True, device="cuda") + + ep = torch.export.export(model, (x,)) + ep = ep.run_decompositions() + self.assertEqual(len(list(ep.graph_module.named_modules())), 2) + + self.assertExpectedInline( + ep.graph_module.code.strip(), + """\ +def forward(self, token, p_mod_list_0_linear1_weight, p_mod_list_0_linear1_bias, p_mod_list_0_linear2_weight, p_mod_list_0_linear2_bias, p_mod_list_1_linear1_weight, p_mod_list_1_linear1_bias, p_mod_list_1_linear2_weight, p_mod_list_1_linear2_bias, p_mod_list_2_linear1_weight, p_mod_list_2_linear1_bias, p_mod_list_2_linear2_weight, p_mod_list_2_linear2_bias, x): + repeated_subgraph0 = self.repeated_subgraph0 + invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', token, x, p_mod_list_0_linear1_weight, p_mod_list_0_linear1_bias, p_mod_list_0_linear2_weight, p_mod_list_0_linear2_bias); repeated_subgraph0 = token = x = p_mod_list_0_linear1_weight = p_mod_list_0_linear1_bias = p_mod_list_0_linear2_weight = p_mod_list_0_linear2_bias = None + getitem = invoke_subgraph[0] + getitem_1 = invoke_subgraph[1]; invoke_subgraph = None + repeated_subgraph0_1 = self.repeated_subgraph0 + invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', getitem, getitem_1, p_mod_list_1_linear1_weight, p_mod_list_1_linear1_bias, p_mod_list_1_linear2_weight, p_mod_list_1_linear2_bias); repeated_subgraph0_1 = getitem = getitem_1 = p_mod_list_1_linear1_weight = p_mod_list_1_linear1_bias = p_mod_list_1_linear2_weight = p_mod_list_1_linear2_bias = None + getitem_2 = invoke_subgraph_1[0] + getitem_3 = invoke_subgraph_1[1]; invoke_subgraph_1 = None + repeated_subgraph0_2 = self.repeated_subgraph0 + invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_2, 'subgraph_0', getitem_2, getitem_3, p_mod_list_2_linear1_weight, p_mod_list_2_linear1_bias, p_mod_list_2_linear2_weight, p_mod_list_2_linear2_bias); repeated_subgraph0_2 = getitem_2 = getitem_3 = p_mod_list_2_linear1_weight = p_mod_list_2_linear1_bias = p_mod_list_2_linear2_weight = p_mod_list_2_linear2_bias = None + getitem_4 = invoke_subgraph_2[0] + getitem_5 = invoke_subgraph_2[1]; invoke_subgraph_2 = None + with_effects = torch.ops.higher_order.with_effects(getitem_4, torch.ops.mylib.record_memory.default, 'forward', 'N'); getitem_4 = None + getitem_6 = with_effects[0]; with_effects = None + return (getitem_6, getitem_5)""", + ) + + self.assertExpectedInline( + ep.graph_module.repeated_subgraph0.code.strip(), + """\ +def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1): + with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.mylib.record_memory.default, 'forward', 'N'); arg0_1 = None + getitem = with_effects[0]; with_effects = None + permute = torch.ops.aten.permute.default(arg2_1, [1, 0]); arg2_1 = None + addmm = torch.ops.aten.addmm.default(arg3_1, arg1_1, permute); arg3_1 = arg1_1 = permute = None + relu = torch.ops.aten.relu.default(addmm); addmm = None + permute_1 = torch.ops.aten.permute.default(arg4_1, [1, 0]); arg4_1 = None + addmm_1 = torch.ops.aten.addmm.default(arg5_1, relu, permute_1); arg5_1 = relu = permute_1 = None + return (getitem, addmm_1)""", + ) + + recorded_list.clear() + # TODO: seems like invoke_subgraph's py_autograd impl calls the subgraph + # eagerly twice. Once for get_output_metadata and then once for + # InvokeSubgraphAutogradOp. This causes record_memory to be called twice. + with torch.no_grad(): + out2 = ep.module()(x) + self.assertEqual(len(recorded_list), 4) + self.assertTrue(torch.allclose(model(x)[0], out2[0])) + if __name__ == "__main__": run_tests() diff --git a/torch/_guards.py b/torch/_guards.py index 32b796d71eea7..1bd32fc7f08ec 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -713,6 +713,9 @@ def __init__(self) -> None: self.lazy_bwd_cache: dict[ str, dict[tuple[object], tuple[torch.fx.GraphModule, int]] ] = defaultdict(dict) + self.effects_cache: dict[ + str, set + ] = {} # Maps identifier -> set of effect types def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None: self.dynamo_installed_submodules[fn_id].append(identifier) @@ -751,6 +754,21 @@ def get_lazy_bwd_entry( return self.lazy_bwd_cache[identifier].get(tangent_metadata, (None, None)) + def add_effects(self, identifier: str, effects: set) -> None: + """Store the effect types for a given invoke_subgraph identifier.""" + if prev_effects := self.effects_cache.get(identifier, None): + assert effects == prev_effects, ( + "Different number of effects were found for invoke_subgraph " + f"call with identifier {identifier}. \n" + f"Previously we had the following effects: {prev_effects}.\n" + f"But now we have: {effects}." + ) + self.effects_cache[identifier] = effects + + def get_effects(self, identifier: str) -> Optional[set]: + """Retrieve the effect types for a given invoke_subgraph identifier.""" + return self.effects_cache.get(identifier, None) + class HopDispatchSetCache: def __init__(self) -> None: diff --git a/torch/_higher_order_ops/invoke_subgraph.py b/torch/_higher_order_ops/invoke_subgraph.py index e22b741631d3f..7d066e132e011 100644 --- a/torch/_higher_order_ops/invoke_subgraph.py +++ b/torch/_higher_order_ops/invoke_subgraph.py @@ -80,6 +80,7 @@ def __call__( assert all( isinstance(o, (torch.Tensor, int, torch.SymInt, torch.Generator)) for o in operands + if o is not None ), ( f"invoke_subgraph operands must be a list of tensors/ints/SymInts/Generator {operands}" ) @@ -562,7 +563,34 @@ def _(ctx, subgraph, identifier, *operands): do_auto_functionalize_v2, ) + # (in the functionalization metadata phase) Capture tokens before + tokens_before = dict(ctx.mode._tokens) + + # Check if this subgraph has effects stored in the cache + invoke_subgraph_cache = get_invoke_subgraph_cache() + effects = None + if invoke_subgraph_cache: + effects = invoke_subgraph_cache.get_effects(identifier) + + if effects: + assert len(effects) == 1, "Multiple effects within a subgraph NYI" + tokens = ctx.mode._tokens + effects = next(iter(effects)) + token_input = tokens[effects] + + operands = (token_input, *operands) + + def wrap_subgraph(subgraph): + def wrapped_subgraph(token, *args): + res = subgraph(*args) + return ctx.unwrap_tensors(ctx.mode._tokens[effects]), *res + + return wrapped_subgraph + + subgraph = wrap_subgraph(subgraph) + unwrapped_operands = ctx.unwrap_tensors(operands) + hop_instance = HopInstance.create(invoke_subgraph, subgraph, identifier, *operands) if can_auto_functionalize(hop_instance): # NOTE: [auto_functionalize x invoke_subgraph caching] @@ -587,6 +615,28 @@ def _(ctx, subgraph, identifier, *operands): # of invoke_subgraph ops if input aliasing/mutation is detected. functionalized_subgraph = FunctionalizeCtxWrapper(ctx, subgraph) out = invoke_subgraph(functionalized_subgraph, identifier, *unwrapped_operands) + + if effects: + (new_token, *out) = out + ctx.mode._tokens[effects] = new_token + + # (in the functionalization metadata phase) Capture tokens after and see if + # there are any differences (there are new effects or the token value for an + # effect type has changed) + tokens_after = dict(ctx.mode._tokens) + discovered_effects = set() + for effect_type, token in tokens_after.items(): + if effect_type not in tokens_before or tokens_before[effect_type] is not token: + discovered_effects.add(effect_type) + + if discovered_effects: + assert ctx.mode._allow_token_discovery, ( + f"Number of tokens changed by {len(discovered_effects)} when tracing subgraph {subgraph}." + ) + # Store discovered effects in the cache by identifier + if invoke_subgraph_cache: + invoke_subgraph_cache.add_effects(identifier, discovered_effects) + return ctx.wrap_tensors(out) diff --git a/torch/_library/effects.py b/torch/_library/effects.py index 41fbaa4c1c7b4..3f765f380eab1 100644 --- a/torch/_library/effects.py +++ b/torch/_library/effects.py @@ -35,6 +35,18 @@ def _set_default_effect(self) -> None: if namespace == "higher_order": return + # These classes do not have side effects as they just store quantization + # params, so we dont need to mark them as ordered + skip_classes = ( + "__torch__.torch.classes.quantized.Conv2dPackedParamsBase", + "__torch__.torch.classes.quantized.Conv3dPackedParamsBase", + "__torch__.torch.classes.quantized.EmbeddingPackedParamsBase", + "__torch__.torch.classes.quantized.LinearPackedParamsBase", + "__torch__.torch.classes.xnnpack.Conv2dOpContext", + "__torch__.torch.classes.xnnpack.LinearOpContext", + "__torch__.torch.classes.xnnpack.TransposeConv2dOpContext", + ) + opname = f"{namespace}::{opname}" if torch._C._get_operation_overload(opname, overload) is not None: # Since we call this when destroying the library, sometimes the @@ -42,6 +54,9 @@ def _set_default_effect(self) -> None: schema = torch._C._get_schema(opname, overload) for arg in schema.arguments: if isinstance(arg.type, torch.ClassType): + type_str = arg.type.str() # pyrefly: ignore[missing-attribute] + if type_str in skip_classes: + continue self._effect = EffectType.ORDERED return diff --git a/torch/export/_remove_effect_tokens_pass.py b/torch/export/_remove_effect_tokens_pass.py index 21930d81fe092..3ebcf6180d660 100644 --- a/torch/export/_remove_effect_tokens_pass.py +++ b/torch/export/_remove_effect_tokens_pass.py @@ -15,113 +15,105 @@ ) -def _remove_effect_tokens_from_graph_helper( - ep, num_tokens, input_token_names, output_token_names +def _get_custom_obj_for_node(node, inputs_to_lifted_custom_objs, constants): + """Extract the custom object from a node's arguments.""" + custom_obj_node = node + custom_obj_meta = custom_obj_node.meta["val"] # type: ignore[union-attr] + assert isinstance(custom_obj_meta, CustomObjArgument) + + if custom_obj_meta.fake_val: + return custom_obj_meta.fake_val + elif custom_obj_node.name in inputs_to_lifted_custom_objs: # type: ignore[union-attr] + return constants[inputs_to_lifted_custom_objs[custom_obj_node.name]] # type: ignore[union-attr] + else: + raise RuntimeError(f"Unable to find custom obj for node {node}") + + +def _replace_with_effects_node( + node, ep, inputs_to_lifted_custom_objs, output_tokens, input_tokens, module ): - inputs_to_lifted_custom_objs = ep.graph_signature.inputs_to_lifted_custom_objs - - output_node = None - with_effect_nodes: list[torch.fx.Node] = [] - - # Output node need to check its args against output_token_names (collected from output_spec) - # Therefore, we only need to find the top-levele output node - output_node = next(reversed(ep.graph_module.graph.find_nodes(op="output"))) - for module in ep.graph_module.modules(): - if not isinstance(module, torch.fx.GraphModule): - continue - - for node in module.graph.nodes: - if not (node.op == "call_function" and node.target is with_effects): - continue - - with_effect_nodes.append(node) - - # Remove tokens from outputs - assert output_node is not None - output_args = output_node.args[0] - assert len(output_args) >= num_tokens - out_token_nodes = output_args[:num_tokens] - output_node.args = (tuple(output_args[num_tokens:]),) - for out_token in out_token_nodes: - assert out_token.name in output_token_names - out_token.users.clear() - ep.graph.erase_node(out_token) - - # Replace with_effects(token, func, args) with just func(args) - for node in reversed(with_effect_nodes): - func = node.args[1] - assert isinstance(func, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)) - - if func is torch.ops.higher_order.call_torchbind: - custom_obj_meta = node.args[2].meta["val"] # type: ignore[union-attr] - assert isinstance(custom_obj_meta, CustomObjArgument) - if custom_obj_meta.fake_val: - custom_obj = custom_obj_meta.fake_val - elif node.args[2].name in inputs_to_lifted_custom_objs: # type: ignore[union-attr] - custom_obj = ep.constants[ - inputs_to_lifted_custom_objs[node.args[2].name] # type: ignore[union-attr] - ] - else: - raise RuntimeError(f"Unable to find custom obj for node {node}") - schema = _get_schema(func, (custom_obj,) + node.args[3:]) - else: - schema = _get_schema(func, node.args[2:]) - - with ep.graph.inserting_before(node): - new_node = ep.graph.call_function(func, node.args[2:], node.kwargs) - for k, v in node.meta.items(): - new_node.meta[k] = v - if k == "unbacked_bindings": - # Remove the extra layer for effect token - old_bindings = new_node.meta[k] - new_bindings = { - k: path[1:] if path else path for k, path in old_bindings.items() - } - new_node.meta[k] = new_bindings - - node.replace_all_uses_with(new_node) - - # Update user getitem nodes - for user in list(new_node.users.keys()): - assert user.target is operator.getitem - # getitem(with_effects, 0) == token - if user.args[1] == 0: - ep.graph.erase_node(user) - - if len(schema.returns) == 1: - # If the function has 1 return then it will just directly return the - # result -- we don't need a getitem. So we can replace all the - # getitem(with_effects, 1) with just the note itself. - for user in list(new_node.users.keys()): - assert user.args[1] == 1 + """Replace a with_effects node with the underlying function call.""" + # Get the input nodes + token_node, func, *node_args = node.args + if token_node.op == "placeholder": + input_tokens.append(token_node) + + assert isinstance(func, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)) + + # Get the schema for the function + if func is torch.ops.higher_order.call_torchbind: + custom_obj = _get_custom_obj_for_node( + node_args[0], inputs_to_lifted_custom_objs, ep.constants + ) + schema = _get_schema(func, [custom_obj] + node_args[1:]) + else: + schema = _get_schema(func, node_args) + + # Create the replacement node + with module.graph.inserting_before(node): + new_node = module.graph.call_function(func, tuple(node_args), node.kwargs) + + # Update getitem nodes that extract outputs from with_effects + for user in list(node.users.keys()): + assert user.target is operator.getitem + # getitem(with_effects, 0) is the token node + if user.args[1] == 0: + for user_user in list(user.users.keys()): + if user_user.op == "output": + output_tokens.append(user) + + # Fix up the getitem nodes based on return count + if len(schema.returns) == 1: + # Single return: replace getitem(with_effects, 1) with the node itself + for user in list(node.users.keys()): + if user.args[1] == 1: user.replace_all_uses_with(new_node) - - new_node.meta["val"] = node.meta["val"][1] - elif len(schema.returns) > 1: - # If the function has more than 1 return then since we got rid of - # the 1st return value (the token), we need to bump all the other - # getitem calls by 1 down - for user in list(new_node.users.keys()): - assert user.args[1] >= 1 - user.args = (user.args[0], user.args[1] - 1) - - new_node.meta["val"] = node.meta["val"][1:] - else: - assert len(schema.returns) == 0 - assert len(new_node.users) == 0 - new_node.meta["val"] = None - - ep.graph.erase_node(node) - - # Remove tokens from inputs - placeholders = [node for node in ep.graph.nodes if node.op == "placeholder"] - assert len(placeholders) >= num_tokens - inp_token_nodes = placeholders[:num_tokens] - for inp_token in inp_token_nodes: - assert inp_token.name in input_token_names - ep.graph.erase_node(inp_token) - - ep.graph.eliminate_dead_code() + new_node.meta["val"] = node.meta["val"][1] + elif len(schema.returns) > 1: + # Multiple returns: shift getitem indices down by 1 + for user in list(node.users.keys()): + if user.args[1] >= 1: + user.args = (new_node, user.args[1] - 1) + new_node.meta["val"] = node.meta["val"][1:] + else: + # No returns + assert len(schema.returns) == 0 + assert len(new_node.users) == 0 + new_node.meta["val"] = None + + # Copy metadata from old node to new node + for k, v in node.meta.items(): + new_node.meta[k] = v + if k == "unbacked_bindings": + # Remove the extra layer for effect token + old_bindings = new_node.meta[k] + new_bindings = { + k: path[1:] if path else path for k, path in old_bindings.items() + } + new_node.meta[k] = new_bindings + + +def _replace_invoke_subgraph_node(node, module, output_tokens, input_tokens): + """Replace an invoke_subgraph node to remove the token argument.""" + assert node.args[0].op == "get_attr" + submod = getattr(module, node.args[0].target) + if not submod.meta.get("has_with_effects", False): + return + + # Remove token from inputs + subgraph, identifier, token, *operands = node.args + node.args = (subgraph, identifier, *operands) + if token.op == "placeholder": + input_tokens.append(token) + + # Update getitem nodes to account for removed token output + for user in list(node.users.keys()): + if user.args[1] >= 1: + user.args = (node, user.args[1] - 1) + elif user.args[1] == 0: + for user_user in list(user.users.keys()): + if user_user.op == "output": + output_tokens.append(user) def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram: @@ -132,6 +124,65 @@ def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram: This function does an inplace modification on the given ExportedProgram. """ + print("before", ep) + inputs_to_lifted_custom_objs = ep.graph_signature.inputs_to_lifted_custom_objs + + # mark submodules with effects as having effects. This will be used in the following pass to remove effects from subgraphs + for _, module in ep.graph_module.named_modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + + with_effect_nodes = [ + node for node in module.graph.nodes if node.target is with_effects + ] + if len(with_effect_nodes) > 0: + module.meta["has_with_effects"] = True + + # Process each module with the replace hook to ensure graph signature is updated + with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()): + for _, module in ep.graph_module.named_modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + + input_tokens = [] + output_tokens = [] + + # Process with_effects and invoke_subgraph nodes + for node in module.graph.nodes: + if node.target is with_effects: + _replace_with_effects_node( + node, + ep, + inputs_to_lifted_custom_objs, + output_tokens, + input_tokens, + module, + ) + elif node.target is torch.ops.higher_order.invoke_subgraph: + _replace_invoke_subgraph_node( + node, module, output_tokens, input_tokens + ) + + # Remove tokens from the output node + if len(output_tokens) > 0: + output_node = next(reversed(module.graph.find_nodes(op="output"))) + output_args = output_node.args[0] + assert len(output_args) >= len(output_tokens), ( + f"{output_args} output arguments found\n" + f"{output_tokens} output tokens found\n" + f"{module.graph}" + ) + output_node.args = (tuple(output_args[len(output_tokens) :]),) + + module.graph.eliminate_dead_code() + + # Remove tokens from the input placeholders + for node in module.graph.nodes: + if node.op == "placeholder" and node in input_tokens: + module.graph.erase_node(node) + + module.recompile() + num_tokens: int = 0 input_token_names: list[str] = [] new_input_specs: list[InputSpec] = [] @@ -159,9 +210,5 @@ def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram: assert num_tokens == num_out_tokens - with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()): - _remove_effect_tokens_from_graph_helper( - ep, num_tokens, input_token_names, output_token_names - ) - + print("after", ep) return ep diff --git a/torch/export/_unlift.py b/torch/export/_unlift.py index 52d06a294fac1..6239c5899c233 100644 --- a/torch/export/_unlift.py +++ b/torch/export/_unlift.py @@ -748,11 +748,23 @@ def _unlift_exported_program_lifted_states( ) -> torch.fx.GraphModule: check_guards = check_guards and _ok_to_generate_guards_fn() + source_node_dict = { + node.name: node for node in ep.graph.nodes if node.op != "placeholder" + } + # placeholder node name might change after deepcopy + placeholder_source_node_dict = { + node.target: node for node in ep.graph.nodes if node.op == "placeholder" + } + + new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph)) + new_gm.meta.update(ep.graph_module.meta) + ep = copy.copy(ep) + ep._graph_module = new_gm + # TODO T206340015 if ep.verifiers[0].dialect != "TRAINING": ep = _remove_effect_tokens(ep) - new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph)) _register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants) forward_arg_names = ( sig.forward_arg_names if (sig := ep.module_call_graph[0].signature) else None @@ -786,19 +798,13 @@ def _unlift_exported_program_lifted_states( for out_spec in ep.graph_signature.output_specs ] - source_node_dict = { - node.name: node for node in ep.graph.nodes if node.op != "placeholder" - } - # placeholder node name might change after deepcopy - placeholder_source_node_dict = { - node.target: node for node in ep.graph.nodes if node.op == "placeholder" - } for node in new_gm.graph.nodes: source_node = None if node.op == "placeholder": source_node = placeholder_source_node_dict.get(node.target) else: - source_node = source_node_dict.get(node.name) + if node.name in source_node_dict: + source_node = source_node_dict.get(node.name) node.meta["from_node"] = [ NodeSource( source_node, diff --git a/torch/fx/node.py b/torch/fx/node.py index 294e15c550235..cb37b6ece75dd 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -753,7 +753,9 @@ def is_impure(self, impure_random: bool = True) -> bool: # between eager and compiled execution, regardless of generator usage return True - return self.target in _side_effectful_functions + from torch._higher_order_ops.effects import has_effects + + return self.target in _side_effectful_functions or has_effects(self.target) # Check if an impure module. if self.op == "call_module": From 789240bae27c957fb59f737a8e171bd221bf4a40 Mon Sep 17 00:00:00 2001 From: angelayi Date: Mon, 17 Nov 2025 16:55:32 -0800 Subject: [PATCH 033/230] [invoke_subgraph] Don't run the graph twice when autograd enabled (#167245) In the [previous PR](https://github.com/pytorch/pytorch/pull/167231/files#diff-e2b74af5d8b538a7d07d18507d27010703742ddad5f819992b55f5abc6d9a502R964-R966) we found that the autograd eager impl of invoke_subgraph calls the subgraph twice. If the subgraph contains effects then effects will be run twice, which is bad. This PR fixes the issue by getting the output metadata from `subgraph`'s `node.meta` if it exists. Differential Revision: [D87392740](https://our.internmc.facebook.com/intern/diff/D87392740) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167245 Approved by: https://github.com/anijain2305 ghstack dependencies: #167231 --- test/higher_order_ops/test_with_effects.py | 6 +- torch/_higher_order_ops/invoke_subgraph.py | 64 ++++++++++++++++++++-- 2 files changed, 59 insertions(+), 11 deletions(-) diff --git a/test/higher_order_ops/test_with_effects.py b/test/higher_order_ops/test_with_effects.py index 38e38c9e13f01..e995959afba47 100644 --- a/test/higher_order_ops/test_with_effects.py +++ b/test/higher_order_ops/test_with_effects.py @@ -960,11 +960,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1): ) recorded_list.clear() - # TODO: seems like invoke_subgraph's py_autograd impl calls the subgraph - # eagerly twice. Once for get_output_metadata and then once for - # InvokeSubgraphAutogradOp. This causes record_memory to be called twice. - with torch.no_grad(): - out2 = ep.module()(x) + out2 = ep.module()(x) self.assertEqual(len(recorded_list), 4) self.assertTrue(torch.allclose(model(x)[0], out2[0])) diff --git a/torch/_higher_order_ops/invoke_subgraph.py b/torch/_higher_order_ops/invoke_subgraph.py index 7d066e132e011..bb0d6cef3ee6f 100644 --- a/torch/_higher_order_ops/invoke_subgraph.py +++ b/torch/_higher_order_ops/invoke_subgraph.py @@ -305,6 +305,62 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None): def get_output_metadata(subgraph, *operands): + """ + Extract metadata about the subgraph outputs WITHOUT executing the subgraph. + This avoids running side-effectful operations twice (once here, once in forward). + We analyze the graph structure statically to extract metadata. + """ + # Unwrap FunctionalizeCtxWrapper if present + if isinstance(subgraph, FunctionalizeCtxWrapper): + subgraph = subgraph.subgraph + + # If not a GraphModule, fall back to execution-based metadata extraction + if not isinstance(subgraph, torch.fx.GraphModule): + return _get_output_metadata_by_execution(subgraph, *operands) + + output_metadata = OutputMetadata() + + # Extract output arguments from the output node + # The output node has args=(output_values,) where output_values is a tuple/list + output_node = next(reversed(subgraph.graph.find_nodes(op="output"))) + output_metadata.num_fw_outs = len(output_node.args[0]) + + for idx, output_arg in enumerate(output_node.args[0]): + if not isinstance(output_arg, torch.fx.Node): + if isinstance(output_arg, int): + output_metadata.indexes_with_symint.add(idx) + output_metadata.indexes_with_no_grad.add(idx) + continue + + # Check node metadata for type information + if output_arg.meta.get("val") is None: + # If we don't have complete metadata for all outputs, fall back to execution + # This is important for correctness (e.g., detecting SymInts) even though it + # runs side-effectful operations + return _get_output_metadata_by_execution(subgraph, *operands) + + val = output_arg.meta["val"] + if isinstance(val, torch.SymInt): + output_metadata.indexes_with_symint.add(idx) + output_metadata.indexes_with_no_grad.add(idx) + elif isinstance(val, torch.Tensor): + # Check if tensor requires grad from metadata + if hasattr(val, "requires_grad") and not val.requires_grad: + output_metadata.indexes_with_no_grad.add(idx) + else: + # Non-tensor, non-symint (shouldn't happen but be safe) + output_metadata.indexes_with_no_grad.add(idx) + + return output_metadata + + +def _get_output_metadata_by_execution(subgraph, *operands): + """ + Fallback: Extract metadata by executing the subgraph. + This should only be used when static analysis fails. + WARNING: This will run side-effectful operations! + """ + with suspend_functionalization(), disable_functional_mode(): with disable_proxy_modes_tracing(): # args are functional tensors, generate some example tensors @@ -324,19 +380,15 @@ def get_output_metadata(subgraph, *operands): num_fw_outs = len(fw_outs) - # Collect the indexes of none in the output to check that the grad - # is None at the corresponding index in the backward. This check is - # performed in the autograd.Function - InvokeSubgraphAutogradOp. - # Also collect the indexes of no_grad in the output to filter out - # the grad_outs in the `backward` method. output_metadata = OutputMetadata() - output_metadata.num_fw_outs = num_fw_outs + for idx, fw_out in enumerate(fw_outs): if isinstance(fw_out, torch.SymInt): output_metadata.indexes_with_symint.add(idx) elif not fw_out.requires_grad: output_metadata.indexes_with_no_grad.add(idx) + return output_metadata From 9abc9aac38229fdd5e52d21ad0c0d65bd4c01ccb Mon Sep 17 00:00:00 2001 From: Garrett Goon Date: Wed, 19 Nov 2025 06:58:43 +0000 Subject: [PATCH 034/230] fix: use grad div factor when fsdp_degree=1 (#167178) `fully_shard`'s `gradient_divide_factor` isn't currently respected when the sharding degree = 1. This PR ensures the division factor applies also in this case. This is a bit of an edge case, but it arises in `torchtitan`, e.g. with expert parallelism and `ep_degree=world_size` we still wrap the routed experts in `fully_shard` because: 1) It lets us take advantage of its mixed-precision mechanisms. 2) [A specific gradient_divide_factor is needed for correctness](https://github.com/pytorch/torchtitan/blob/176498cd4edd4d80e95959a618279681f8295f4c/torchtitan/models/llama4/infra/parallelize.py?plain=1#L364-L369) This PR ensures correctness in the `reduce_scatter_group.size()==1` case. Reproducer and sample failures are in the [gist here](https://gist.github.ibm.com/goon/f67e7559284cc2d322faff1ac59fe382). The net effect is that the EP grads are too-large by a factor of the world size in the case described above. I checked that the proposed fix makes these tests pass. I guess I should add a test for this, too? Pull Request resolved: https://github.com/pytorch/pytorch/pull/167178 Approved by: https://github.com/weifengpy --- .../_composable/fsdp/test_fully_shard_comm.py | 40 ++++++++++++++----- .../fsdp/_fully_shard/_fsdp_collectives.py | 19 +++++---- 2 files changed, 43 insertions(+), 16 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_comm.py b/test/distributed/_composable/fsdp/test_fully_shard_comm.py index ad3064608960d..076c4de69f44f 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_comm.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_comm.py @@ -428,7 +428,14 @@ def test_manual_reshard_with_reshard_after_forward_false(self): @xfailIf(TEST_XPU) # https://github.com/intel/torch-xpu-ops/issues/1571 def test_set_reduce_scatter_divide_factor(self): self.run_subtests( - {"divide_factor": [self.world_size * 2, self.world_size]}, + { + "divide_factor": [self.world_size * 2, self.world_size], + "mesh_shape": [ + (self.world_size,), + (self.world_size // 2, 2), + (self.world_size, 1), + ], + }, self._test_set_reduce_scatter_divide_factor, ) self.run_subtests( @@ -436,18 +443,31 @@ def test_set_reduce_scatter_divide_factor(self): self._test_set_reduce_scatter_divide_factor_mixed_prevision, ) - def _test_set_reduce_scatter_divide_factor(self, divide_factor: float): + def _test_set_reduce_scatter_divide_factor( + self, divide_factor: float, mesh_shape: tuple[int] | tuple[int, int] + ): torch.manual_seed(42) model_args = ModelArgs(dropout_p=0.0, weight_tying=False) model = Transformer(model_args) ref_model = copy.deepcopy(model).to(device_type) ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2) + mesh_dim_names = ("outer",) if len(mesh_shape) == 1 else ("outer", "inner") + mesh = init_device_mesh( + device_type.type, mesh_shape, mesh_dim_names=mesh_dim_names + ) for module in model.modules(): if isinstance(module, TransformerBlock): - fully_shard(module, reshard_after_forward=False) - model = fully_shard(model, reshard_after_forward=False) + fully_shard(module, reshard_after_forward=False, mesh=mesh) + model = fully_shard(model, reshard_after_forward=False, mesh=mesh) optim = torch.optim.AdamW(model.parameters(), lr=1e-2) - model.set_reduce_scatter_divide_factor(divide_factor) + model.set_gradient_divide_factor(divide_factor) + + # Get ref_model params which should have the specific division factor applied + block_params = set() + for ref_mod in ref_model.modules(): + if isinstance(ref_mod, TransformerBlock): + block_params.update(ref_mod.parameters()) + non_block_params = set(ref_model.parameters()) - block_params torch.manual_seed(42 + self.rank) inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type.type) @@ -456,16 +476,18 @@ def _test_set_reduce_scatter_divide_factor(self, divide_factor: float): ref_loss = ref_model(inp).sum() ref_loss.backward() for param in ref_model.parameters(): - param.grad.mul_(1.0 / divide_factor) + factor = divide_factor if param in non_block_params else self.world_size + param.grad.mul_(1.0 / factor) dist.all_reduce(param.grad) loss = model(inp).sum() loss.backward() ref_optim.step() optim.step() - ref_optim.zero_grad() - optim.zero_grad() self.assertEqual(ref_loss, loss) + # Check parity before calling zero_grad so that grads are also checked check_sharded_parity(self, ref_model, model) + ref_optim.zero_grad() + optim.zero_grad() def _test_set_reduce_scatter_divide_factor_mixed_prevision( self, divide_factor: float @@ -484,7 +506,7 @@ def _test_set_reduce_scatter_divide_factor_mixed_prevision( fully_shard(mlp, mp_policy=mp_policy) model = fully_shard(model, mp_policy=mp_policy) optim = torch.optim.AdamW(model.parameters(), lr=1e-2) - model.set_reduce_scatter_divide_factor(divide_factor) + model.set_gradient_divide_factor(divide_factor) torch.manual_seed(42 + self.rank) inp = torch.randn((4, 16), device=device_type.type, dtype=param_dtype) diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py b/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py index 794b755b1f64d..2bd7d24cd7d3f 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py @@ -547,8 +547,12 @@ def foreach_reduce( op=reduce_scatter_op, ) else: - # For single GPU, just copy the input to output (no actual reduce-scatter needed) - reduce_output.copy_(reduce_scatter_input) + # For single GPU, just copy the input to output (no actual reduce-scatter needed), and + # account for a possible gradient_divide_factor. + if gradient_divide_factor is not None: + reduce_output.copy_(reduce_scatter_input / gradient_divide_factor) + else: + reduce_output.copy_(reduce_scatter_input) reduce_scatter_event = reduce_scatter_stream.record_event() post_reduce_stream = reduce_scatter_stream if all_reduce_group is not None: # HSDP or DDP/replicate @@ -721,20 +725,21 @@ def _get_gradient_divide_factors( if all_reduce_group is not None: data_parallel_size *= all_reduce_group.size() - if factor is None: - factor = float(data_parallel_size) - if not overflow_risk and not force_sum_reduction_for_comms: - if factor == data_parallel_size: + if factor is None: # Warning: NCCL ReduceOp.AVG may produce incorrect results with # world size 1. if data_parallel_size == 1: return None, None, ReduceOp.SUM, ReduceOp.SUM return None, None, ReduceOp.AVG, ReduceOp.AVG + if reduce_scatter_group is not None and factor == reduce_scatter_group.size(): + reduce_scatter_op = ReduceOp.AVG else: reduce_scatter_op = torch.distributed._make_nccl_premul_sum(1 / factor) - return None, None, reduce_scatter_op, ReduceOp.SUM + return None, None, reduce_scatter_op, ReduceOp.SUM + if factor is None: + factor = float(data_parallel_size) pre_factor: Optional[float] if overflow_risk: # Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid From 1c0bf2a0bbec25fa0dffa678748da58e919229c2 Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Wed, 19 Nov 2025 07:01:11 +0000 Subject: [PATCH 035/230] [CUDA][Complex] Bump tolerances for `TestFFTCUDA.test_reference_nd__refs_fft_irfftn_cuda_complex64` (#168016) Otherwise we see e.g., ``` Mismatched elements: 1 / 40320 (0.0%) Greatest absolute difference: 0.0001373291015625 at index (0, 4, 0, 2, 3, 5) (up to 0.0001 allowed) Greatest relative difference: 1.633889951335732e-05 at index (0, 4, 0, 2, 3, 5) (up to 1.3e-06 allowed) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/168016 Approved by: https://github.com/nWEIdia, https://github.com/ezyang --- test/test_spectral_ops.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_spectral_ops.py b/test/test_spectral_ops.py index 6284be2aebe9e..522a82cf9a222 100644 --- a/test/test_spectral_ops.py +++ b/test/test_spectral_ops.py @@ -357,6 +357,9 @@ def test_fft_half_and_chalf_not_power_of_two_error(self, device, dtype, op): @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND], allowed_dtypes=(torch.cfloat, torch.cdouble)) + @toleranceOverride({ + torch.cfloat : tol(2e-4, 1.3e-6), + }) def test_reference_nd(self, device, dtype, op): if op.ref is None: raise unittest.SkipTest("No reference implementation") From a5f36a8fda588eca2b15edefbf416cd92f081f30 Mon Sep 17 00:00:00 2001 From: zpcore Date: Wed, 19 Nov 2025 07:52:32 +0000 Subject: [PATCH 036/230] [DTensor] Fix deadlock after fast cache clear (#168069) This is the necessary fix for https://github.com/meta-pytorch/autoparallel/issues/256. ### Issue: when we call `_clear_fast_path_sharding_prop_cache()`, and then `get_thread_local_native_sharding_propagator_cache()`, the code will stuck due to deadlock. ### Cause: When you assign to a Python dict key that already exists: ```C++ thread_dict["__DTensor_fastpath_thread_cache_cleanup"] = old_capsule // capsule #1 stored ... clear_DTensor_sharding_propagator_cache() // call to clean up the cache ... get_thread_local_native_sharding_propagator_cache() { std::lock_guard lock( native_sharding_propagator_cache_cleanup_mutex); // FIRST claims the lock! if (!native_sharding_propagator_cache_DO_NOT_USE.has_value()) { // enter this again because we have cleared the cache. ... // Destroys old_capsule FIRST then stores new_capsule. However, where we destroy the old_capsule, // it will trigger the destructor to claim `native_sharding_propagator_cache_cleanup_mutex` again! thread_dict["__DTensor_fastpath_thread_cache_cleanup"] = new_capsule // SECOND claims the lock before FIRST releases } } ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/168069 Approved by: https://github.com/ezyang --- test/distributed/tensor/test_op_strategy.py | 31 +++++++++++++++-- torch/csrc/autograd/python_variable.cpp | 38 +++++++++++---------- 2 files changed, 49 insertions(+), 20 deletions(-) diff --git a/test/distributed/tensor/test_op_strategy.py b/test/distributed/tensor/test_op_strategy.py index 139f5fb61fac8..72d95efcfa8c9 100644 --- a/test/distributed/tensor/test_op_strategy.py +++ b/test/distributed/tensor/test_op_strategy.py @@ -34,7 +34,11 @@ register_op_strategy, replicate_op_strategy, ) -from torch.distributed.tensor.debug import CommDebugMode +from torch.distributed.tensor.debug import ( + _clear_fast_path_sharding_prop_cache, + _clear_python_sharding_prop_cache, + CommDebugMode, +) from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.distributed._tensor.common_dtensor import ( create_local_tensor_test_class, @@ -479,7 +483,8 @@ def op_strategy_context(op_overload, strategy_func, schema_info=None): del propagator.op_to_schema_info[op_overload] else: propagator.op_to_schema_info[op_overload] = _origin_op_strategy_schema - propagator.propagate_op_sharding.cache.cache_clear() + _clear_fast_path_sharding_prop_cache() + _clear_python_sharding_prop_cache() def detect_exists_identical_opspec(*args, op, mesh, strategy_function) -> bool: @@ -645,6 +650,28 @@ def test_call_with_different_nontensor_args(self): self.assertEqual(out1.full_tensor(), out2.full_tensor()) +class TestStrategyOperation(DTensorTestBase): + @property + def world_size(self): + return 2 + + @with_comms + def test_cache_clean(self): + mesh = self.build_device_mesh() + test_op = torch.ops.mylib.numpy_sin + x = torch.randn(2, device=self.device_type) + y = torch.randn(2, device=self.device_type) + x_dt = distribute_tensor(x, mesh, [Shard(0)]) + y_dt = distribute_tensor(y, mesh, [Shard(0)]) + with op_strategy_context(test_op.default, replicate_op_strategy): + self._test_op_on_dtensor(test_op, x_dt, y_dt) + with self.assertRaisesRegex( + NotImplementedError, + f"Operator {test_op.default} does not have a sharding strategy registered", + ): + self._test_op_on_dtensor(test_op, x_dt, y_dt) + + DistTensorReplicateStrategyRegistrationTestWithLocalTensor = ( create_local_tensor_test_class( DistTensorReplicateStrategyRegistrationTest, diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 6d0bf5d0a8579..de7f3dc53c323 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -1200,25 +1200,27 @@ get_thread_local_native_sharding_propagator_cache() { py::reinterpret_borrow(PyThreadState_GetDict()); // We need to clean up before Python detaches from the thread if // the thread is being destroyed. - thread_dict["__DTensor_fastpath_thread_cache_cleanup"] = - py::capsule(new std::thread::id(this_thread_id), [](void* p) { - auto* ptid = reinterpret_cast(p); - { - std::lock_guard inner_lock( - native_sharding_propagator_cache_cleanup_mutex); - auto it = all_thread_caches.find(*ptid); - if (it != all_thread_caches.end()) { - // We need to both: - // 1) free python objects, and - it->second->reset(); - // 2) make sure we don't try to come back and mess with - // a destroyed thread-local at module unload (e.g., - // process exit) time. - all_thread_caches.erase(it); + if (!thread_dict.contains("__DTensor_fastpath_thread_cache_cleanup")) { + thread_dict["__DTensor_fastpath_thread_cache_cleanup"] = + py::capsule(new std::thread::id(this_thread_id), [](void* p) { + auto* ptid = reinterpret_cast(p); + { + std::lock_guard inner_lock( + native_sharding_propagator_cache_cleanup_mutex); + auto it = all_thread_caches.find(*ptid); + if (it != all_thread_caches.end()) { + // We need to both: + // 1) free python objects, and + it->second->reset(); + // 2) make sure we don't try to come back and mess with + // a destroyed thread-local at module unload (e.g., + // process exit) time. + all_thread_caches.erase(it); + } } - } - delete ptid; - }); + delete ptid; + }); + } } return native_sharding_propagator_cache_DO_NOT_USE.value(); } From e5a766ece41d9591a7a5c95cbbf6af3b5aceed0e Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Mon, 17 Nov 2025 22:11:28 -0800 Subject: [PATCH 037/230] [user-streams] Insert backward syncs (#167747) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167747 Approved by: https://github.com/soulitzer --- test/dynamo/test_streams.py | 4 ++ .../_functorch/_aot_autograd/graph_capture.py | 4 +- torch/_functorch/_aot_autograd/streams.py | 61 +++++++++++++++++++ 3 files changed, 68 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_streams.py b/test/dynamo/test_streams.py index 3b4aff724eee4..967bedb9ebaae 100644 --- a/test/dynamo/test_streams.py +++ b/test/dynamo/test_streams.py @@ -585,6 +585,10 @@ def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"): # Annotation: {'stream': 1} mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None + # No stacktrace found for following nodes + record_event_default = torch.ops.streams.record_event.default(2, 1); record_event_default = None + wait_event_default = torch.ops.streams.wait_event.default(2, 0); wait_event_default = None + # Annotation: {'stream': 0} add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None return (add_3, add_2) diff --git a/torch/_functorch/_aot_autograd/graph_capture.py b/torch/_functorch/_aot_autograd/graph_capture.py index b6ea08a802240..f17a516183975 100644 --- a/torch/_functorch/_aot_autograd/graph_capture.py +++ b/torch/_functorch/_aot_autograd/graph_capture.py @@ -33,7 +33,7 @@ handle_effect_tokens_fn, ) from .schemas import AOTConfig, FxValue, SubclassMeta, TraceFn, ViewAndMutationMeta -from .streams import assign_backward_streams +from .streams import assign_backward_streams, insert_backward_syncs from .utils import ( call_and_expect_output_descs, copy_fwd_metadata_to_bw_nodes, @@ -477,6 +477,8 @@ def aot_dispatch_autograd_graph( # After copying metadata, assign streams to gradient accumulation nodes assign_backward_streams(fx_g) + insert_backward_syncs(fx_g) + fx_g.graph.eliminate_dead_code() if not aot_config.disable_functionalization: # There should be *NO* mutating ops in the graph at this point. diff --git a/torch/_functorch/_aot_autograd/streams.py b/torch/_functorch/_aot_autograd/streams.py index f78a2c6cad1de..1b4f5ded051e3 100644 --- a/torch/_functorch/_aot_autograd/streams.py +++ b/torch/_functorch/_aot_autograd/streams.py @@ -3,6 +3,7 @@ import torch.fx import torch.fx.traceback from torch._dynamo.graph_utils import _get_flat_args +from torch._dynamo.variables.streams import get_current_stream, new_event Node: TypeAlias = torch.fx.Node @@ -12,6 +13,14 @@ def is_gradient_acc(node: Node) -> bool: return node.meta.get("is_gradient_acc", False) +def is_bwd_node(node: Node) -> bool: + return node.meta.get("partitioner_tag") == "is_backward" + + +def get_device(node: Node) -> torch.device: + return node.meta["val"].device + + def get_stream(node: Node) -> Optional[int]: maybe_annotation = node.meta.get("custom", None) if maybe_annotation is not None: @@ -20,6 +29,13 @@ def get_stream(node: Node) -> Optional[int]: return None +def get_stream_or_current_stream(node: Node) -> int: + ind = get_stream(node) + if ind is None: + ind = get_current_stream(get_device(node)) + return ind + + def set_stream(node: Node, ind: int) -> None: if "custom" in node.meta: node.meta["custom"].update({"stream": ind}) @@ -27,6 +43,36 @@ def set_stream(node: Node, ind: int) -> None: node.meta["custom"] = {"stream": ind} +def insert_sync( + graph: torch.fx.Graph, + consumer: Node, + producer: Node, + node_to_wait_event_ind: dict[Node, int], +) -> None: + if producer not in node_to_wait_event_ind: + node_to_wait_event_ind[producer] = new_event() + + with graph.inserting_after(producer): + node = graph.call_function( + torch.ops.streams.record_event.default, + ( + node_to_wait_event_ind[producer], + get_stream_or_current_stream(producer), + ), + ) + node.meta["partitioner_tag"] = "must_be_in_backward" + + with graph.inserting_before(consumer): + node = graph.call_function( + torch.ops.streams.wait_event.default, + ( + node_to_wait_event_ind[producer], + get_stream_or_current_stream(consumer), + ), + ) + node.meta["partitioner_tag"] = "must_be_in_backward" + + def assign_backward_streams(gm: torch.fx.GraphModule) -> None: """Assigns backward streams to gradient accumulation nodes""" @@ -51,3 +97,18 @@ def assign_backward_streams(gm: torch.fx.GraphModule) -> None: if ind is not None: set_stream(node, ind) break + + +def insert_backward_syncs(gm: torch.fx.GraphModule) -> None: + """Inserts stream syncs for backward nodes if consumer and producer are on different streams""" + node_to_wait_event_ind = {} + for node in gm.graph.nodes: + if is_bwd_node(node): + flat_args = _get_flat_args(node, {}) + cur_node_stream = get_stream(node) + + for arg in flat_args: + if is_bwd_node(arg): + arg_stream = get_stream(arg) + if arg_stream != cur_node_stream and get_device(arg).type != "cpu": + insert_sync(gm.graph, node, arg, node_to_wait_event_ind) From 9f94c7b8ee6077594399a9d8e89f7a50e2f02413 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Wed, 19 Nov 2025 15:20:46 +0000 Subject: [PATCH 038/230] [fix] Assign CUDAEvent external member properly (#167711) # Motivation This PR aims to fix the bug that the moved-to object's `external_` member is not assigned correctly. # Additional Context It's not fine to swap the valid value and the invalid value. We'd just need to prevent double-free. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167711 Approved by: https://github.com/albanD --- aten/src/ATen/cuda/CUDAEvent.h | 17 ++++++++---- aten/src/ATen/test/CMakeLists.txt | 1 + aten/src/ATen/test/cuda_event_test.cpp | 36 ++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 5 deletions(-) create mode 100644 aten/src/ATen/test/cuda_event_test.cpp diff --git a/aten/src/ATen/cuda/CUDAEvent.h b/aten/src/ATen/cuda/CUDAEvent.h index 81b4643ac0418..7a650b9cbcf35 100644 --- a/aten/src/ATen/cuda/CUDAEvent.h +++ b/aten/src/ATen/cuda/CUDAEvent.h @@ -238,11 +238,18 @@ struct TORCH_CUDA_CPP_API CUDAEvent { } void moveHelper(CUDAEvent&& other) { - std::swap(flags_, other.flags_); - std::swap(is_created_, other.is_created_); - std::swap(was_recorded_, other.was_recorded_); - std::swap(device_index_, other.device_index_); - std::swap(event_, other.event_); + // Transfer ownership of all state from other to this + flags_ = other.flags_; + is_created_ = other.is_created_; + was_recorded_ = other.was_recorded_; + external_ = other.external_; + device_index_ = other.device_index_; + event_ = other.event_; + + // Reset other to a valid empty state to prevent double-free + // The moved-from object must not attempt to destroy the event + other.is_created_ = false; + other.event_ = cudaEvent_t{}; } }; diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt index a522e7ab76cf4..923b7119a42fc 100644 --- a/aten/src/ATen/test/CMakeLists.txt +++ b/aten/src/ATen/test/CMakeLists.txt @@ -65,6 +65,7 @@ list(APPEND ATen_CUDA_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cuda_device_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cuda_distributions_test.cu ${CMAKE_CURRENT_SOURCE_DIR}/cuda_dlconvertor_test.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cuda_event_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cuda_exchange_device_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cuda_generator_test.cu ${CMAKE_CURRENT_SOURCE_DIR}/cuda_half_test.cu diff --git a/aten/src/ATen/test/cuda_event_test.cpp b/aten/src/ATen/test/cuda_event_test.cpp new file mode 100644 index 0000000000000..7c58688e1ef9d --- /dev/null +++ b/aten/src/ATen/test/cuda_event_test.cpp @@ -0,0 +1,36 @@ +#include + +#include +#include +#include + +TEST(CUDAEventTest, testCUDAExternalEvent) { + if (!at::cuda::is_available()) { + return; + } + + // Create two external CUDA events + unsigned int flags = cudaEventDefault | cudaEventExternal; + auto event1 = at::cuda::CUDAEvent(flags); + auto event2 = at::cuda::CUDAEvent(flags); + // Ensure external CUDAEvent remain valid and functional after being moved. + auto start_event = std::move(event1); + auto end_event = std::move(event2); + + auto stream = at::cuda::getStreamFromPool(); + at::cuda::setCurrentCUDAStream(stream); + + auto graph = at::cuda::CUDAGraph(); + graph.capture_begin(); + start_event.record(); + at::cuda::sleep(100000); + end_event.record(); + graph.capture_end(); + + // External events should correctly record timestamps even when used inside + // CUDA graphs, and elapsed_time() between them should be positive. + stream.synchronize(); + graph.replay(); + at::cuda::device_synchronize(); + EXPECT_TRUE(start_event.elapsed_time(end_event) > 0); +} From 7a963ffc0b5d4ed0ffb4ac2f88a89fff1dccda40 Mon Sep 17 00:00:00 2001 From: dolpm <34420038+dolpm@users.noreply.github.com> Date: Wed, 19 Nov 2025 10:11:35 +0000 Subject: [PATCH 039/230] LocalTensor for random_ops tests (#166540) adds support for randomness in localtensor. tl;dr it needs to be able to handle RNG the same way (i.e., rng tracking/syncing across shards, user-defined seeds, user-defined generators, etc. we extend the existing OffsetBasedRNGTracker to play nicely with localtensor's setup, creating a few small subclasses and patching the core RNG logic to manage the per-rank seeds and offsets correctly. i still haven't done the per-rank generator support (since the existing tests imply a globally-seeded generator), but that it something that should be done. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166540 Approved by: https://github.com/dzmitry-huba --- test/distributed/tensor/test_random_ops.py | 287 +++++---- torch/distributed/_local_tensor/__init__.py | 575 +++++++++++++++++- torch/distributed/tensor/_dispatch.py | 2 + torch/distributed/tensor/_random.py | 13 + torch/random.py | 4 + .../distributed/_tensor/common_dtensor.py | 3 + 6 files changed, 765 insertions(+), 119 deletions(-) diff --git a/test/distributed/tensor/test_random_ops.py b/test/distributed/tensor/test_random_ops.py index 61b88ee169e2e..4ff470511f2ad 100644 --- a/test/distributed/tensor/test_random_ops.py +++ b/test/distributed/tensor/test_random_ops.py @@ -6,8 +6,8 @@ import torch import torch.distributed._functional_collectives as funcol import torch.distributed.tensor._random as random +from torch.distributed._local_tensor import LocalTensor, maybe_run_for_local_tensor from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.distributed_c10d import broadcast_object_list from torch.distributed.fsdp import fully_shard from torch.distributed.tensor import ( DeviceMesh, @@ -26,6 +26,7 @@ from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( + create_local_tensor_test_class, DTensorTestBase, skip_if_lt_x_gpu, skip_unless_torch_gpu, @@ -34,9 +35,12 @@ from torch.utils._typing_utils import not_none -def get_generator_seed_for_device_type(device_type: str) -> int: - device_module = torch.get_device_module(device_type) - return device_module.get_rng_state()[:8].view(torch.int64).item() +def get_generator_seed_for_device_type(device_type: str): + from torch.distributed._local_tensor import ( + get_generator_seed_for_device_type as _get_seed, + ) + + return _get_seed(device_type) class DistTensorRandomInitTest(DTensorTestBase): @@ -134,9 +138,6 @@ def test_meta_tensor_init(self): torch.empty(*size, device="meta"), device_mesh, [Replicate()] ) - # the tensor slice on the current rank - self_slice = slice(1024 * self.rank, 1024 * self.rank + 1024) - # Test 1: enable the distribute region for RNG (by default) self.assertTrue(meta_dtensor.is_meta) # Tensor meta init @@ -150,16 +151,23 @@ def test_meta_tensor_init(self): dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) ) - # compare with local tensors from other ranks - for other_rank in range(self.world_size): - # the RNG result on each rank are the same because they're replicated - if self.rank != other_rank: - # other rank should have an identical local tensor - other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024) - self.assertEqual( - gathered_local_tensors[self_slice, :], - gathered_local_tensors[other_slice, :], - ) + @maybe_run_for_local_tensor + def compute_rankwise_if_local_tensor(gathered_local_tensors, rank): + # the tensor slice on the current rank + self_slice = slice(1024 * rank, 1024 * rank + 1024) + + # compare with local tensors from other ranks + for other_rank in range(self.world_size): + # the RNG result on each rank are the same because they're replicated + if rank != other_rank: + # other rank should have an identical local tensor + other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024) + self.assertEqual( + gathered_local_tensors[self_slice, :], + gathered_local_tensors[other_slice, :], + ) + + compute_rankwise_if_local_tensor(gathered_local_tensors.wait(), self.rank) # Test 2: disable the distribute region for RNG self.assertTrue(meta_dtensor.is_meta) @@ -175,15 +183,7 @@ def test_meta_tensor_init(self): dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) ) - # compare with local tensors from other ranks - for other_rank in range(self.world_size): - # the RNG result on each rank are the same even without the help of DTensor's RNG infra, - # since the default RNG is the same across ranks. - if self.rank != other_rank: - other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024) - self.assertEqual( - local_tensor[self_slice, :], local_tensor[other_slice, :] - ) + compute_rankwise_if_local_tensor(local_tensor.wait(), self.rank) @with_comms @skip_unless_torch_gpu @@ -224,13 +224,17 @@ def test_tp_model_meta_init(self): group=WORLD, ) - # verify the weights are initialized differently on all ranks - for other_rank in range(self.world_size): - if self.rank != other_rank: - self.assertNotEqual( - weight_local, - weight_gather[other_rank : other_rank + 1, :], - ) + @maybe_run_for_local_tensor + def compute_rankwise_if_local_tensor(weight_local, weight_gather, rank): + # verify the weights are initialized differently on all ranks + for other_rank in range(self.world_size): + if rank != other_rank: + self.assertNotEqual( + weight_local, + weight_gather[other_rank : other_rank + 1, :], + ) + + compute_rankwise_if_local_tensor(weight_local, weight_gather.wait(), self.rank) @with_comms @skip_if_lt_x_gpu(4) @@ -277,13 +281,17 @@ def test_fsdp_tp_model_meta_init(self): group=WORLD, ) - # verify the weights are initialized differently on all ranks - for other_rank in range(self.world_size): - if self.rank != other_rank: - self.assertNotEqual( - weight_local, - weight_gather[other_rank : other_rank + 1, :], - ) + @maybe_run_for_local_tensor + def compute_rankwise_if_local_tensor(weight_local, weight_gather, rank): + # verify the weights are initialized differently on all ranks + for other_rank in range(self.world_size): + if rank != other_rank: + self.assertNotEqual( + weight_local, + weight_gather[other_rank : other_rank + 1, :], + ) + + compute_rankwise_if_local_tensor(weight_local, weight_gather.wait(), self.rank) class DistTensorRandomOpTest(DTensorTestBase): @@ -291,9 +299,14 @@ class DistTensorRandomOpTest(DTensorTestBase): @skip_unless_torch_gpu def test_rng_tracker_init(self): torch.manual_seed(self.rank) - object_list = [torch.initial_seed()] - broadcast_object_list(object_list) - seed_from_rank_0 = int(object_list[0]) + seed_local = ( + torch.zeros_like(torch.empty(1), device=self.device_type) + + torch.initial_seed() + ) + torch.distributed.broadcast(seed_local, src=0) + # if localtensor, it should automaticall reconcile after the broadcast + # since all virtual ranks should have rank 0's initial_seed() + seed_from_rank_0 = seed_local device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) # seed synchronization now does NOT happen after the first `distribute_tensor` @@ -344,15 +357,19 @@ def test_manual_seed(self): @with_comms @skip_unless_torch_gpu def test_manual_seed_submesh(self): - # the current rank is not a part of the mesh - single_rank_device_mesh = DeviceMesh( - self.device_type, [(self.rank + 1) % self.world_size] - ) - with self.assertRaisesRegex( - RuntimeError, - "manual_seed requires the current rank to be a part of the device mesh", - ): - manual_seed(self.rank, single_rank_device_mesh) + @maybe_run_for_local_tensor + def compute_rankwise_if_local_tensor(rank): + # the current rank is not a part of the mesh + single_rank_device_mesh = DeviceMesh( + self.device_type, [(rank + 1) % self.world_size], _rank=rank + ) + with self.assertRaisesRegex( + RuntimeError, + "manual_seed requires the current rank to be a part of the device mesh", + ): + manual_seed(rank, single_rank_device_mesh) + + compute_rankwise_if_local_tensor(self.rank) @with_comms @skip_unless_torch_gpu @@ -394,7 +411,7 @@ def test_pipeline_parallel_manual_seed(self): for other_rank in range(self.world_size): if self.rank != other_rank: self.assertNotEqual( - spmd_dtensor.to_local(), + spmd_dtensor, tensor_gather[2 * other_rank : 2 * (other_rank + 1), :], ) @@ -428,16 +445,20 @@ def test_deterministic_dropout_1d(self): dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) ) - # compare with local tensors from other ranks - self_slice = slice(4 * self.rank, 4 * self.rank + 4) - for other_rank in range(self.world_size): - if self.rank != other_rank: - # other rank should have an identical local tensor - other_slice = slice(4 * other_rank, 4 * other_rank + 4) - self.assertEqual( - local_tensor[self_slice, :], - local_tensor[other_slice, :], - ) + @maybe_run_for_local_tensor + def compute_rankwise_if_local_tensor(local_tensor, rank): + # compare with local tensors from other ranks + self_slice = slice(4 * rank, 4 * rank + 4) + for other_rank in range(self.world_size): + if rank != other_rank: + # other rank should have an identical local tensor + other_slice = slice(4 * other_rank, 4 * other_rank + 4) + self.assertEqual( + local_tensor[self_slice, :], + local_tensor[other_slice, :], + ) + + compute_rankwise_if_local_tensor(local_tensor, self.rank) @with_comms @skip_unless_torch_gpu @@ -454,16 +475,20 @@ def test_deterministic_rand_1d(self): dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) ) - # compare with local tensors from other ranks - self_slice = slice(4 * self.rank, 4 * self.rank + 4) - for other_rank in range(self.world_size): - if self.rank != other_rank: - # other rank should have a different local tensor for shard placement - other_slice = slice(4 * other_rank, 4 * other_rank + 4) - self.assertNotEqual( - local_tensor[self_slice, :], - local_tensor[other_slice, :], - ) + @maybe_run_for_local_tensor + def compute_rankwise_if_local_tensor(local_tensor, rank): + # compare with local tensors from other ranks + self_slice = slice(4 * rank, 4 * rank + 4) + for other_rank in range(self.world_size): + if rank != other_rank: + # other rank should have an identical local tensor for replicate placement + other_slice = slice(4 * other_rank, 4 * other_rank + 4) + self.assertNotEqual( + local_tensor[self_slice, :], + local_tensor[other_slice, :], + ) + + compute_rankwise_if_local_tensor(local_tensor, self.rank) # we should set manual seed to the same value on all SPMD ranks torch.manual_seed(0) @@ -472,16 +497,20 @@ def test_deterministic_rand_1d(self): dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) ) - # compare with local tensors from other ranks - self_slice = slice(4 * self.rank, 4 * self.rank + 4) - for other_rank in range(self.world_size): - if self.rank != other_rank: - # other rank should have an identical local tensor for replicate placement - other_slice = slice(4 * other_rank, 4 * other_rank + 4) - self.assertEqual( - local_tensor[self_slice, :], - local_tensor[other_slice, :], - ) + @maybe_run_for_local_tensor + def compute_rankwise_if_local_tensor(local_tensor, rank): + # compare with local tensors from other ranks + self_slice = slice(4 * rank, 4 * rank + 4) + for other_rank in range(self.world_size): + if rank != other_rank: + # other rank should have an identical local tensor for replicate placement + other_slice = slice(4 * other_rank, 4 * other_rank + 4) + self.assertEqual( + local_tensor[self_slice, :], + local_tensor[other_slice, :], + ) + + compute_rankwise_if_local_tensor(local_tensor, self.rank) @with_comms @skip_if_lt_x_gpu(4) @@ -539,7 +568,12 @@ def test_deterministic_uniform_2d(self): shard_linear_idx = random._rng_tracker._calc_shard_linear_idx( shard_coord, shard_size ) - self.assertEqual(shard_linear_idx, shard_index[self.rank]) + + @maybe_run_for_local_tensor + def check_shard_index(shard_linear_idx, rank): + self.assertEqual(shard_linear_idx, shard_index[rank]) + + check_shard_index(shard_linear_idx, self.rank) # compute local size and offset _, local_shard_offset = compute_local_shape_and_global_offset( @@ -578,16 +612,27 @@ def test_deterministic_uniform_2d(self): # allgather the local tensors full_tensor = dtensor.full_tensor() - # compare local tensor with each other shard - for other_local_shard in local_shard_comb: - other_local_shard_offset, _ = zip(*other_local_shard) - slice_idx = [ - slice(offset, offset + size) for offset, size in other_local_shard - ] - if local_shard_offset == other_local_shard_offset: - self.assertEqual(full_tensor[tuple(slice_idx)], local_tensor) - else: - self.assertNotEqual(full_tensor[tuple(slice_idx)], local_tensor) + full_tensor = ( + full_tensor.reconcile() + if isinstance(full_tensor, LocalTensor) + else full_tensor + ) + + @maybe_run_for_local_tensor + def blockwise_iter_if_localtensor(local_tensor, local_shard_offset): + # compare local tensor with each other shard + for other_local_shard in local_shard_comb: + other_local_shard_offset, _ = zip(*other_local_shard) + slice_idx = [ + slice(offset, offset + size) + for offset, size in other_local_shard + ] + if local_shard_offset == other_local_shard_offset: + self.assertEqual(full_tensor[tuple(slice_idx)], local_tensor) + else: + self.assertNotEqual(full_tensor[tuple(slice_idx)], local_tensor) + + blockwise_iter_if_localtensor(local_tensor, local_shard_offset) class DistTensorRandomOpsTest3D(DTensorTestBase): @@ -641,22 +686,46 @@ def test_hsdp_tp_model_meta_init(self): group=WORLD, ) - # verify the weights are initialized differently on all ranks - shard_dim_0_len = self.world_size // 4 - for other_rank in range(self.world_size): - other_rank_dim_0_start = other_rank * shard_dim_0_len - other_rank_dim_0_end = other_rank_dim_0_start + shard_dim_0_len - if self.rank % 4 != other_rank % 4: - self.assertNotEqual( - weight_local, - weight_gather[other_rank_dim_0_start:other_rank_dim_0_end, :], - ) - else: - self.assertEqual( - weight_local, - weight_gather[other_rank_dim_0_start:other_rank_dim_0_end, :], - ) + weight_gather = weight_gather.wait() + + weight_gather = ( + weight_gather.reconcile() + if isinstance(weight_gather, LocalTensor) + else weight_gather + ) + @maybe_run_for_local_tensor + def compute_rankwise_if_local_tensor(weight_local, rank): + # verify the weights are initialized differently on all ranks + shard_dim_0_len = self.world_size // 4 + for other_rank in range(self.world_size): + other_rank_dim_0_start = other_rank * shard_dim_0_len + other_rank_dim_0_end = other_rank_dim_0_start + shard_dim_0_len + if rank % 4 != other_rank % 4: + self.assertNotEqual( + weight_local, + weight_gather[other_rank_dim_0_start:other_rank_dim_0_end, :], + ) + else: + self.assertEqual( + weight_local, + weight_gather[other_rank_dim_0_start:other_rank_dim_0_end, :], + ) + + compute_rankwise_if_local_tensor(weight_local, self.rank) + + +DistTensorRandomInitTestWithLocalTensor = create_local_tensor_test_class( + DistTensorRandomInitTest, +) + +DistTensorRandomOpTestWithLocalTensor = create_local_tensor_test_class( + DistTensorRandomOpTest, +) + +DistTensorRandomOpsTest3DWithLocalTensor = create_local_tensor_test_class( + DistTensorRandomOpsTest3D, +) if __name__ == "__main__": run_tests() diff --git a/torch/distributed/_local_tensor/__init__.py b/torch/distributed/_local_tensor/__init__.py index c186694df94e7..db03d26227911 100644 --- a/torch/distributed/_local_tensor/__init__.py +++ b/torch/distributed/_local_tensor/__init__.py @@ -76,7 +76,11 @@ from torch.nested._internal.nested_int import NestedIntNode from torch.utils import _pytree as pytree from torch.utils._mode_utils import no_dispatch -from torch.utils._python_dispatch import return_and_correct_aliasing, TorchDispatchMode +from torch.utils._python_dispatch import ( + _get_current_dispatch_mode_stack, + return_and_correct_aliasing, + TorchDispatchMode, +) from torch.utils.checkpoint import get_device_states, set_device_states @@ -86,6 +90,12 @@ from . import _c10d +def _is_in_fake_tensor_mode() -> bool: + return any( + isinstance(mode, FakeTensorMode) for mode in _get_current_dispatch_mode_stack() + ) + + def _is_inplace_op(op: OpOverload | Callable[..., Any]) -> bool: return ( isinstance(op, OpOverload) @@ -256,21 +266,31 @@ def _for_each_rank_run_func( a.wait() if isinstance(a, AsyncCollectiveTensor) else a for a in flat_args ] - # NB: Before invoking an op we are collecting rng states from CPU and - # CUDA devices such that we can reset to the same before invoking op - # for each rank. This is not very efficient and will likely be revisited - # to support per rank rng state. - rng_state = _get_rng_state() + lm = enabled_local_tensor_mode() + use_per_rank_rng = lm is not None and len(lm._per_rank_rng_states) > 0 + + global_rng_state = None if use_per_rank_rng else _get_rng_state() + flat_rank_rets = {} default_value: Tensor | None = None for r in sorted(ranks): - _set_rng_state(*rng_state) + if use_per_rank_rng: + assert lm is not None + _set_rng_state(*lm._per_rank_rng_states[r]) + else: + assert global_rng_state is not None + _set_rng_state(*global_rng_state) + rank_flat_args = [_map_to_rank_local_val(a, r) for a in flat_args] rank_args, rank_kwargs = pytree.tree_unflatten(rank_flat_args, args_spec) rank_ret = func(*rank_args, **rank_kwargs) flat_rank_rets[r] = rank_ret + if use_per_rank_rng: + assert lm is not None + lm._per_rank_rng_states[r] = _get_rng_state() + if default_value is None and func is torch.ops.aten.split.Tensor: # If split happens over the dimension smaller than the number of chunks # it is possible that some ranks will produce shorter lists of chunks. @@ -437,6 +457,247 @@ def wrap_int(self, num: int) -> "LocalIntNode | ConstantIntNode": return ConstantIntNode(num) +class _LocalDeviceHandle: + """ + Wrapper around device module (e.g., torch.cuda) with automatic LocalTensor semantics. + + This class wraps device modules and automatically handles per-rank operations in + LocalTensor mode: + - get_rng_state() returns a LocalTensor with per-rank states + - set_rng_state(LocalTensor) sets per-rank states + + When not in LocalTensor mode, it delegates directly to the underlying device handle. + """ + + def __init__(self, device_handle, device_type: str): + """ + Initialize the local device handle wrapper. + + Args: + device_handle: The underlying device module (e.g., torch.cuda) + device_type: Device type string (e.g., "cuda", "cpu") + """ + self._device_handle = device_handle + self._device_type = device_type + + def get_rng_state(self): + """ + Get RNG state, automatically returning LocalTensor in LocalTensor mode. + + Returns: + LocalTensor in LocalTensor mode, regular Tensor otherwise + """ + lm = enabled_local_tensor_mode() + if not lm: + return self._device_handle.get_rng_state() + + original_state = _get_rng_state() + per_rank_states = {} + + try: + for rank in lm.ranks: + # We need to set-then-get instead of directly copying lm._per_rank_rng_states[rank] + # because they have different structures: + # - lm._per_rank_rng_states[rank] is a tuple: (cpu_state, {device_idx: cuda_state}) + # - self._device_handle.get_rng_state() returns just the device-specific tensor + # So we temporarily restore the full RNG state (CPU + all CUDA devices) for this rank, + # then extract only the specific device's state tensor that we need. + if rank in lm._per_rank_rng_states: + _set_rng_state(*lm._per_rank_rng_states[rank]) + + per_rank_states[rank] = self._device_handle.get_rng_state() + finally: + _set_rng_state(*original_state) + + # pyrefly: ignore [bad-argument-type, bad-argument-count] + return LocalTensor(per_rank_states) + + def set_rng_state(self, state): + """ + Set RNG state, automatically handling LocalTensor input. + + Args: + state: Regular Tensor or LocalTensor with per-rank states + """ + if isinstance(state, LocalTensor): + lm = enabled_local_tensor_mode() + assert lm is not None + + # Similar to get_rng_state but in reverse: we need to convert from + # device-specific tensor format to full state tuple format. + # - state._local_tensors[rank] contains just the device-specific RNG state tensor + # - lm._per_rank_rng_states[rank] needs a tuple: (cpu_state, {device_idx: cuda_state}) + # So we set the device's state with the rank-specific tensor, then _get_rng_state() + # captures both CPU and CUDA states into the tuple format that _per_rank_rng_states expects. + for rank, rank_state in state._local_tensors.items(): + self._device_handle.set_rng_state(rank_state.to("cpu")) + lm._per_rank_rng_states[rank] = _get_rng_state() + else: + self._device_handle.set_rng_state(state.to("cpu")) + + def __getattr__(self, name): + """Delegate all other attributes to the underlying device module.""" + return getattr(self._device_handle, name) + + +class _LocalOffsetBasedRNGTracker: + """ + LocalTensor-specific RNG tracker for DTensor random operations. + + This class manages per-rank RNG states when running in LocalTensor mode, + using _LocalPhiloxState to track different offsets for each virtual rank. + It is instantiated and used by OffsetBasedRNGTracker when in LocalTensor mode. + + Much of this is derived from OffsetBasedRNGTracker: + https://github.com/pytorch/pytorch/blob/402c46503002f98ccfc023a733081fb0719223a1/torch/distributed/tensor/_random.py#L182 + """ + + def __init__(self, device_type: str = "cuda"): + """Initialize the LocalTensor RNG tracker.""" + from torch.distributed.device_mesh import _get_device_handle + + self._device_type = device_type + self._device_handle = _LocalDeviceHandle( + _get_device_handle(device_type), device_type + ) + self.distribute_region_enabled = True + self._device_mesh = None + + @property + def _device(self): + return torch.device(self._device_type, torch.cuda.current_device()) + + def _set_pre_op_offset(self, state, spec) -> None: + """Compute and set per-rank offsets before the random operation.""" + from torch.distributed.tensor._ops.utils import prod + from torch.distributed.tensor._utils import ( + _compute_local_shape_and_global_offset, + ) + from torch.distributed.tensor.placement_types import Shard + + lm = enabled_local_tensor_mode() + assert lm is not None + + state._per_rank_offsets = {} + + for rank in lm.ranks: + # compute this rank's coordinate in the mesh + mesh_coords = [] + for mesh_dim_idx in range(spec.mesh.ndim): + mesh_dim_size = spec.mesh.size(mesh_dim_idx) + # calculate rank's coordinate in this mesh dimension + num_chunks_after = 1 + for j in range(mesh_dim_idx + 1, spec.mesh.ndim): + num_chunks_after *= spec.mesh.size(j) + coord = (rank // num_chunks_after) % mesh_dim_size + mesh_coords.append(coord) + + # compute local shape and global offset for this rank + local_shape, global_offset = _compute_local_shape_and_global_offset( + spec.shape, spec.mesh.shape, mesh_coords, spec.placements + ) + + # compute shard offset based on placements + shard_offset = 1 + for idx, placement in enumerate(spec.placements): + if isinstance(placement, Shard): + shard_dim = placement.dim + shard_offset *= global_offset[shard_dim] + 1 + + # get current offset for this rank + current_offset = int( + state._per_rank_states[rank][8:].view(dtype=torch.int64).item() + ) + + # compute local size + local_size = prod(local_shape) + + # compute new offset (must be multiple of 4) + shard_linear_idx = shard_offset - 1 + offset_incr = (shard_linear_idx * local_size + 3) // 4 * 4 + state._per_rank_offsets[rank] = current_offset + offset_incr + + def _set_post_op_offset(self, state, spec, old_offset) -> None: + """Set per-rank offsets after the random operation.""" + from torch.distributed.tensor._ops.utils import prod + + lm = enabled_local_tensor_mode() + assert lm is not None + + dtensor_shape = spec.shape + numel = prod(dtensor_shape) + # offset must be multiple of 4 + numel = (numel + 3) // 4 * 4 + + if not hasattr(state, "_per_rank_offsets"): + state._per_rank_offsets = {} + + # handle LocalIntNode old_offset (different values per rank) + if isinstance(old_offset, SymInt) and isinstance(old_offset.node, LocalIntNode): + for rank in lm.ranks: + rank_old_offset = old_offset.node._local_ints[rank] + state._per_rank_offsets[rank] = rank_old_offset + numel + else: + # same old_offset for all ranks + old_offset_int = ( + int(old_offset) if isinstance(old_offset, SymInt) else old_offset + ) + for rank in lm.ranks: + state._per_rank_offsets[rank] = old_offset_int + numel + + @contextlib.contextmanager + def _distribute_region(self, spec, generator=None): + """Context manager for LocalTensor mode distribute region.""" + lm = enabled_local_tensor_mode() + assert lm is not None + + # get base state + if generator is not None: + base_state_tensor = generator.get_state() + per_rank_states = {rank: base_state_tensor.clone() for rank in lm.ranks} + # pyrefly: ignore [bad-argument-type, bad-argument-count] + base_state_tensor = LocalTensor(per_rank_states) + else: + base_state_tensor = self._device_handle.get_rng_state() + + state = _LocalPhiloxState(base_state_tensor) + + if self.distribute_region_enabled: + # sync to rank 0's state if no explicit generator + if generator is None: + rank_0_state = lm._per_rank_rng_states[0] + rank_0_cpu, rank_0_cuda = rank_0_state + + if self._device.type == "cuda": + assert self._device.index in rank_0_cuda + rank_0_device_state = rank_0_cuda[self._device.index] + else: + rank_0_device_state = rank_0_cpu + + from torch.distributed.tensor._random import _PhiloxState + + rank_0_philox = _PhiloxState(rank_0_device_state) + state.seed = rank_0_philox.seed + state.offset = rank_0_philox.offset + + old_offset = state.offset + self._set_pre_op_offset(state, spec) + state.apply_to_local_tensor_mode(self._device_handle) + + try: + yield + finally: + self._set_post_op_offset(state, spec, old_offset) + state.apply_to_local_tensor_mode(self._device_handle) + else: + yield + + # maybe reset generator to rank 0's state + if generator is not None: + rank_0_state = state._per_rank_states[0] + generator.set_state(rank_0_state) + + _LOCAL_TENSOR_ATTR_PREFIX = "_local_tensor_" @@ -597,6 +858,7 @@ def __deepcopy__(self, memo: dict[Any, Any] | None) -> "LocalTensor": local_tensors_copy = { r: copy.deepcopy(t, memo) for r, t in self._local_tensors.items() } + # pyrefly: ignore [bad-argument-type, bad-argument-count] return LocalTensor(local_tensors_copy, self.requires_grad) def __repr__(self) -> str: # type: ignore[override] @@ -636,6 +898,7 @@ def __tensor_unflatten__( local_tensors = { _from_local_tensor_attr(a): t for a, t in inner_tensors.items() } + # pyrefly: ignore [bad-argument-type, bad-argument-count] return LocalTensor(local_tensors) @classmethod @@ -774,12 +1037,28 @@ def __init__(self, ranks: Union[int, frozenset[int]]): self.ranks = ranks self._disable = False self._old_get_coordinate = None + self._old_torch_manual_seed: Any = None + self._old_torch_initial_seed: Any = None + self._per_rank_rng_states: dict[ + int, tuple[torch.Tensor, dict[int, torch.Tensor]] + ] = {} def __enter__(self) -> "LocalTensorMode": self._disable = False self._patch_device_mesh() + self._patch_random_functions() _LOCAL_TENSOR_MODE.append(self) + # _distribute_region will compute correct per-shard offsets + # but we want all ranks to start with the same state + if not _is_in_fake_tensor_mode(): + cpu_state, cuda_states = _get_rng_state() + for rank in self.ranks: + self._per_rank_rng_states[rank] = ( + cpu_state.clone(), + {idx: state.clone() for idx, state in cuda_states.items()}, + ) + return super().__enter__() def __exit__( @@ -790,6 +1069,7 @@ def __exit__( ) -> None: self._disable = True self._unpatch_device_mesh() + self._unpatch_random_functions() _LOCAL_TENSOR_MODE.pop() super().__exit__(exc_type, exc_val, exc_tb) @@ -936,6 +1216,7 @@ def tensor_map( m = cb(r, tensor._local_tensors[r]) if m is not None: results[r] = m + # pyrefly: ignore [bad-argument-type, bad-argument-count] return LocalTensor(results) def _patch_device_mesh(self) -> None: @@ -949,6 +1230,87 @@ def _unpatch_device_mesh(self) -> None: # pyrefly: ignore [bad-assignment] self._old_get_coordinate = None + def _patch_random_functions(self) -> None: + import torch.random + from torch.distributed.tensor import _random as dtensor_random + + if self._old_torch_manual_seed is None: + self._old_torch_manual_seed = torch.random.manual_seed + torch.random.manual_seed = _LocalRandom.torch_manual_seed + torch.manual_seed = _LocalRandom.torch_manual_seed + + if self._old_torch_initial_seed is None: + self._old_torch_initial_seed = torch.random.initial_seed + torch.random.initial_seed = _LocalRandom.torch_initial_seed + torch.initial_seed = _LocalRandom.torch_initial_seed + + def _unpatch_random_functions(self) -> None: + import torch.random + from torch.distributed.tensor import _random as dtensor_random + + if self._old_torch_manual_seed is not None: + torch.random.manual_seed = self._old_torch_manual_seed + torch.manual_seed = self._old_torch_manual_seed + self._old_torch_manual_seed = None + + if self._old_torch_initial_seed is not None: + torch.random.initial_seed = self._old_torch_initial_seed + torch.initial_seed = self._old_torch_initial_seed + self._old_torch_initial_seed = None + + +class _LocalRandom: + """ + Holds implementations of random functionality that must be patched while running + under LocalTensorMode. + """ + + @staticmethod + def torch_manual_seed(seed) -> torch._C.Generator: + """LocalTensor-aware version of torch.random.manual_seed.""" + if ( + (lm := enabled_local_tensor_mode()) + and isinstance(seed, torch.SymInt) + and isinstance(seed.node, LocalIntNode) + ): + from torch.random import _manual_seed_impl + + for rank in sorted(lm.ranks): + rank_seed = seed.node._local_ints[rank] + _manual_seed_impl(rank_seed, update_local_tensor_states=False) + lm._per_rank_rng_states[rank] = _get_rng_state() + return torch.random.default_generator + from torch.random import _manual_seed_impl + + result = _manual_seed_impl(seed, update_local_tensor_states=False) + + if lm is not None and len(lm._per_rank_rng_states) > 0: + cpu_state, cuda_states = _get_rng_state() + for rank in lm.ranks: + lm._per_rank_rng_states[rank] = ( + cpu_state.clone(), + {idx: state.clone() for idx, state in cuda_states.items()}, + ) + + return result + + @staticmethod + def torch_initial_seed(): + """LocalTensor-aware version of torch.random.initial_seed.""" + if lm := enabled_local_tensor_mode(): + if len(lm._per_rank_rng_states) == 0: + return torch.random.default_generator.initial_seed() + rank_seeds = {} + + for rank in sorted(lm.ranks): + _set_rng_state(*lm._per_rank_rng_states[rank]) + rank_seeds[rank] = torch.random.default_generator.initial_seed() + + local_int_node = LocalIntNode(rank_seeds) + return torch.SymInt(local_int_node) + + return torch.random.default_generator.initial_seed() + class _LocalDeviceMesh: """ @@ -963,7 +1325,7 @@ def get_coordinate(self: DeviceMesh) -> Optional[list[int] | None]: # doing this because when submesh is created it is created for a particular # rank (therefore below we are patching get_rank method). We are trying to # limit the invasiveness of local tensor. - lm = local_tensor_mode() + lm = enabled_local_tensor_mode() assert lm is not None, "Unexpectedly not in LocalTensorMode" coords: list[dict[int, int]] = [{} for _ in range(self.ndim)] @@ -1024,6 +1386,22 @@ def local_tensor_mode() -> Optional[LocalTensorMode]: return None +def enabled_local_tensor_mode() -> Optional[LocalTensorMode]: + """ + Returns the current active LocalTensorMode only if it's enabled. + + This is a convenience function that combines the common pattern of checking + if local_tensor_mode() is not None and not disabled. + + Returns: + Optional[LocalTensorMode]: The current LocalTensorMode if active and enabled, else None. + """ + lm = local_tensor_mode() + if lm is not None and not lm._disable: + return lm + return None + + def maybe_run_for_local_tensor(func: Callable[..., Any]) -> Callable[..., Any]: """ Decorator that ensures a function is executed for each local tensor shard @@ -1048,8 +1426,7 @@ def maybe_run_for_local_tensor(func: Callable[..., Any]) -> Callable[..., Any]: @functools.wraps(func) def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] - lm = local_tensor_mode() - if lm is None or lm._disable: + if not (lm := enabled_local_tensor_mode()): return func(*args, **kwargs) ret = None with lm.disable(): @@ -1068,6 +1445,73 @@ def maybe_disable_local_tensor_mode() -> contextlib.AbstractContextManager: return lm.disable() if lm is not None else contextlib.nullcontext() +def maybe_enable_local_tracker( + device_type: str, distribute_region_enabled: bool, spec, generator +): + """ + Returns a context manager for LocalTensor-mode RNG tracking if local tensor mode is enabled. + + Args: + device_type: The device type (e.g., "cuda", "cpu") + distribute_region_enabled: Whether distribute region is enabled + spec: The DTensorSpec + generator: Optional torch.Generator + + Returns: + Context manager from local_tracker._distribute_region if local tensor mode is enabled, + otherwise None. + """ + if enabled_local_tensor_mode(): + local_tracker = _LocalOffsetBasedRNGTracker(device_type) + local_tracker.distribute_region_enabled = distribute_region_enabled + return local_tracker._distribute_region(spec, generator) + + return None + + +def get_generator_seed_for_device_type(device_type: str): + """ + Gets the generator seed for a specific device type, handling LocalTensor mode appropriately. + + Args: + device_type: The device type (e.g., "cuda", "cpu") + + Returns: + If in LocalTensor mode with per-rank RNG states: + - Returns int if all ranks have the same seed + - Returns SymInt(LocalIntNode) if ranks have different seeds + Otherwise: + - Returns int seed from the device's RNG state + """ + if lm := enabled_local_tensor_mode(): + if len(lm._per_rank_rng_states) == 0: + device_module = torch.get_device_module(device_type) + return device_module.get_rng_state()[:8].view(torch.int64).item() + device_module = torch.get_device_module(device_type) + + original_state = _get_rng_state() + + rank_seeds = {} + try: + for rank in sorted(lm.ranks): + _set_rng_state(*lm._per_rank_rng_states[rank]) + rank_seeds[rank] = int( + device_module.get_rng_state()[:8].view(torch.int64).item() + ) + finally: + # restore original state + _set_rng_state(*original_state) + + unique_seeds = set(rank_seeds.values()) + if len(unique_seeds) == 1: + return next(iter(unique_seeds)) + local_int_node = LocalIntNode(rank_seeds) + return torch.SymInt(local_int_node) + else: + device_module = torch.get_device_module(device_type) + return device_module.get_rng_state()[:8].view(torch.int64).item() + + import threading from queue import Queue @@ -1183,3 +1627,114 @@ def current() -> "LocalRunnerMode": global _LOCAL_RUNNER_MODE assert _LOCAL_RUNNER_MODE is not None, "LocalRunnerMode is not enabled" return _LOCAL_RUNNER_MODE + + +class _LocalPhiloxState: + """ + LocalTensor-aware version of _PhiloxState that manages per-rank RNG states. + This class handles the case where the generator state is a LocalTensor, allowing + different offsets and seeds for different virtual ranks. + + Note: This is designed to be used as a drop-in replacement for _PhiloxState + when working with LocalTensors in the DTensor random ops implementation. + """ + + def __init__(self, state: torch.Tensor): + assert isinstance(state, LocalTensor), ( + "_LocalPhiloxState requires a LocalTensor" + ) + self._local_tensor = state + self._per_rank_states = { + rank: local_state.to("cpu") + for rank, local_state in state._local_tensors.items() + } + + @property + def state(self): + return LocalTensor(self._per_rank_states) # type: ignore[name-defined] + + @property + def offset(self) -> Union[int, SymInt]: + from torch.distributed.tensor._random import _PhiloxState + + offsets = {} + for rank, state in self._per_rank_states.items(): + rank_philox = _PhiloxState(state) + offsets[rank] = rank_philox.offset + + if len(set(offsets.values())) == 1: + return next(iter(offsets.values())) + # pyrefly: ignore [bad-argument-type, bad-argument-count] + return SymInt(LocalIntNode(offsets)) + + @offset.setter + def offset(self, offset: Union[int, SymInt]) -> None: + from torch.distributed.tensor._random import _PhiloxState + + if isinstance(offset, SymInt) and isinstance(offset.node, LocalIntNode): + for rank, state in self._per_rank_states.items(): + rank_offset = offset.node._local_ints[rank] + rank_philox = _PhiloxState(state) + rank_philox.offset = rank_offset + else: + offset_int = int(offset) if isinstance(offset, SymInt) else offset + for state in self._per_rank_states.values(): + rank_philox = _PhiloxState(state) + rank_philox.offset = offset_int + + @property + def seed(self) -> Union[int, SymInt]: + from torch.distributed.tensor._random import _PhiloxState + + seeds = {} + for rank, state in self._per_rank_states.items(): + rank_philox = _PhiloxState(state) + seeds[rank] = rank_philox.seed + + if len(set(seeds.values())) == 1: + return next(iter(seeds.values())) + return SymInt(LocalIntNode(seeds)) + + @seed.setter + def seed(self, seed: Union[int, SymInt]) -> None: + from torch.distributed.tensor._random import _PhiloxState + + if isinstance(seed, SymInt) and isinstance(seed.node, LocalIntNode): + for rank, state in self._per_rank_states.items(): + rank_seed = seed.node._local_ints[rank] + rank_philox = _PhiloxState(state) + rank_philox.seed = rank_seed + else: + seed_int = int(seed) if isinstance(seed, SymInt) else seed + for state in self._per_rank_states.values(): + rank_philox = _PhiloxState(state) + rank_philox.seed = seed_int + + def apply_to_local_tensor_mode(self, device_handle) -> None: + """ + Apply per-rank RNG states to the LocalTensorMode's tracked states. + This updates both the device RNG state and the LocalTensorMode's _per_rank_rng_states. + + Args: + device_handle: The device handle to use for setting RNG state (_LocalDeviceHandle) + """ + if not enabled_local_tensor_mode(): + return + + assert hasattr(self, "_per_rank_offsets") + + for rank in sorted(self._per_rank_states.keys()): + offset_value = self._per_rank_offsets[rank] + if isinstance(offset_value, SymInt): + if isinstance(offset_value.node, LocalIntNode): + offset_value = offset_value.node._local_ints[rank] + else: + offset_value = int(offset_value) + + offset_tensor = torch.tensor( + [offset_value], dtype=torch.uint64, device="cpu" + ).view(torch.uint8) + self._per_rank_states[rank][8:] = offset_tensor + + # pyrefly: ignore [bad-argument-type, bad-argument-count] + device_handle.set_rng_state(LocalTensor(self._per_rank_states)) diff --git a/torch/distributed/tensor/_dispatch.py b/torch/distributed/tensor/_dispatch.py index cbd817a8bde37..630f327add3d7 100644 --- a/torch/distributed/tensor/_dispatch.py +++ b/torch/distributed/tensor/_dispatch.py @@ -135,7 +135,9 @@ def __init__(self) -> None: self._random_ops = { aten.native_dropout.default, aten.normal_.default, + aten.rand.default, aten.rand_like.default, + aten.randn.default, aten.randn_like.default, aten.randint_like.default, aten.randint_like.low_dtype, diff --git a/torch/distributed/tensor/_random.py b/torch/distributed/tensor/_random.py index f8325c83d55e4..42bf1ebeebf0e 100644 --- a/torch/distributed/tensor/_random.py +++ b/torch/distributed/tensor/_random.py @@ -101,6 +101,9 @@ def manual_seed(seed: int, device_mesh: DeviceMesh) -> None: # DTensor no longer maintains a copy of rng state. manual seed on dtensor is the same thing # as manual seed on torch. + # + # torch.manual_seed will handle LocalTensor mode correctly by + # iterating through all ranks if seed is a LocalIntNode. torch.manual_seed(seed) @@ -239,6 +242,16 @@ def _set_device_state(self, state: torch.Tensor): def _distribute_region( self, spec: DTensorSpec, generator: Optional[torch.Generator] = None ): + from torch.distributed._local_tensor import maybe_enable_local_tracker + + if local_tracker_context := maybe_enable_local_tracker( + self._device.type, self.distribute_region_enabled, spec, generator + ): + with local_tracker_context: + yield + return + + # regular (non-LocalTensor) mode if generator is not None: # This is a little hacky, but for any user-passed generator, we store its state under a unique key, # not because we need to keep a copy of it but because its the easiest way to make it work with the diff --git a/torch/random.py b/torch/random.py index cf23e52db320e..f86d7349019dc 100644 --- a/torch/random.py +++ b/torch/random.py @@ -39,6 +39,10 @@ def manual_seed(seed) -> torch._C.Generator: is raised. Negative inputs are remapped to positive values with the formula `0xffff_ffff_ffff_ffff + seed`. """ + return _manual_seed_impl(seed, update_local_tensor_states=True) + + +def _manual_seed_impl(seed, update_local_tensor_states) -> torch._C.Generator: seed = int(seed) import torch.cuda diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index 9666765b01e71..1f6c4aece1e80 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -724,6 +724,9 @@ def setUp(self) -> None: torch.autograd._enable_record_function(False) def tearDown(self) -> None: + from torch.distributed.tensor import _random as random + + random._rng_tracker = None super().tearDown() torch.autograd._enable_record_function(True) From be33b7faf685560bb618561b44b751713a660337 Mon Sep 17 00:00:00 2001 From: lichuyang Date: Wed, 19 Nov 2025 13:10:12 +0000 Subject: [PATCH 040/230] [DeviceMemory] Add Basic Statistics to Device Memory in OpenReg (#166395) Implement a complete OpenRegDeviceAllocator with the following enhancements: - Implement memory statistics tracking (allocated/reserved bytes, allocation count) - Track allocation sizes for accurate memory statistics - Refactor DeviceAllocator's inheritance relationship from c10::DeviceAllocator - This change is for further improvement of adding a memory caching function to DeviceMemory Add comprehensive test coverage: - Memory allocation/deallocation tests with statistics validation - Storage operations and tensor-from-blob tests - Multithreading safety tests for concurrent allocations - Gradient tracking and requires_grad compatibility tests Fixes #166157 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166395 Approved by: https://github.com/fffrog --- .../csrc/runtime/OpenRegDeviceAllocator.cpp | 271 +++++++++- .../csrc/runtime/OpenRegDeviceAllocator.h | 103 ++-- .../torch_openreg/tests/test_memory.py | 488 ++++++++++++++++++ 3 files changed, 826 insertions(+), 36 deletions(-) diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.cpp index 3d35b677cd208..3a6f2945d903c 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.cpp +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.cpp @@ -1,8 +1,275 @@ #include "OpenRegDeviceAllocator.h" +#include "OpenRegFunctions.h" + +#include +#include + +using namespace c10::CachingAllocator; namespace c10::openreg { -static OpenRegDeviceAllocator global_openreg_alloc; -REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_openreg_alloc); +constexpr size_t kAggregate = static_cast(StatType::AGGREGATE); + + +DeviceMemoryAllocator::DeviceMemoryAllocator(c10::DeviceIndex device_index) + : device_index_(device_index) {} + +void* DeviceMemoryAllocator::malloc(size_t nbytes) { + if (nbytes == 0) { + return nullptr; + } + + std::lock_guard lock(mutex_); + + void* data = nullptr; + auto ret = orMalloc(&data, nbytes); + + TORCH_CHECK( + ret == orSuccess && data != nullptr, + "Failed to allocate ", + nbytes, + " bytes on openreg device ", + device_index_, + ". ", + "Allocated: ", + stats_.allocated_bytes[0].current, + " bytes, ", + "Reserved: ", + stats_.reserved_bytes[0].current, + " bytes"); + + // Track allocation size for proper deallocation statistics + allocation_sizes_[data] = nbytes; + + // Update statistics + stats_.allocated_bytes[kAggregate].increase(nbytes); + stats_.reserved_bytes[kAggregate].increase(nbytes); + stats_.num_device_alloc++; + + return data; +} + +void DeviceMemoryAllocator::free(void* ptr) { + if (!ptr) { + return; + } + + std::lock_guard lock(mutex_); + + auto ret = orFree(ptr); + + if (ret == orSuccess) { + auto it = allocation_sizes_.find(ptr); + if (it != allocation_sizes_.end()) { + size_t nbytes = it->second; + + stats_.allocated_bytes[kAggregate].decrease(nbytes); + stats_.reserved_bytes[kAggregate].decrease(nbytes); + stats_.num_device_free++; + + allocation_sizes_.erase(it); + } else { + TORCH_WARN( + "Successfully freed OpenReg memory pointer ", + ptr, + " on device ", + device_index_, + " that was not tracked by the allocator. " + "Statistics may be inaccurate."); + } + } else { + // orFree failed + auto it = allocation_sizes_.find(ptr); + if (it != allocation_sizes_.end()) { + TORCH_WARN( + "orFree failed for tracked pointer ", + ptr, + " with size ", + it->second, + " bytes on device ", + device_index_, + ". Return code: ", + ret, + ". Keeping tracking record - this may indicate a double-free or invalid pointer."); + } else { + TORCH_WARN( + "orFree failed for untracked pointer ", + ptr, + " on device ", + device_index_, + ". Return code: ", + ret, + ". This likely indicates a double-free or invalid pointer."); + } + } +} + +c10::CachingDeviceAllocator::DeviceStats DeviceMemoryAllocator::getStats() { + std::lock_guard lock(mutex_); + return stats_; +} + +void DeviceMemoryAllocator::resetAccumulatedStats() { + std::lock_guard lock(mutex_); + + // Reset accumulated statistics for all StatTypes + for (const auto stat_type : + c10::irange(static_cast(StatType::NUM_TYPES))) { + stats_.allocated_bytes[stat_type].reset_accumulated(); + stats_.reserved_bytes[stat_type].reset_accumulated(); + stats_.active_bytes[stat_type].reset_accumulated(); + stats_.inactive_split_bytes[stat_type].reset_accumulated(); + stats_.requested_bytes[stat_type].reset_accumulated(); + } + + stats_.num_alloc_retries = 0; + stats_.num_ooms = 0; + stats_.num_sync_all_streams = 0; + stats_.num_device_alloc = 0; + stats_.num_device_free = 0; +} + +void DeviceMemoryAllocator::resetPeakStats() { + std::lock_guard lock(mutex_); + + // Reset peak statistics for all StatTypes + for (const auto stat_type : + c10::irange(static_cast(StatType::NUM_TYPES))) { + stats_.allocated_bytes[stat_type].reset_peak(); + stats_.reserved_bytes[stat_type].reset_peak(); + stats_.active_bytes[stat_type].reset_peak(); + stats_.inactive_split_bytes[stat_type].reset_peak(); + stats_.requested_bytes[stat_type].reset_peak(); + } + + stats_.oversize_allocations.reset_peak(); + stats_.oversize_segments.reset_peak(); +} + +namespace { + +OpenRegDeviceAllocator g_allocator; + +void deleteOpenRegMemory(void* ptr) { + g_allocator.freeMemory(ptr); +} + +} + +OpenRegDeviceAllocator::OpenRegDeviceAllocator() { + std::lock_guard lock(mutex_); + const auto device_count = c10::openreg::device_count(); + device_allocators_.resize(device_count); + for (const auto i : c10::irange(device_count)) { + device_allocators_[i] = std::make_unique(i); + } +} + + +at::DataPtr OpenRegDeviceAllocator::allocate(size_t nbytes) { + int current_device_index = -1; + auto ret = orGetDevice(¤t_device_index); + TORCH_CHECK(ret == orSuccess, "Failed to get current OpenReg device"); + + auto curr_device = + c10::Device(c10::DeviceType::PrivateUse1, current_device_index); + + void* data = nullptr; + if (nbytes > 0) { + // Allocate memory via device-specific allocator + data = device_allocators_[current_device_index]->malloc(nbytes); + + // Track which device owns this pointer + std::lock_guard lock(mutex_); + allocated_blocks_[data] = current_device_index; + } + + return {data, data, &deleteOpenRegMemory, curr_device}; +} + +at::DeleterFnPtr OpenRegDeviceAllocator::raw_deleter() const { + return &deleteOpenRegMemory; +} + +void OpenRegDeviceAllocator::copy_data( + void* dest, + const void* src, + std::size_t count) const { + auto ret = orMemcpy(dest, src, count, orMemcpyDeviceToDevice); + TORCH_CHECK( + ret == orSuccess, "Failed to copy ", count, " bytes on openreg device"); +} + +bool OpenRegDeviceAllocator::initialized() { + std::lock_guard lock(mutex_); + return !device_allocators_.empty(); +} + +void OpenRegDeviceAllocator::freeMemory(void* ptr) { + if (!ptr) { + return; + } + + // Try to find which device owns this pointer + c10::DeviceIndex device_index = -1; + bool found_in_map = false; + + { + std::lock_guard lock(mutex_); + auto it = allocated_blocks_.find(ptr); + if (it != allocated_blocks_.end()) { + device_index = it->second; + allocated_blocks_.erase(it); + found_in_map = true; + } + } + + if (found_in_map) { + // Pointer was tracked - free via device-specific allocator with stats + device_allocators_[device_index]->free(ptr); + } else { + // Pointer not tracked - might be already freed by storage or other path + // Try to free it directly via orFree without updating statistics + auto ret = orFree(ptr); + + // Only warn if orFree actually failed (not just "not found") + // In OpenReg's case, orFree returns orErrorUnknown if pointer not in registry + // which is expected for already-freed memory + if (ret != orSuccess && ret != orErrorUnknown) { + TORCH_WARN( + "orFree failed for untracked OpenReg memory pointer ", + ptr, + ". Error code: ", ret); + } + } +} + +c10::CachingDeviceAllocator::DeviceStats OpenRegDeviceAllocator:: + getDeviceStats(c10::DeviceIndex device) { + return device_allocators_[device]->getStats(); +} + +void OpenRegDeviceAllocator::resetAccumulatedStats(c10::DeviceIndex device) { + device_allocators_[device]->resetAccumulatedStats(); +} + +void OpenRegDeviceAllocator::resetPeakStats(c10::DeviceIndex device) { + device_allocators_[device]->resetPeakStats(); +} + +void OpenRegDeviceAllocator::emptyCache(MempoolId_t mempool_id) { + // OpenReg doesn't implement caching yet + // TODO: When caching is implemented, release all free blocks here +} + +void OpenRegDeviceAllocator::recordStream( + const DataPtr& ptr, + c10::Stream stream) { + // OpenReg doesn't track stream usage yet + // TODO: When stream support is added, track which streams are using this pointer +} +// ============ Global Registration ============ + +REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &g_allocator); } // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.h index c9aea4a913427..777926e02b18c 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.h +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.h @@ -1,43 +1,78 @@ -#include +#pragma once #include +#include #include +#include #include +#include +#include +#include +#include + namespace c10::openreg { -struct OpenRegDeviceAllocator final : at::Allocator { - OpenRegDeviceAllocator() = default; - - static void ReportAndDelete(void* ptr) { - if (!ptr) { - return; - } - orFreeHost(ptr); - } - - at::DataPtr allocate(size_t nbytes) override { - int current_device_index = -1; - orGetDevice(¤t_device_index); - - auto curr_device = - c10::Device(c10::DeviceType::PrivateUse1, current_device_index); - void* data = nullptr; - if (nbytes > 0) { - orMalloc(&data, nbytes); - TORCH_CHECK( - data, "Failed to allocator ", nbytes, " bytes on openreg device."); - } - return {data, data, &ReportAndDelete, curr_device}; - } - - at::DeleterFnPtr raw_deleter() const override { - return &ReportAndDelete; - } - - void copy_data(void* dest, const void* src, std::size_t count) const final { - orMemcpy(dest, src, count, orMemcpyDeviceToDevice); - } + +class DeviceMemoryAllocator { + public: + explicit DeviceMemoryAllocator(c10::DeviceIndex device_index); + + DeviceMemoryAllocator(const DeviceMemoryAllocator&) = delete; + DeviceMemoryAllocator& operator=(const DeviceMemoryAllocator&) = delete; + + void* malloc(size_t nbytes); + + void free(void* ptr); + + c10::CachingDeviceAllocator::DeviceStats getStats(); + + void resetAccumulatedStats(); + + void resetPeakStats(); + + private: + c10::DeviceIndex device_index_; + + c10::CachingDeviceAllocator::DeviceStats stats_; + + std::unordered_map allocation_sizes_; + + std::recursive_mutex mutex_; +}; + + +class OpenRegDeviceAllocator final : public c10::DeviceAllocator { + public: + OpenRegDeviceAllocator(); + + at::DataPtr allocate(size_t nbytes) override; + at::DeleterFnPtr raw_deleter() const override; + void copy_data(void* dest, const void* src, std::size_t count) const final; + + + bool initialized() override; + void emptyCache(MempoolId_t mempool_id = {0, 0}) override; + void recordStream(const DataPtr& ptr, c10::Stream stream) override; + c10::CachingDeviceAllocator::DeviceStats getDeviceStats( + c10::DeviceIndex device) override; + void resetAccumulatedStats(c10::DeviceIndex device) override; + void resetPeakStats(c10::DeviceIndex device) override; + + + void freeMemory(void* ptr); + + private: + + // Per-device allocators + std::vector> device_allocators_; + + // Global mapping from pointer to device index + std::recursive_mutex mutex_; + ska::flat_hash_map allocated_blocks_; }; -} // namespace c10::openreg + + + +} diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_memory.py b/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_memory.py index 3d67e16a0f503..b4a64eedc5bfc 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_memory.py +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_memory.py @@ -1,9 +1,392 @@ # Owner(s): ["module: PrivateUse1"] +import gc +import time + import torch + +import torch_openreg # noqa: F401 from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase +class TestDeviceAllocator(TestCase): + """Test cases for OpenRegDeviceAllocator functionality.""" + + def setUp(self): + """Reset memory state before each test.""" + # Force garbage collection to ensure clean state + gc.collect() + # Note: We can't directly reset allocator stats without C++ API, + # but we can ensure tensors are properly released + + def test_basic_allocation(self): + """Test basic memory allocation with various sizes.""" + # Small allocation + x = torch.empty(100, device="openreg") + self.assertEqual(x.device.type, "openreg") + self.assertEqual(x.numel(), 100) + # Large allocation + z = torch.empty(10000, device="openreg") + self.assertEqual(z.device.type, "openreg") + self.assertEqual(z.numel(), 10000) + # Multi-dimensional allocation + w = torch.empty(10, 20, 30, device="openreg") + self.assertEqual(w.device.type, "openreg") + self.assertEqual(w.shape, torch.Size([10, 20, 30])) + + def test_memory_lifecycle(self): + """Test complete memory allocation and deallocation lifecycle.""" + # Allocate tensor + x = torch.empty(1000, device="openreg") + self.assertEqual(x.device.type, "openreg") + + # Explicitly delete tensor + del x + gc.collect() + + # Allocate again to ensure memory was freed + y = torch.empty(1000, device="openreg") + self.assertEqual(y.device.type, "openreg") + del y + gc.collect() + + def test_tensor_copy_operations(self): + """Test memory operations during tensor copies.""" + # CPU to OpenReg + cpu_tensor = torch.randn(100) + openreg_tensor = cpu_tensor.to("openreg") + self.assertEqual(openreg_tensor.device.type, "openreg") + self.assertEqual(cpu_tensor.shape, openreg_tensor.shape) + + # OpenReg to CPU + back_to_cpu = openreg_tensor.to("cpu") + self.assertEqual(back_to_cpu.device.type, "cpu") + self.assertTrue(torch.allclose(cpu_tensor, back_to_cpu)) + + # OpenReg to OpenReg (clone) + cloned = openreg_tensor.clone() + self.assertEqual(cloned.device.type, "openreg") + self.assertTrue(torch.allclose(openreg_tensor.cpu(), cloned.cpu())) + + def test_inplace_operations(self): + """Test memory stability during inplace operations.""" + x = torch.ones(100, device="openreg") + original_data_ptr = x.data_ptr() + + # Inplace addition + x.add_(1) + self.assertEqual(x.data_ptr(), original_data_ptr) + self.assertTrue(torch.all(x == 2)) + + # Inplace multiplication + x.mul_(2) + self.assertEqual(x.data_ptr(), original_data_ptr) + self.assertTrue(torch.all(x == 4)) + + def test_view_operations(self): + """Test that views share memory correctly.""" + x = torch.randn(100, device="openreg") + original_data_ptr = x.data_ptr() + + # Reshape view + y = x.view(10, 10) + self.assertEqual(y.data_ptr(), original_data_ptr) + self.assertEqual(y.shape, torch.Size([10, 10])) + + # Slice view + z = x[10:20] + # Slices may have different data_ptr but should share storage + self.assertEqual(z.numel(), 10) + + def test_different_dtypes(self): + """Test allocation with different data types.""" + dtypes = [torch.float32, torch.float64, torch.int32, torch.int64] + + for dtype in dtypes: + x = torch.empty(100, dtype=dtype, device="openreg") + self.assertEqual(x.device.type, "openreg") + self.assertEqual(x.dtype, dtype) + self.assertEqual(x.numel(), 100) + + def test_tensor_resize(self): + """Test tensor resizing operations.""" + x = torch.empty(100, device="openreg") + _ = x.data_ptr() + + # Resize to smaller size (should reuse storage) + x.resize_(50) + self.assertEqual(x.numel(), 50) + # Storage should still be available + + # Resize to original size + x.resize_(100) + self.assertEqual(x.numel(), 100) + + def test_empty_cache_operation(self): + """Test empty cache functionality.""" + # Allocate some tensors + x = torch.empty(1000, device="openreg") + y = torch.empty(2000, device="openreg") + + # Delete tensors + del x, y + gc.collect() + + # Note: OpenRegDeviceAllocator.emptyCache is currently a no-op + # This test ensures it doesn't crash + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + def test_memory_format_allocation(self): + """Test allocation with different memory formats.""" + # Channels last format + x = torch.empty(2, 3, 4, 4, device="openreg", memory_format=torch.channels_last) + self.assertEqual(x.device.type, "openreg") + self.assertTrue(x.is_contiguous(memory_format=torch.channels_last)) + + # Contiguous format (default) + y = torch.empty( + 2, 3, 4, 4, device="openreg", memory_format=torch.contiguous_format + ) + self.assertEqual(y.device.type, "openreg") + self.assertTrue(y.is_contiguous()) + + def test_large_allocation(self): + """Test large memory allocation.""" + # Allocate a large tensor (10MB approximately) + size = 10 * 1024 * 1024 // 4 # 10MB in float32 + x = torch.empty(size, device="openreg") + self.assertEqual(x.device.type, "openreg") + self.assertEqual(x.numel(), size) + + def test_sequential_allocations_and_deallocations(self): + """Test sequential allocation and deallocation patterns.""" + for i in range(10): + x = torch.empty(1000 + i * 100, device="openreg") + self.assertEqual(x.device.type, "openreg") + # Let tensor go out of scope + del x + gc.collect() + + def test_allocation_with_requires_grad(self): + """Test allocation of tensors with gradient tracking.""" + x = torch.empty(100, device="openreg", requires_grad=True) + self.assertEqual(x.device.type, "openreg") + self.assertTrue(x.requires_grad) + + y = torch.randn(100, device="openreg", requires_grad=True) + self.assertEqual(y.device.type, "openreg") + self.assertTrue(y.requires_grad) + + def test_storage_operations(self): + """Test storage-level operations.""" + x = torch.randn(100, device="openreg") + storage = x.storage() + + # Verify storage is on correct device + self.assertTrue(storage.device.type == "openreg") + + # Verify storage size + self.assertGreaterEqual(storage.size(), x.numel()) + + def test_tensor_from_blob(self): + """Test creating tensors that reference existing memory.""" + x = torch.randn(100, device="openreg") + + # Create a view that references the same data + y = x.view_as(x) + + # They should share the same underlying storage + self.assertEqual(x.data_ptr(), y.data_ptr()) + + # Modifying one should affect the other + x.fill_(5.0) + self.assertTrue(torch.all(y == 5.0)) + + +class TestMemoryLeaks(TestCase): + """Test cases for detecting memory leaks in OpenRegDeviceAllocator.""" + + def setUp(self): + """Reset memory state before each test.""" + gc.collect() + time.sleep(0.1) # Allow time for cleanup + + def test_no_leak_simple_allocations(self): + """Test that simple allocations don't leak memory.""" + # Warm-up + for _ in range(10): + x = torch.empty(1000, device="openreg") + del x + gc.collect() + time.sleep(0.1) + + # Perform many allocations and deallocations + iterations = 1000 + for i in range(iterations): + x = torch.empty(1000, device="openreg") + del x + + if i % 100 == 0: + gc.collect() + + # Final cleanup + gc.collect() + time.sleep(0.1) + + # If there were leaks, this would have accumulated significant memory + # The test passes if no exception/crash occurred + + def test_no_leak_varying_sizes(self): + """Test that allocations of varying sizes don't leak.""" + iterations = 500 + sizes = [100, 500, 1000, 5000, 10000] + + for i in range(iterations): + size = sizes[i % len(sizes)] + x = torch.empty(size, device="openreg") + del x + + if i % 50 == 0: + gc.collect() + + gc.collect() + time.sleep(0.1) + + def test_no_leak_with_copies(self): + """Test that tensor copies don't leak memory.""" + iterations = 300 + + for i in range(iterations): + # Create tensor + x = torch.randn(500, device="openreg") + + # Copy to CPU + cpu_copy = x.cpu() + + # Copy back to device + device_copy = cpu_copy.to("openreg") + + # Clone + cloned = device_copy.clone() + + # Delete all + del x, cpu_copy, device_copy, cloned + + if i % 50 == 0: + gc.collect() + + gc.collect() + time.sleep(0.1) + + def test_no_leak_with_views(self): + """Test that tensor views don't leak memory.""" + iterations = 500 + + for i in range(iterations): + x = torch.randn(1000, device="openreg") + + # Create various views + view1 = x.view(10, 100) + view2 = x[100:200] + view3 = x.reshape(20, 50) + + # Delete views and original + del view1, view2, view3, x + + if i % 100 == 0: + gc.collect() + + gc.collect() + time.sleep(0.1) + + def test_no_leak_inplace_operations(self): + """Test that inplace operations don't leak memory.""" + iterations = 500 + + for i in range(iterations): + x = torch.ones(1000, device="openreg") + + # Multiple inplace operations + x.add_(1) + x.mul_(2) + x.div_(2) + x.sub_(1) + + del x + + if i % 100 == 0: + gc.collect() + + gc.collect() + time.sleep(0.1) + + def test_no_leak_with_gradients(self): + """Test that tensors with gradients don't leak.""" + iterations = 300 + + for i in range(iterations): + x = torch.randn(100, device="openreg", requires_grad=True) + y = torch.randn(100, device="openreg", requires_grad=True) + + # Operation that creates computation graph + z = x + y + + # Delete all + del x, y, z + + if i % 50 == 0: + gc.collect() + + gc.collect() + time.sleep(0.1) + + def test_no_leak_repeated_large_allocations(self): + """Test repeated large allocations for memory leaks.""" + # Large tensor size (50MB) + size = 50 * 1024 * 1024 // 4 + iterations = 50 + + for i in range(iterations): + x = torch.empty(size, device="openreg") + del x + gc.collect() + time.sleep(0.05) # Allow time for cleanup + + # Final cleanup + gc.collect() + time.sleep(0.1) + + def test_leak_detection_with_statistics(self): + """Test memory leak detection using allocation patterns.""" + # This test verifies that after many alloc/dealloc cycles, + # the allocator properly frees memory + + num_cycles = 10 + allocations_per_cycle = 100 + + for cycle in range(num_cycles): + tensors = [] + + # Allocate many tensors + for i in range(allocations_per_cycle): + t = torch.empty(1000, device="openreg") + tensors.append(t) + + # Verify all allocated + self.assertEqual(len(tensors), allocations_per_cycle) + + # Delete all + tensors.clear() + gc.collect() + time.sleep(0.05) + + # Final verification - if there were leaks, memory would be exhausted + # The test passes if we can still allocate + final_tensor = torch.empty(10000, device="openreg") + self.assertEqual(final_tensor.device.type, "openreg") + del final_tensor + + class TestPinMemory(TestCase): @skipIfTorchDynamo("unsupported aten.is_pinned.default") def test_pin_memory(self): @@ -27,5 +410,110 @@ def test_pin_memory(self): self.assertTrue(pinned_untyped_storage.is_pinned("openreg")) +class TestMultiDeviceAllocation(TestCase): + """Test basic multi-device allocation functionality.""" + + def setUp(self): + self.device_count = torch.openreg.device_count() + self.assertEqual(self.device_count, 2, "This test requires 2 OpenReg devices") + gc.collect() + + def tearDown(self): + """Restore device 0 to avoid affecting subsequent tests.""" + torch.openreg.set_device(0) + gc.collect() + + def test_allocation_on_device_1(self): + torch.openreg.set_device(1) + x = torch.empty(100, device="openreg:1") + self.assertEqual(x.device.type, "openreg") + self.assertEqual(x.device.index, 1) + + def test_simultaneous_device_allocations(self): + """Test allocations on both devices simultaneously.""" + x = torch.empty(100, device="openreg:0") + y = torch.empty(200, device="openreg:1") + + self.assertEqual(x.device.index, 0) + self.assertEqual(y.device.index, 1) + self.assertNotEqual(x.data_ptr(), y.data_ptr()) + + def test_memory_isolation_between_devices(self): + """Test that memory allocations are isolated between devices.""" + + tensors_dev0 = [torch.empty(1000, device="openreg:0") for _ in range(10)] + tensors_dev1 = [torch.empty(1000, device="openreg:1") for _ in range(10)] + + # Verify all device 0 tensors are on device 0 + for t in tensors_dev0: + self.assertEqual(t.device.index, 0) + + # Verify all device 1 tensors are on device 1 + for t in tensors_dev1: + self.assertEqual(t.device.index, 1) + + # Pointers should be different + ptrs_dev0 = {t.data_ptr() for t in tensors_dev0} + ptrs_dev1 = {t.data_ptr() for t in tensors_dev1} + self.assertEqual( + len(ptrs_dev0 & ptrs_dev1), 0, "Devices should not share pointers" + ) + + def test_alternating_device_allocations(self): + """Test alternating allocations between devices.""" + tensors = [] + for i in range(20): + device_idx = i % 2 + t = torch.empty(100 + i, device=f"openreg:{device_idx}") + self.assertEqual(t.device.index, device_idx) + tensors.append(t) + + # Verify all tensors retained correct device assignment + for i, t in enumerate(tensors): + expected_device = i % 2 + self.assertEqual(t.device.index, expected_device) + + +class TestCrossDeviceOperations(TestCase): + """Test cross-device tensor operations.""" + + def setUp(self): + self.device_count = torch.openreg.device_count() + self.assertEqual(self.device_count, 2) + gc.collect() + + def tearDown(self): + """Restore device 0 to avoid affecting subsequent tests.""" + torch.openreg.set_device(0) + gc.collect() + + def test_tensor_to_different_device(self): + """Test moving tensor from one device to another.""" + # Create on device 0 + x = torch.randn(100, device="openreg:0") + self.assertEqual(x.device.index, 0) + + # Move to device 1 + y = x.to("openreg:1") + self.assertEqual(y.device.index, 1) + self.assertNotEqual(x.data_ptr(), y.data_ptr()) + + # Values should be the same + self.assertTrue(torch.allclose(x.cpu(), y.cpu())) + + def test_bidirectional_device_transfer(self): + """Test transferring tensor back and forth between devices.""" + original = torch.randn(100, device="openreg:0") + original_cpu = original.cpu() + + # 0 -> 1 + on_dev1 = original.to("openreg:1") + self.assertTrue(torch.allclose(original_cpu, on_dev1.cpu())) + + # 1 -> 0 + back_to_dev0 = on_dev1.to("openreg:0") + self.assertTrue(torch.allclose(original_cpu, back_to_dev0.cpu())) + + if __name__ == "__main__": run_tests() From 8f4dc304534529c6abf9da0b7154d49fb907d167 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Tue, 18 Nov 2025 23:24:54 +0000 Subject: [PATCH 041/230] Hide all symbols (except stable/headeronly/shim) if TORCH_STABLE_ONLY is defined (#167496) Fixes https://github.com/pytorch/pytorch/issues/161660 This extends the `TORCH_STABLE_ONLY` stopgap added in https://github.com/pytorch/pytorch/pull/161658 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167496 Approved by: https://github.com/janeyx99, https://github.com/malfet --- .../smoke_test/check_binary_symbols.py | 338 ++++++++++++++++++ setup.py | 47 +++ .../libtorch_agnostic_2_10_extension/setup.py | 1 - .../torch_stable_test_extension/setup.py | 67 ---- .../torch_stable_test/__init__.py | 0 .../torch_stable_test/csrc/test_extension.cpp | 1 - .../torch_stable_test/test_torch_stable.py | 22 -- torch/csrc/inductor/aoti_torch/c/shim.h | 6 +- 8 files changed, 388 insertions(+), 94 deletions(-) delete mode 100644 test/cpp_extensions/torch_stable_test_extension/setup.py delete mode 100644 test/cpp_extensions/torch_stable_test_extension/torch_stable_test/__init__.py delete mode 100644 test/cpp_extensions/torch_stable_test_extension/torch_stable_test/csrc/test_extension.cpp delete mode 100644 test/cpp_extensions/torch_stable_test_extension/torch_stable_test/test_torch_stable.py diff --git a/.ci/pytorch/smoke_test/check_binary_symbols.py b/.ci/pytorch/smoke_test/check_binary_symbols.py index b0c607659c72d..51d5174e77912 100755 --- a/.ci/pytorch/smoke_test/check_binary_symbols.py +++ b/.ci/pytorch/smoke_test/check_binary_symbols.py @@ -100,6 +100,337 @@ def check_lib_statically_linked_libstdc_cxx_abi_symbols(lib: str) -> None: ) +def _compile_and_extract_symbols( + cpp_content: str, compile_flags: list[str], exclude_list: list[str] | None = None +) -> list[str]: + """ + Helper to compile a C++ file and extract all symbols. + + Args: + cpp_content: C++ source code to compile + compile_flags: Compilation flags + exclude_list: List of symbol names to exclude. Defaults to ["main"]. + + Returns: + List of all symbols found in the object file (excluding those in exclude_list). + """ + import subprocess + import tempfile + + if exclude_list is None: + exclude_list = ["main"] + + with tempfile.TemporaryDirectory() as tmpdir: + tmppath = Path(tmpdir) + cpp_file = tmppath / "test.cpp" + obj_file = tmppath / "test.o" + + cpp_file.write_text(cpp_content) + + result = subprocess.run( + compile_flags + [str(cpp_file), "-o", str(obj_file)], + capture_output=True, + text=True, + timeout=60, + ) + + if result.returncode != 0: + raise RuntimeError(f"Compilation failed: {result.stderr}") + + symbols = get_symbols(str(obj_file)) + + # Return all symbol names, excluding those in the exclude list + return [name for _addr, _stype, name in symbols if name not in exclude_list] + + +def check_stable_only_symbols(install_root: Path) -> None: + """ + Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code and comparing symbol counts. + + This approach tests: + 1. WITHOUT macros -> many torch symbols exposed + 2. WITH TORCH_STABLE_ONLY -> zero torch symbols (all hidden) + 3. WITH TORCH_TARGET_VERSION -> zero torch symbols (all hidden) + 4. WITH both macros -> zero torch symbols (all hidden) + """ + include_dir = install_root / "include" + assert include_dir.exists(), f"Expected {include_dir} to be present" + + test_cpp_content = """ +// Main torch C++ API headers +#include +#include + +// ATen tensor library +#include + +// Core c10 headers (commonly used) +#include +#include +#include +#include +#include + +int main() { return 0; } +""" + + base_compile_flags = [ + "g++", + "-std=c++17", + f"-I{include_dir}", + f"-I{include_dir}/torch/csrc/api/include", + "-c", # Compile only, don't link + ] + + # Compile WITHOUT any macros + symbols_without = _compile_and_extract_symbols( + cpp_content=test_cpp_content, + compile_flags=base_compile_flags, + ) + + # We expect constexpr symbols, inline functions used by other headers etc. + # to produce symbols + num_symbols_without = len(symbols_without) + print(f"Found {num_symbols_without} symbols without any macros defined") + assert num_symbols_without != 0, ( + "Expected a non-zero number of symbols without any macros" + ) + + # Compile WITH TORCH_STABLE_ONLY (expect 0 symbols) + compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY"] + + symbols_with_stable_only = _compile_and_extract_symbols( + cpp_content=test_cpp_content, + compile_flags=compile_flags_with_stable_only, + ) + + num_symbols_with_stable_only = len(symbols_with_stable_only) + assert num_symbols_with_stable_only == 0, ( + f"Expected no symbols with TORCH_STABLE_ONLY macro, but found {num_symbols_with_stable_only}" + ) + + # Compile WITH TORCH_TARGET_VERSION (expect 0 symbols) + compile_flags_with_target_version = base_compile_flags + [ + "-DTORCH_TARGET_VERSION=1" + ] + + symbols_with_target_version = _compile_and_extract_symbols( + cpp_content=test_cpp_content, + compile_flags=compile_flags_with_target_version, + ) + + num_symbols_with_target_version = len(symbols_with_target_version) + assert num_symbols_with_target_version == 0, ( + f"Expected no symbols with TORCH_TARGET_VERSION macro, but found {num_symbols_with_target_version}" + ) + + # Compile WITH both macros (expect 0 symbols) + compile_flags_with_both = base_compile_flags + [ + "-DTORCH_STABLE_ONLY", + "-DTORCH_TARGET_VERSION=1", + ] + + symbols_with_both = _compile_and_extract_symbols( + cpp_content=test_cpp_content, + compile_flags=compile_flags_with_both, + ) + + num_symbols_with_both = len(symbols_with_both) + assert num_symbols_with_both == 0, ( + f"Expected no symbols with both macros, but found {num_symbols_with_both}" + ) + + +def check_stable_api_symbols(install_root: Path) -> None: + """ + Test that stable API headers still expose symbols with TORCH_STABLE_ONLY. + The torch/csrc/stable/c/shim.h header is tested in check_stable_c_shim_symbols + """ + include_dir = install_root / "include" + assert include_dir.exists(), f"Expected {include_dir} to be present" + + stable_dir = include_dir / "torch" / "csrc" / "stable" + assert stable_dir.exists(), f"Expected {stable_dir} to be present" + + stable_headers = list(stable_dir.rglob("*.h")) + if not stable_headers: + raise RuntimeError("Could not find any stable headers") + + includes = [] + for header in stable_headers: + rel_path = header.relative_to(include_dir) + includes.append(f"#include <{rel_path.as_posix()}>") + + includes_str = "\n".join(includes) + test_stable_content = f""" +{includes_str} +int main() {{ return 0; }} +""" + + compile_flags = [ + "g++", + "-std=c++17", + f"-I{include_dir}", + f"-I{include_dir}/torch/csrc/api/include", + "-c", + "-DTORCH_STABLE_ONLY", + ] + + symbols_stable = _compile_and_extract_symbols( + cpp_content=test_stable_content, + compile_flags=compile_flags, + ) + num_symbols_stable = len(symbols_stable) + print(f"Found {num_symbols_stable} symbols in torch/csrc/stable") + assert num_symbols_stable > 0, ( + f"Expected stable headers to expose symbols with TORCH_STABLE_ONLY, " + f"but found {num_symbols_stable} symbols" + ) + + +def check_headeronly_symbols(install_root: Path) -> None: + """ + Test that header-only utility headers still expose symbols with TORCH_STABLE_ONLY. + """ + include_dir = install_root / "include" + assert include_dir.exists(), f"Expected {include_dir} to be present" + + # Find all headers in torch/headeronly + headeronly_dir = include_dir / "torch" / "headeronly" + assert headeronly_dir.exists(), f"Expected {headeronly_dir} to be present" + headeronly_headers = list(headeronly_dir.rglob("*.h")) + if not headeronly_headers: + raise RuntimeError("Could not find any headeronly headers") + + # Filter out platform-specific headers that may not compile everywhere + platform_specific_keywords = [ + "cpu/vec", + ] + + filtered_headers = [] + for header in headeronly_headers: + rel_path = header.relative_to(include_dir).as_posix() + if not any( + keyword in rel_path.lower() for keyword in platform_specific_keywords + ): + filtered_headers.append(header) + + includes = [] + for header in filtered_headers: + rel_path = header.relative_to(include_dir) + includes.append(f"#include <{rel_path.as_posix()}>") + + includes_str = "\n".join(includes) + test_headeronly_content = f""" +{includes_str} +int main() {{ return 0; }} +""" + + compile_flags = [ + "g++", + "-std=c++17", + f"-I{include_dir}", + f"-I{include_dir}/torch/csrc/api/include", + "-c", + "-DTORCH_STABLE_ONLY", + ] + + symbols_headeronly = _compile_and_extract_symbols( + cpp_content=test_headeronly_content, + compile_flags=compile_flags, + ) + num_symbols_headeronly = len(symbols_headeronly) + print(f"Found {num_symbols_headeronly} symbols in torch/headeronly") + assert num_symbols_headeronly > 0, ( + f"Expected headeronly headers to expose symbols with TORCH_STABLE_ONLY, " + f"but found {num_symbols_headeronly} symbols" + ) + + +def check_aoti_shim_symbols(install_root: Path) -> None: + """ + Test that AOTI shim headers still expose symbols with TORCH_STABLE_ONLY. + """ + include_dir = install_root / "include" + assert include_dir.exists(), f"Expected {include_dir} to be present" + + # There are no constexpr symbols etc., so we need to actually use functions + # so that some symbols are found. + test_shim_content = """ +#include +int main() { + int32_t (*fp1)() = &aoti_torch_device_type_cpu; + int32_t (*fp2)() = &aoti_torch_dtype_float32; + (void)fp1; (void)fp2; + return 0; +} +""" + + compile_flags = [ + "g++", + "-std=c++17", + f"-I{include_dir}", + f"-I{include_dir}/torch/csrc/api/include", + "-c", + "-DTORCH_STABLE_ONLY", + ] + + symbols_shim = _compile_and_extract_symbols( + cpp_content=test_shim_content, + compile_flags=compile_flags, + ) + num_symbols_shim = len(symbols_shim) + assert num_symbols_shim > 0, ( + f"Expected shim headers to expose symbols with TORCH_STABLE_ONLY, " + f"but found {num_symbols_shim} symbols" + ) + + +def check_stable_c_shim_symbols(install_root: Path) -> None: + """ + Test that stable C shim headers still expose symbols with TORCH_STABLE_ONLY. + """ + include_dir = install_root / "include" + assert include_dir.exists(), f"Expected {include_dir} to be present" + + # Check if the stable C shim exists + stable_shim = include_dir / "torch" / "csrc" / "stable" / "c" / "shim.h" + if not stable_shim.exists(): + raise RuntimeError("Could not find stable c shim") + + # There are no constexpr symbols etc., so we need to actually use functions + # so that some symbols are found. + test_stable_shim_content = """ +#include +int main() { + // Reference stable C API functions to create undefined symbols + AOTITorchError (*fp1)(const char*, uint32_t*, int32_t*) = &torch_parse_device_string; + AOTITorchError (*fp2)(uint32_t*) = &torch_get_num_threads; + (void)fp1; (void)fp2; + return 0; +} +""" + + compile_flags = [ + "g++", + "-std=c++17", + f"-I{include_dir}", + f"-I{include_dir}/torch/csrc/api/include", + "-c", + "-DTORCH_STABLE_ONLY", + ] + + symbols_stable_shim = _compile_and_extract_symbols( + cpp_content=test_stable_shim_content, + compile_flags=compile_flags, + ) + num_symbols_stable_shim = len(symbols_stable_shim) + assert num_symbols_stable_shim > 0, ( + f"Expected stable C shim headers to expose symbols with TORCH_STABLE_ONLY, " + f"but found {num_symbols_stable_shim} symbols" + ) + + def check_lib_symbols_for_abi_correctness(lib: str) -> None: print(f"lib: {lib}") cxx11_symbols = grep_symbols(lib, LIBTORCH_CXX11_PATTERNS) @@ -129,6 +460,13 @@ def main() -> None: check_lib_symbols_for_abi_correctness(libtorch_cpu_path) check_lib_statically_linked_libstdc_cxx_abi_symbols(libtorch_cpu_path) + # Check symbols when TORCH_STABLE_ONLY is defined + check_stable_only_symbols(install_root) + check_stable_api_symbols(install_root) + check_headeronly_symbols(install_root) + check_aoti_shim_symbols(install_root) + check_stable_c_shim_symbols(install_root) + if __name__ == "__main__": main() diff --git a/setup.py b/setup.py index 314f719ea67f0..ef584cefdd6dd 100644 --- a/setup.py +++ b/setup.py @@ -1358,6 +1358,45 @@ def __exit__(self, *exc_info: object) -> None: # Need to create the proper LICENSE.txt for the wheel class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel): + def _wrap_headers_with_macro(self, bdist_dir: Path) -> None: + """Wrap all header files with #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION). + + Excludes: + - torch/include/torch/headeronly/* + - torch/include/torch/csrc/stable/* + - torch/include/torch/csrc/inductor/aoti_torch/c/ (only shim headers) + - torch/include/torch/csrc/inductor/aoti_torch/generated/ + """ + header_extensions = (".h", ".hpp", ".cuh") + header_files = [ + f for ext in header_extensions for f in bdist_dir.rglob(f"*{ext}") + ] + + # Paths to exclude from wrapping + exclude_dir_patterns = [ + "torch/include/torch/headeronly/", + "torch/include/torch/csrc/stable/", + "torch/include/torch/csrc/inductor/aoti_torch/c/", + "torch/include/torch/csrc/inductor/aoti_torch/generated/", + ] + + for header_file in header_files: + rel_path = header_file.relative_to(bdist_dir).as_posix() + + if any(rel_path.startswith(pattern) for pattern in exclude_dir_patterns): + report(f"Skipping header: {rel_path}") + continue + + original_content = header_file.read_text(encoding="utf-8") + wrapped_content = ( + "#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n" + f"{original_content}" + "\n#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n" + ) + + header_file.write_text(wrapped_content, encoding="utf-8") + report(f"Wrapped header: {rel_path}") + def run(self) -> None: with concat_license_files(include_files=True): super().run() @@ -1380,6 +1419,14 @@ def write_wheelfile(self, *args: Any, **kwargs: Any) -> None: # need an __init__.py file otherwise we wouldn't have a package (bdist_dir / "torch" / "__init__.py").touch() + # Wrap all header files with TORCH_STABLE_ONLY macro + assert self.bdist_dir is not None, "bdist_dir should be set during wheel build" + bdist_dir = Path(self.bdist_dir) + report( + "-- Wrapping header files with if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)" + ) + self._wrap_headers_with_macro(bdist_dir) + class clean(Command): user_options: ClassVar[list[tuple[str, str | None, str]]] = [] diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py index ff2aeff5e932b..405944bc0f9bf 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py @@ -35,7 +35,6 @@ def get_extension(): extra_compile_args = { "cxx": [ "-fdiagnostics-color=always", - "-DTORCH_STABLE_ONLY", "-DTORCH_TARGET_VERSION=0x020a000000000000", ], } diff --git a/test/cpp_extensions/torch_stable_test_extension/setup.py b/test/cpp_extensions/torch_stable_test_extension/setup.py deleted file mode 100644 index 062d466e7ae98..0000000000000 --- a/test/cpp_extensions/torch_stable_test_extension/setup.py +++ /dev/null @@ -1,67 +0,0 @@ -import distutils.command.clean -import shutil -from pathlib import Path - -from setuptools import find_packages, setup - -from torch.utils.cpp_extension import BuildExtension, CppExtension - - -ROOT_DIR = Path(__file__).parent -CSRC_DIR = ROOT_DIR / "torch_stable_test" / "csrc" - - -class clean(distutils.command.clean.clean): - def run(self): - # Run default behavior first - distutils.command.clean.clean.run(self) - - # Remove extension - for path in (ROOT_DIR / "torch_stable_test").glob("**/*.so"): - path.unlink() - # Remove build and dist and egg-info directories - dirs = [ - ROOT_DIR / "build", - ROOT_DIR / "dist", - ROOT_DIR / "torch_stable_test.egg-info", - ] - for path in dirs: - if path.exists(): - shutil.rmtree(str(path), ignore_errors=True) - - -def get_extension(): - extra_compile_args = { - "cxx": ["-fdiagnostics-color=always", "-DTORCH_STABLE_ONLY"], - } - - sources = list(CSRC_DIR.glob("**/*.cpp")) - - return [ - CppExtension( - "torch_stable_test._C", - sources=sorted(str(s) for s in sources), - py_limited_api=True, - extra_compile_args=extra_compile_args, - extra_link_args=[], - ) - ] - - -setup( - name="torch_stable_test", - version="0.0", - author="PyTorch Core Team", - description="Test extension to verify TORCH_STABLE_ONLY flag", - packages=find_packages(exclude=("test",)), - package_data={"torch_stable_test": ["*.dll", "*.dylib", "*.so"]}, - install_requires=[ - "torch", - ], - ext_modules=get_extension(), - cmdclass={ - "build_ext": BuildExtension.with_options(no_python_abi_suffix=True), - "clean": clean, - }, - options={"bdist_wheel": {"py_limited_api": "cp39"}}, -) diff --git a/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/__init__.py b/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/csrc/test_extension.cpp b/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/csrc/test_extension.cpp deleted file mode 100644 index c92d56da11ba3..0000000000000 --- a/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/csrc/test_extension.cpp +++ /dev/null @@ -1 +0,0 @@ -#include // This should trigger the TORCH_STABLE_ONLY error diff --git a/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/test_torch_stable.py b/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/test_torch_stable.py deleted file mode 100644 index 5c5613bb5484e..0000000000000 --- a/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/test_torch_stable.py +++ /dev/null @@ -1,22 +0,0 @@ -# Owner(s): ["module: cpp"] - -from pathlib import Path - -from torch.testing._internal.common_utils import ( - install_cpp_extension, - IS_WINDOWS, - run_tests, - TestCase, -) - - -if not IS_WINDOWS: - - class TestTorchStable(TestCase): - def test_setup_fails(self): - with self.assertRaisesRegex(RuntimeError, "build failed for cpp extension"): - install_cpp_extension(extension_root=Path(__file__).parent.parent) - - -if __name__ == "__main__": - run_tests() diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index 4fb746ea15271..2eda2b218e705 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -38,9 +38,9 @@ // The following files are implemented in a header-only way and are guarded by // test/cpp/aoti_abi_check -#include -#include -#include +#include +#include +#include #ifdef __cplusplus extern "C" { From a0ccd3e5ffacf2e3b44718008ab04ec47d51d7b1 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Tue, 18 Nov 2025 23:24:54 +0000 Subject: [PATCH 042/230] Error when non stable/headeronly/shim headers are included by stable extension (#167855) Address Nikita's offline comment on #167496 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167855 Approved by: https://github.com/janeyx99 ghstack dependencies: #167496 --- .../smoke_test/check_binary_symbols.py | 86 ++++++++------ setup.py | 110 ++++++++++-------- 2 files changed, 111 insertions(+), 85 deletions(-) diff --git a/.ci/pytorch/smoke_test/check_binary_symbols.py b/.ci/pytorch/smoke_test/check_binary_symbols.py index 51d5174e77912..7ad10ca946215 100755 --- a/.ci/pytorch/smoke_test/check_binary_symbols.py +++ b/.ci/pytorch/smoke_test/check_binary_symbols.py @@ -145,14 +145,17 @@ def _compile_and_extract_symbols( def check_stable_only_symbols(install_root: Path) -> None: """ - Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code and comparing symbol counts. + Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code. This approach tests: - 1. WITHOUT macros -> many torch symbols exposed - 2. WITH TORCH_STABLE_ONLY -> zero torch symbols (all hidden) - 3. WITH TORCH_TARGET_VERSION -> zero torch symbols (all hidden) - 4. WITH both macros -> zero torch symbols (all hidden) + 1. WITHOUT macros -> many torch symbols exposed (compilation succeeds) + 2. WITH TORCH_STABLE_ONLY -> compilation fails with #error directive + 3. WITH TORCH_TARGET_VERSION -> compilation fails with #error directive + 4. WITH both macros -> compilation fails with #error directive """ + import subprocess + import tempfile + include_dir = install_root / "include" assert include_dir.exists(), f"Expected {include_dir} to be present" @@ -182,7 +185,7 @@ def check_stable_only_symbols(install_root: Path) -> None: "-c", # Compile only, don't link ] - # Compile WITHOUT any macros + # Compile WITHOUT any macros - should succeed symbols_without = _compile_and_extract_symbols( cpp_content=test_cpp_content, compile_flags=base_compile_flags, @@ -196,49 +199,56 @@ def check_stable_only_symbols(install_root: Path) -> None: "Expected a non-zero number of symbols without any macros" ) - # Compile WITH TORCH_STABLE_ONLY (expect 0 symbols) - compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY"] + # Helper to verify compilation fails with expected error + def _expect_compilation_failure(compile_flags: list[str], macro_name: str) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + tmppath = Path(tmpdir) + cpp_file = tmppath / "test.cpp" + obj_file = tmppath / "test.o" + + cpp_file.write_text(test_cpp_content) + + result = subprocess.run( + compile_flags + [str(cpp_file), "-o", str(obj_file)], + capture_output=True, + text=True, + timeout=60, + ) + + if result.returncode == 0: + raise RuntimeError( + f"Expected compilation to fail with {macro_name} defined, but it succeeded" + ) + + stderr = result.stderr + expected_error_msg = ( + "This file should not be included when either TORCH_STABLE_ONLY " + "or TORCH_TARGET_VERSION is defined." + ) + + if expected_error_msg not in stderr: + raise RuntimeError( + f"Expected error message to contain:\n '{expected_error_msg}'\n" + f"but got:\n{stderr[:1000]}" + ) + + print(f"Compilation correctly failed with {macro_name} defined") - symbols_with_stable_only = _compile_and_extract_symbols( - cpp_content=test_cpp_content, - compile_flags=compile_flags_with_stable_only, - ) - - num_symbols_with_stable_only = len(symbols_with_stable_only) - assert num_symbols_with_stable_only == 0, ( - f"Expected no symbols with TORCH_STABLE_ONLY macro, but found {num_symbols_with_stable_only}" - ) + compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY"] + _expect_compilation_failure(compile_flags_with_stable_only, "TORCH_STABLE_ONLY") - # Compile WITH TORCH_TARGET_VERSION (expect 0 symbols) compile_flags_with_target_version = base_compile_flags + [ "-DTORCH_TARGET_VERSION=1" ] - - symbols_with_target_version = _compile_and_extract_symbols( - cpp_content=test_cpp_content, - compile_flags=compile_flags_with_target_version, + _expect_compilation_failure( + compile_flags_with_target_version, "TORCH_TARGET_VERSION" ) - num_symbols_with_target_version = len(symbols_with_target_version) - assert num_symbols_with_target_version == 0, ( - f"Expected no symbols with TORCH_TARGET_VERSION macro, but found {num_symbols_with_target_version}" - ) - - # Compile WITH both macros (expect 0 symbols) compile_flags_with_both = base_compile_flags + [ "-DTORCH_STABLE_ONLY", "-DTORCH_TARGET_VERSION=1", ] - - symbols_with_both = _compile_and_extract_symbols( - cpp_content=test_cpp_content, - compile_flags=compile_flags_with_both, - ) - - num_symbols_with_both = len(symbols_with_both) - assert num_symbols_with_both == 0, ( - f"Expected no symbols with both macros, but found {num_symbols_with_both}" - ) + _expect_compilation_failure(compile_flags_with_both, "both macros") def check_stable_api_symbols(install_root: Path) -> None: diff --git a/setup.py b/setup.py index ef584cefdd6dd..f15e7bbdd0ac4 100644 --- a/setup.py +++ b/setup.py @@ -1089,6 +1089,60 @@ def check_pydep(importname: str, module: str) -> None: class build_ext(setuptools.command.build_ext.build_ext): + def _wrap_headers_with_macro(self, include_dir: Path) -> None: + """Wrap all header files with #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION). + + Excludes: + - torch/headeronly/* + - torch/csrc/stable/* + - torch/csrc/inductor/aoti_torch/c/ (only shim headers) + - torch/csrc/inductor/aoti_torch/generated/ + + This method is idempotent - it will not wrap headers that are already wrapped. + """ + header_extensions = (".h", ".hpp", ".cuh") + header_files = [ + f for ext in header_extensions for f in include_dir.rglob(f"*{ext}") + ] + + # Paths to exclude from wrapping (relative to include_dir) + exclude_dir_patterns = [ + "torch/headeronly/", + "torch/csrc/stable/", + "torch/csrc/inductor/aoti_torch/c/", + "torch/csrc/inductor/aoti_torch/generated/", + ] + + # Marker to detect if a header is already wrapped + wrap_start_marker = ( + "#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n" + ) + + for header_file in header_files: + rel_path = header_file.relative_to(include_dir).as_posix() + + if any(rel_path.startswith(pattern) for pattern in exclude_dir_patterns): + report(f"Skipping header: {rel_path}") + continue + + original_content = header_file.read_text(encoding="utf-8") + + # Check if already wrapped (idempotency check) + if original_content.startswith(wrap_start_marker): + report(f"Already wrapped, skipping: {rel_path}") + continue + + wrapped_content = ( + wrap_start_marker + + f"{original_content}" + + "\n#else\n" + + '#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."\n' + + "#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n" + ) + + header_file.write_text(wrapped_content, encoding="utf-8") + report(f"Wrapped header: {rel_path}") + def _embed_libomp(self) -> None: # Copy libiomp5.dylib/libomp.dylib inside the wheel package on MacOS build_lib = Path(self.build_lib) @@ -1256,6 +1310,15 @@ def run(self) -> None: super().run() + # Wrap headers with TORCH_STABLE_ONLY and TORCH_TARGET_VERSION guards + build_lib = Path(self.build_lib) + build_torch_include_dir = build_lib / "torch" / "include" + if build_torch_include_dir.exists(): + report( + "-- Wrapping header files with if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)" + ) + self._wrap_headers_with_macro(build_torch_include_dir) + if IS_DARWIN: self._embed_libomp() @@ -1358,45 +1421,6 @@ def __exit__(self, *exc_info: object) -> None: # Need to create the proper LICENSE.txt for the wheel class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel): - def _wrap_headers_with_macro(self, bdist_dir: Path) -> None: - """Wrap all header files with #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION). - - Excludes: - - torch/include/torch/headeronly/* - - torch/include/torch/csrc/stable/* - - torch/include/torch/csrc/inductor/aoti_torch/c/ (only shim headers) - - torch/include/torch/csrc/inductor/aoti_torch/generated/ - """ - header_extensions = (".h", ".hpp", ".cuh") - header_files = [ - f for ext in header_extensions for f in bdist_dir.rglob(f"*{ext}") - ] - - # Paths to exclude from wrapping - exclude_dir_patterns = [ - "torch/include/torch/headeronly/", - "torch/include/torch/csrc/stable/", - "torch/include/torch/csrc/inductor/aoti_torch/c/", - "torch/include/torch/csrc/inductor/aoti_torch/generated/", - ] - - for header_file in header_files: - rel_path = header_file.relative_to(bdist_dir).as_posix() - - if any(rel_path.startswith(pattern) for pattern in exclude_dir_patterns): - report(f"Skipping header: {rel_path}") - continue - - original_content = header_file.read_text(encoding="utf-8") - wrapped_content = ( - "#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n" - f"{original_content}" - "\n#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n" - ) - - header_file.write_text(wrapped_content, encoding="utf-8") - report(f"Wrapped header: {rel_path}") - def run(self) -> None: with concat_license_files(include_files=True): super().run() @@ -1419,14 +1443,6 @@ def write_wheelfile(self, *args: Any, **kwargs: Any) -> None: # need an __init__.py file otherwise we wouldn't have a package (bdist_dir / "torch" / "__init__.py").touch() - # Wrap all header files with TORCH_STABLE_ONLY macro - assert self.bdist_dir is not None, "bdist_dir should be set during wheel build" - bdist_dir = Path(self.bdist_dir) - report( - "-- Wrapping header files with if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)" - ) - self._wrap_headers_with_macro(bdist_dir) - class clean(Command): user_options: ClassVar[list[tuple[str, str | None, str]]] = [] From 5abb7bf8fee800e92028e57ebbb41e2e9f62d499 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 19 Nov 2025 14:26:53 +0000 Subject: [PATCH 043/230] Revert "[SymmMem] Skip multicast init if any CUDA call fails (#168049)" This reverts commit 8cb8b6cbbdbfc790be2921c768dab403157671ef. Reverted https://github.com/pytorch/pytorch/pull/168049 on behalf of https://github.com/yangw-dev due to D87346992 internal error that conflict the main branch, please rebase and try to merge again These changes have conflicts when merging with master branch. Rebase this diff. ([comment](https://github.com/pytorch/pytorch/pull/168049#issuecomment-3552985895)) --- c10/cuda/driver_api.h | 16 --- .../c10d/symm_mem/CUDASymmetricMemory.cu | 115 ++++++++---------- 2 files changed, 48 insertions(+), 83 deletions(-) diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h index 1ff0c9a12ac78..380e7939ff76c 100644 --- a/c10/cuda/driver_api.h +++ b/c10/cuda/driver_api.h @@ -20,22 +20,6 @@ } \ } while (0) -#define C10_CUDA_DRIVER_CHECK_GOTO(EXPR, NEXT) \ - do { \ - CUresult __err = EXPR; \ - if (__err != CUDA_SUCCESS) { \ - const char* err_str; \ - CUresult get_error_str_err [[maybe_unused]] = \ - c10::cuda::DriverAPI::get()->cuGetErrorString_(__err, &err_str); \ - if (get_error_str_err != CUDA_SUCCESS) { \ - TORCH_WARN("CUDA driver error: unknown error"); \ - } else { \ - TORCH_WARN("CUDA driver error: ", err_str); \ - } \ - goto NEXT; \ - } \ - } while (0) - // The integer in the second column specifies the requested CUDA Driver API // version. The dynamic loader will accept a driver with a newer version, but it // ensures that the requested symbol exists in *at least* the specified version diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu index f83d42df4ac68..6352330c3872c 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu @@ -517,11 +517,6 @@ static void init_multicast_for_block( using McHandleType = std::conditional_t; - McHandleType invalidator; - std::memset(&invalidator, UINT8_MAX, sizeof(McHandleType)); - - // Phase 1: export handle (rank 0 only) - McHandleType mc_exported_handle{}; if (rank == 0) { CUmulticastObjectProp mc_prop{}; mc_prop.numDevices = world_size; @@ -530,82 +525,68 @@ static void init_multicast_for_block( // create a multicast object, which acts as a handle that allows multiple // devices or processes to access the same memory allocation coherently. - try { - C10_CUDA_DRIVER_CHECK( - driver_api->cuMulticastCreate_(&mc_handle, &mc_prop)); - // using the CUDA Driver API to export a multicast object into a POSIX file - // descriptor. - C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_( - &mc_exported_handle, mc_handle, handleType, 0)); - } catch (const std::exception& e) { - // Allow peers gracefully skip multicast initialization by sending -1 - mc_exported_handle = invalidator; + auto err = driver_api->cuMulticastCreate_(&mc_handle, &mc_prop); + if (err != CUDA_SUCCESS) { + const char* err_str; + CUresult get_error_str_err = driver_api->cuGetErrorString_(err, &err_str); + if (get_error_str_err != CUDA_SUCCESS) { + err_str = "unknown cuda driver error"; + } LOG(WARNING) - << "SymmetricMemory: fail to export multicast handle.\n" - << e.what(); + << "SymmetricMemory: cuMulticastCreate failed with: \"" << err_str + << "\". Gracefully skipping multicast initialization. " + << "However, this is unexpected. Please report the issue on GitHub."; + // Allow peers gracefully skip multicast initialization by sending -1 + // TODO: allow graceful skip for fabric + if constexpr (!use_fabric_handle) { + ipc_channel.broadcast_fds(rank, 0, pids, -1); + } + return; } - } - - // Phase 2: Exchange handle - McHandleType recv_handle; - if constexpr (!use_fabric_handle) { - recv_handle = ipc_channel.broadcast_fds(rank, 0, pids, mc_exported_handle); - } else { - // TODO implement storeExchange.broadcast - auto gathered_handles = storeExchange.all_gather(store, rank, world_size, mc_exported_handle); - recv_handle = std::move(gathered_handles[0]); - } - - // Check exchange result - if (memcmp(&recv_handle, &invalidator, sizeof(McHandleType)) == 0) { - LOG(WARNING) << "Gracefully skipping multicast initialization."; - return; - } - // Flip to true after all CUDA steps finish - bool success_end = false; + McHandleType mc_exported_handle; + // using the CUDA Driver API to export a multicast object into a POSIX file + // descriptor. + C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_( + &mc_exported_handle, mc_handle, handleType, 0)); + if constexpr (!use_fabric_handle) { + ipc_channel.broadcast_fds(rank, 0, pids, mc_exported_handle); + // Ref count is incremented as soon as SCM_RIGHTS send happens + close(mc_exported_handle); + } else { + // TODO implement storeExchange.broadcast + storeExchange.all_gather(store, rank, world_size, mc_exported_handle); + } - // Phase 3: Import handle (non-0 ranks only) - if (rank != 0) { + } else { if constexpr (!use_fabric_handle) { + int mc_fd = ipc_channel.broadcast_fds(rank, 0, pids, -1); + if (mc_fd == -1) { + return; + } // Convert back to a handle from the broadcasted POSIX file descriptor. - C10_CUDA_DRIVER_CHECK_GOTO(driver_api->cuMemImportFromShareableHandle_( + C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_( &mc_handle, - (void*)(uintptr_t)recv_handle, - CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR), check_all); + (void*)(uintptr_t)mc_fd, + CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); + close(mc_fd); } else { - C10_CUDA_DRIVER_CHECK_GOTO(driver_api->cuMemImportFromShareableHandle_( - &mc_handle, (void*)&(recv_handle), CU_MEM_HANDLE_TYPE_FABRIC), check_all); + CUmemFabricHandle null_handle{}; + auto mc_handles = + storeExchange.all_gather(store, rank, world_size, null_handle); + C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_( + &mc_handle, (void*)&(mc_handles[0]), CU_MEM_HANDLE_TYPE_FABRIC)); } } - // Phase 4: Bind memory // All rank adds their physical allocation to the multicast object - C10_CUDA_DRIVER_CHECK_GOTO( - driver_api->cuMulticastAddDevice_(mc_handle, block->device_idx), check_all); - C10_CUDA_DRIVER_CHECK_GOTO(driver_api->cuMulticastBindMem_( - mc_handle, 0, block->alloc_ref->handle, 0, block->block_size, 0), check_all); - - success_end = true; - -check_all: - // Whether all ranks have succeeded - bool all_succeed = true; - auto rank_successes = storeExchange.all_gather(store, rank, world_size, success_end); - for (int r = 0; r < world_size; ++r) { - all_succeed &= rank_successes[r]; - } - // Close the file descriptor before exit - if constexpr (!use_fabric_handle) { - close(recv_handle); - } - if (!all_succeed) { - LOG(WARNING) << "Gracefully skipping multicast initialization."; - return; - } + C10_CUDA_DRIVER_CHECK( + driver_api->cuMulticastAddDevice_(mc_handle, block->device_idx)); + C10_CUDA_DRIVER_CHECK(driver_api->cuMulticastBindMem_( + mc_handle, 0, block->alloc_ref->handle, 0, block->block_size, 0)); - // Phase 5: Map to virtual memory map_block(&mc_addr, mc_handle, block->block_size, block->device_idx); + storeExchange.barrier(store, rank, world_size); #endif } From c7cf3fb12504a3b7f40b2543bf4d511d64f29a11 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 19 Nov 2025 14:33:26 +0000 Subject: [PATCH 044/230] Revert "[pytree][compile] Slightly faster TreeSpec init (#168024)" This reverts commit db1551bafa129c1eea4ece6c51e2a1f604596e43. Reverted https://github.com/pytorch/pytorch/pull/168024 on behalf of https://github.com/yangw-dev due to Internal merge fail, These changes have conflicts when merging with master branch. Rebase this diff. please rebase the pr and try merge again ([comment](https://github.com/pytorch/pytorch/pull/168024#issuecomment-3553015987)) --- torch/_dynamo/polyfills/pytree.py | 7 ++----- torch/utils/_pytree.py | 7 ++----- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/torch/_dynamo/polyfills/pytree.py b/torch/_dynamo/polyfills/pytree.py index 63a72afa43a6d..1c6283e8a038f 100644 --- a/torch/_dynamo/polyfills/pytree.py +++ b/torch/_dynamo/polyfills/pytree.py @@ -201,11 +201,8 @@ def __post_init__(self, /) -> None: num_children = 0 else: assert callable(self._unflatten_func) - num_nodes = 1 - num_leaves = 0 - for child in self._children: - num_nodes += child.num_nodes - num_leaves += child.num_leaves + num_nodes = sum((spec.num_nodes for spec in self._children), start=1) + num_leaves = sum(spec.num_leaves for spec in self._children) num_children = len(self._children) object.__setattr__(self, "num_nodes", num_nodes) diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 16877719718af..3d2e4d110b6b2 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -1113,11 +1113,8 @@ def __post_init__(self) -> None: num_leaves = 1 num_children = 0 else: - num_nodes = 1 - num_leaves = 0 - for child in self._children: - num_nodes += child.num_nodes - num_leaves += child.num_leaves + num_nodes = sum((spec.num_nodes for spec in self._children), start=1) + num_leaves = sum(spec.num_leaves for spec in self._children) num_children = len(self._children) object.__setattr__(self, "num_nodes", num_nodes) object.__setattr__(self, "num_leaves", num_leaves) From eefc0f87001327f375321159cdc28744154ec1db Mon Sep 17 00:00:00 2001 From: albanD Date: Wed, 19 Nov 2025 15:07:25 +0000 Subject: [PATCH 045/230] Fix link for core maintainers request form (#168089) As per title Pull Request resolved: https://github.com/pytorch/pytorch/pull/168089 Approved by: https://github.com/ezyang, https://github.com/svekars --- docs/source/community/governance.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/community/governance.rst b/docs/source/community/governance.rst index cea24593dca83..ebfadf4e0f69b 100644 --- a/docs/source/community/governance.rst +++ b/docs/source/community/governance.rst @@ -132,7 +132,7 @@ The Process for Nomination * Each module has its own process. Please contact module maintainers for more information. However, if there is no process identified, you can file a request to the core - maintainers by submitting `this form `__. + maintainers by submitting `this form `__. Core maintainers are meeting every three months. * If you are submitting a request to the core maintainers, the information in your request must include the following items: From 962f13f9a54ae2f5df80434b2803a01957050328 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 19 Nov 2025 00:22:56 -0800 Subject: [PATCH 046/230] [compile][to_local] Support Sequence-like placement user defined objects (#168149) grad_placements is a sequence like data structure and therefore can be a UserDefinedObject. In that case, we can extract the tuple and pass along. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168149 Approved by: https://github.com/bdhirsh --- .../tensor/test_dtensor_compile.py | 79 +++++++++++++++++++ torch/_dynamo/variables/tensor.py | 13 +++ 2 files changed, 92 insertions(+) diff --git a/test/distributed/tensor/test_dtensor_compile.py b/test/distributed/tensor/test_dtensor_compile.py index 22493c4451d63..e58b6dda658f3 100644 --- a/test/distributed/tensor/test_dtensor_compile.py +++ b/test/distributed/tensor/test_dtensor_compile.py @@ -63,6 +63,54 @@ dev_type = torch.device(get_devtype()) +class PytreeTuple: + """ + Tuple-like values that are treated as leaves of a PyTree. + """ + + def __init__(self, *values): + self._values = tuple(values) + + def __repr__(self): + pr = repr(self._values)[1:-1] + return f"{type(self).__name__}({pr})" + + def __getitem__(self, i): + return self._values[i] + + def __iter__(self): + return iter(self._values) + + def __len__(self): + return len(self._values) + + def __eq__(self, other: object) -> bool: + if isinstance(other, self.__class__): + return self._values == other._values + elif isinstance(other, tuple): + return self._values == other + return False + + def __hash__(self) -> int: + return hash(self._values) + + def __add__(self, other): + if isinstance(other, (self.__class__, tuple)): + return self.__class__(*self, *other) + raise NotImplementedError(type(other)) + + def __radd__(self, other): + if isinstance(other, (self.__class__, tuple)): + return self.__class__(*other, *self) + raise NotImplementedError(type(other)) + + def index(self, value): + return self._values.index(value) + + def count(self, value): + return self._values.count(value) + + class SimpleModel(nn.Module): def __init__(self, device): super().__init__() @@ -767,6 +815,37 @@ def fn(x): # this fails with an inductor stride assert out_dt.to_local().sum().backward() + def test_dynamo_to_local_grad_placements_sequence(self): + placements = PytreeTuple([Shard(0)]) + + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + def fn(x): + return dt.to_local(grad_placements=placements) + 2 + + fn_opt = torch.compile(fn, backend="aot_eager", fullgraph=True) + x = torch.ones(4) + dt = DTensor.from_local(x, mesh, [Replicate()], run_check=False) + + out_ref = fn(dt) + out_test = fn_opt(dt) + self.assertEqual(out_ref, out_test) + + def test_dynamo_to_local_grad_placements_sequence_intermediate(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + def fn(x): + placements = PytreeTuple([Shard(0)]) + return dt.to_local(grad_placements=placements) + 2 + + fn_opt = torch.compile(fn, backend="aot_eager", fullgraph=True) + x = torch.ones(4) + dt = DTensor.from_local(x, mesh, [Replicate()], run_check=False) + + out_ref = fn(dt) + out_test = fn_opt(dt) + self.assertEqual(out_ref, out_test) + def test_dynamo_to_local_kwargs(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 326178ef00874..16fa0997c7f83 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -1266,6 +1266,19 @@ def method_to_local(self, *args, **kwargs): tx = InstructionTranslator.current_tx() # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function # and rewrite args to have only proxyable args, then insert call_function + + grad_placements_vt = kwargs.get( + "grad_placements", ConstantVariable.create(None) + ) + if isinstance(grad_placements_vt, variables.UserDefinedObjectVariable): + # grad_placement is a sequence-like structure, iterate over the value + grad_placements_vt = variables.BuiltinVariable(tuple).call_function( + tx, [grad_placements_vt], {} + ) + + if kwargs.get("grad_placements") is not None: + kwargs["grad_placements"] = grad_placements_vt + args_as_value = [x.as_python_constant() for x in args] kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()} From fb6af11af9c762a18810a9d8ad00ffb3bb257c07 Mon Sep 17 00:00:00 2001 From: Abhishek Nandy Date: Wed, 19 Nov 2025 15:58:29 +0000 Subject: [PATCH 047/230] GroupNorm: include offending values in error message; add test (#167925) # Description Improves the error message in `GroupNorm` to include the actual values of `num_channels` and `num_groups` when validation fails. ## Problem The current error message doesn't show the actual values that caused the error, making debugging harder, especially when values come from variables or calculations. **Before:** ```python ValueError: num_channels must be divisible by num_groups ``` **After:** ```python ValueError: num_channels (10) must be divisible by num_groups (3) ``` ## Solution Include the actual values in the error message using f-string formatting. ## Changes Made - `torch/nn/modules/normalization.py`: Updated error message to include `num_channels` and `num_groups` values - `test/test_nn.py`: Added test to verify error message includes values ## Testing - Manual verification confirms error message includes actual values - New test `test_GroupNorm_error_message_includes_values` verifies the fix - All existing GroupNorm tests continue to pass - Error message is now consistent with other PyTorch error messages Tested it locally all tests are passing ## Example ```python import torch.nn as nn # Before: ValueError: num_channels must be divisible by num_groups # After: ValueError: num_channels (10) must be divisible by num_groups (3) try: model = nn.GroupNorm(num_groups=3, num_channels=10) except ValueError as e: print(e) ``` # Solution Include the actual values in the error message using f-string formatting. ## Changes Made - `torch/nn/modules/normalization.py`: Updated error message to include `num_channels` and `num_groups` values - `test/test_nn.py`: Added test `test_GroupNorm_error_message_includes_values` to verify error message includes values ## Testing - Manual verification confirms error message includes actual values - New test `test_GroupNorm_error_message_includes_values` verifies the fix - All existing GroupNorm tests continue to pass (6/6 tests passing) - Error message is now consistent with other PyTorch error messages ## Change This is a small, focused improvement that enhances error messages without changing any functionality.Thanks for reviewing ,open to any changes if required. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167925 Approved by: https://github.com/mikaylagawarecki, https://github.com/malfet --- torch/nn/modules/normalization.py | 4 +++- torch/testing/_internal/common_modules.py | 27 +++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/torch/nn/modules/normalization.py b/torch/nn/modules/normalization.py index 60bd561bfd0e4..4a7302d5cae33 100644 --- a/torch/nn/modules/normalization.py +++ b/torch/nn/modules/normalization.py @@ -301,7 +301,9 @@ def __init__( factory_kwargs = {"device": device, "dtype": dtype} super().__init__() if num_channels % num_groups != 0: - raise ValueError("num_channels must be divisible by num_groups") + raise ValueError( + f"num_channels ({num_channels}) must be divisible by num_groups ({num_groups})" + ) self.num_groups = num_groups self.num_channels = num_channels diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index 9571cc1209ed6..83fca0b973856 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -1769,6 +1769,32 @@ def module_inputs_torch_nn_GroupNorm(module_info, device, dtype, requires_grad, ] +def module_error_inputs_torch_nn_GroupNorm(module_info, device, dtype, requires_grad, training, **kwargs): + """ + Error inputs for GroupNorm that test error messages include actual values. + """ + return [ + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(3, 10), # num_groups=3, num_channels=10 + forward_input=FunctionInput(), # Not needed for construction error + ), + error_on=ModuleErrorEnum.CONSTRUCTION_ERROR, + error_type=ValueError, + error_regex=r"num_channels \(10\) must be divisible by num_groups \(3\)" + ), + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(5, 13), # num_groups=5, num_channels=13 + forward_input=FunctionInput(), # Not needed for construction error + ), + error_on=ModuleErrorEnum.CONSTRUCTION_ERROR, + error_type=ValueError, + error_regex=r"num_channels \(13\) must be divisible by num_groups \(5\)" + ), + ] + + def module_inputs_torch_nn_Hardshrink(module_info, device, dtype, requires_grad, training, **kwargs): make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -3958,6 +3984,7 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad ), ModuleInfo(torch.nn.GroupNorm, module_inputs_func=module_inputs_torch_nn_GroupNorm, + module_error_inputs_func=module_error_inputs_torch_nn_GroupNorm, dtypes=get_all_fp_dtypes(include_bfloat16=True, include_half=True), skips=( # Tracking at https://github.com/pytorch/pytorch/issues/98089 From 0d7ba9714ac77b2b4a446a9eff913a6ff9dfc782 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 18 Nov 2025 22:11:30 -0800 Subject: [PATCH 048/230] [dynamo][compile time] Special case for torch.utils._pytree._get_node_type (#168054) Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/168054 Approved by: https://github.com/XuehaiPan, https://github.com/zou3519, https://github.com/mlazos --- test/dynamo/test_repros.py | 59 ++++++++++++++++++++++++++++ torch/_dynamo/trace_rules.py | 2 + torch/_dynamo/variables/__init__.py | 1 + torch/_dynamo/variables/functions.py | 39 ++++++++++++++++++ 4 files changed, 101 insertions(+) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 10342f56d55d1..aab7d5268fcdc 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -8184,6 +8184,65 @@ def fn(x): self.assertEqual(fn(torch.ones(3)), torch.ones(3) + 1) + def test_pytree_get_node_type_not_traced(self): + # Test that torch.utils._pytree._get_node_type is not traced into + # and doesn't cause excessive trace time overhead + from torch.utils._pytree import _get_node_type + + cnt = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnt, fullgraph=True) + def fn(x, y): + # Call _get_node_type which is used internally by pytree operations + node_type = _get_node_type([x, y]) + assert node_type is list + # Do some work with pytree structures + data = {"a": x, "b": y} + flat, spec = pytree.tree_flatten(data) + result = flat[0] + flat[1] + return result + + x = torch.randn(3, 4) + y = torch.randn(3, 4) + result = fn(x, y) + expected = x + y + + self.assertTrue(torch.allclose(result, expected)) + # Should compile successfully with fullgraph=True + self.assertEqual(cnt.frame_count, 1) + + def test_pytree_get_node_type_with_namedtuple(self): + # Test that torch.utils._pytree._get_node_type handles namedtuples correctly + # without being traced into, even when is_namedtuple_class is True + from collections import namedtuple + + from torch.utils._pytree import _get_node_type + + Point = namedtuple("Point", ["x", "y"]) + + cnt = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnt, fullgraph=True) + def fn(a, b): + # Create a namedtuple + point = Point(a, b) + # Call _get_node_type with a namedtuple instance + node_type = _get_node_type(point) + assert node_type is namedtuple + # Use pytree operations with namedtuples + flat, spec = pytree.tree_flatten(point) + result = flat[0] + flat[1] + return result + + x = torch.randn(3, 4) + y = torch.randn(3, 4) + result = fn(x, y) + expected = x + y + + self.assertTrue(torch.allclose(result, expected)) + # Should compile successfully with fullgraph=True + self.assertEqual(cnt.frame_count, 1) + instantiate_parametrized_tests(ReproTests) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 97a3946b48bde..36093b042002e 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -64,6 +64,7 @@ LocalGeneratorObjectVariable, NestedUserFunctionVariable, PolyfilledFunctionVariable, + PyTreeGetNodeTypeFunctionVariable, ReparametrizeModuleCallVariable, SkipFunctionVariable, TorchInGraphFunctionVariable, @@ -378,6 +379,7 @@ f"torch/testing/_internal/distributed/_tensor/common_dtensor.py#{TORCH_DYNAMO_RESUME_IN_PREFIX}": UserFunctionVariable, "torch/testing/_internal/common_distributed.py#forward": UserFunctionVariable, f"torch/testing/_internal/common_distributed.py#{TORCH_DYNAMO_RESUME_IN_PREFIX}": UserFunctionVariable, + "torch.utils._pytree._get_node_type": PyTreeGetNodeTypeFunctionVariable, } diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index 74165b30bb2f0..ac0be3e5888be 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -64,6 +64,7 @@ LocalGeneratorObjectVariable, NestedUserFunctionVariable, PolyfilledFunctionVariable, + PyTreeGetNodeTypeFunctionVariable, SkipFunctionVariable, TMADescriptorExperimentalVariable, TMADescriptorStableVariable, diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index e30eeeb2c2fde..459b8e0bf6230 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -29,6 +29,7 @@ import sys import traceback import types +from collections import namedtuple from collections.abc import Callable, Sequence from types import CellType, FunctionType from typing import Any, Optional, TYPE_CHECKING, TypeVar @@ -38,6 +39,7 @@ import torch from torch._dynamo.exc import get_stack_above_dynamo from torch._guards import Source +from torch.utils._pytree import is_namedtuple_class from .. import config, graph_break_hints, polyfills, variables from ..bytecode_transformation import create_call_function, create_rot_n, is_generator @@ -63,6 +65,7 @@ DefaultsSource, GetItemSource, SkipGuardSource, + TypeSource, ) from ..utils import ( check_constant_args, @@ -2717,3 +2720,39 @@ def call_function( tensor=tensor, # type: ignore[arg-type] block_shape=block_shape, # type: ignore[arg-type] ) + + +class PyTreeGetNodeTypeFunctionVariable(UserFunctionVariable): + """ + `torch.utils._pytree._get_node_type` function is very hot function. We want to special case it to reduce Dynamo tracing time. + + def _get_node_type(tree: Any) -> Any: + node_type = type(tree) + # All namedtuple types are implicitly registered as pytree nodes. + # XXX: Other parts of the codebase expect namedtuple types always return + # `namedtuple` instead of the actual namedtuple type. Even if the type + # is explicitly registered. + if is_namedtuple_class(node_type): + return namedtuple + return node_type + """ + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if len(args) != 1: + raise_type_error_exc( + tx, + f"pytree_get_node_type requires exactly 1 argument, got {len(args)}", + ) + type_source = None + if args[0].source: + install_guard(args[0].source.make_guard(GuardBuilder.TYPE_MATCH)) + type_source = TypeSource(args[0].source) + python_type = args[0].python_type() + if is_namedtuple_class(python_type): + return VariableTracker.build(tx, namedtuple) + return VariableTracker.build(tx, python_type, source=type_source) From 7a928397cda89b71c24b0efe9db6df7fb04a46cb Mon Sep 17 00:00:00 2001 From: Isalia20 Date: Wed, 19 Nov 2025 17:24:13 +0000 Subject: [PATCH 049/230] [MPS] permute op for sparse tensors (#168154) permute op for sparse tensors Pull Request resolved: https://github.com/pytorch/pytorch/pull/168154 Approved by: https://github.com/malfet --- aten/src/ATen/native/native_functions.yaml | 2 +- test/test_sparse.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index fd88794d38f52..4fa24ff378d72 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4617,7 +4617,7 @@ dispatch: CompositeExplicitAutograd: permute MPS: permute_mps - SparseCPU, SparseCUDA: permute_sparse_coo + SparseCPU, SparseCUDA, SparseMPS: permute_sparse_coo tags: core - func: movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a) diff --git a/test/test_sparse.py b/test/test_sparse.py index 65d624fc9dd8f..21530352cef9a 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -995,7 +995,6 @@ def test_shape(sparse_dims, nnz, with_size): @coalescedonoff @dtypes(torch.double, torch.cdouble) @dtypesIfMPS(torch.float32, torch.complex64) - @expectedFailureMPS @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupported triggers assertion error") @gradcheck_semantics() def test_permute(self, device, dtype, coalesced, gradcheck): @@ -1035,7 +1034,8 @@ def test_shape(sparse_dims, nnz, with_size): else: self.assertFalse(s_permuted.is_coalesced()) - gradcheck(lambda t: t.permute(dims).to_dense(masked_grad=gradcheck.masked), s.requires_grad_()) + kwargs = {"eps": 1e-4} if device == "mps:0" else {} + gradcheck(lambda t: t.permute(dims).to_dense(masked_grad=gradcheck.masked), s.requires_grad_(), **kwargs) else: # otherwise check if exception is thrown fail_message = "transpositions between sparse and dense dimensions are not allowed" From a097e166db7077f1e8da94757ccd91a6a521550e Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 19 Nov 2025 17:59:50 +0000 Subject: [PATCH 050/230] Revert "Error when non stable/headeronly/shim headers are included by stable extension (#167855)" This reverts commit a0ccd3e5ffacf2e3b44718008ab04ec47d51d7b1. Reverted https://github.com/pytorch/pytorch/pull/167855 on behalf of https://github.com/atalman due to Failing validations ([comment](https://github.com/pytorch/pytorch/pull/167855#issuecomment-3553987894)) --- .../smoke_test/check_binary_symbols.py | 86 ++++++-------- setup.py | 110 ++++++++---------- 2 files changed, 85 insertions(+), 111 deletions(-) diff --git a/.ci/pytorch/smoke_test/check_binary_symbols.py b/.ci/pytorch/smoke_test/check_binary_symbols.py index 7ad10ca946215..51d5174e77912 100755 --- a/.ci/pytorch/smoke_test/check_binary_symbols.py +++ b/.ci/pytorch/smoke_test/check_binary_symbols.py @@ -145,17 +145,14 @@ def _compile_and_extract_symbols( def check_stable_only_symbols(install_root: Path) -> None: """ - Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code. + Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code and comparing symbol counts. This approach tests: - 1. WITHOUT macros -> many torch symbols exposed (compilation succeeds) - 2. WITH TORCH_STABLE_ONLY -> compilation fails with #error directive - 3. WITH TORCH_TARGET_VERSION -> compilation fails with #error directive - 4. WITH both macros -> compilation fails with #error directive + 1. WITHOUT macros -> many torch symbols exposed + 2. WITH TORCH_STABLE_ONLY -> zero torch symbols (all hidden) + 3. WITH TORCH_TARGET_VERSION -> zero torch symbols (all hidden) + 4. WITH both macros -> zero torch symbols (all hidden) """ - import subprocess - import tempfile - include_dir = install_root / "include" assert include_dir.exists(), f"Expected {include_dir} to be present" @@ -185,7 +182,7 @@ def check_stable_only_symbols(install_root: Path) -> None: "-c", # Compile only, don't link ] - # Compile WITHOUT any macros - should succeed + # Compile WITHOUT any macros symbols_without = _compile_and_extract_symbols( cpp_content=test_cpp_content, compile_flags=base_compile_flags, @@ -199,56 +196,49 @@ def check_stable_only_symbols(install_root: Path) -> None: "Expected a non-zero number of symbols without any macros" ) - # Helper to verify compilation fails with expected error - def _expect_compilation_failure(compile_flags: list[str], macro_name: str) -> None: - with tempfile.TemporaryDirectory() as tmpdir: - tmppath = Path(tmpdir) - cpp_file = tmppath / "test.cpp" - obj_file = tmppath / "test.o" - - cpp_file.write_text(test_cpp_content) - - result = subprocess.run( - compile_flags + [str(cpp_file), "-o", str(obj_file)], - capture_output=True, - text=True, - timeout=60, - ) - - if result.returncode == 0: - raise RuntimeError( - f"Expected compilation to fail with {macro_name} defined, but it succeeded" - ) - - stderr = result.stderr - expected_error_msg = ( - "This file should not be included when either TORCH_STABLE_ONLY " - "or TORCH_TARGET_VERSION is defined." - ) - - if expected_error_msg not in stderr: - raise RuntimeError( - f"Expected error message to contain:\n '{expected_error_msg}'\n" - f"but got:\n{stderr[:1000]}" - ) - - print(f"Compilation correctly failed with {macro_name} defined") - + # Compile WITH TORCH_STABLE_ONLY (expect 0 symbols) compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY"] - _expect_compilation_failure(compile_flags_with_stable_only, "TORCH_STABLE_ONLY") + symbols_with_stable_only = _compile_and_extract_symbols( + cpp_content=test_cpp_content, + compile_flags=compile_flags_with_stable_only, + ) + + num_symbols_with_stable_only = len(symbols_with_stable_only) + assert num_symbols_with_stable_only == 0, ( + f"Expected no symbols with TORCH_STABLE_ONLY macro, but found {num_symbols_with_stable_only}" + ) + + # Compile WITH TORCH_TARGET_VERSION (expect 0 symbols) compile_flags_with_target_version = base_compile_flags + [ "-DTORCH_TARGET_VERSION=1" ] - _expect_compilation_failure( - compile_flags_with_target_version, "TORCH_TARGET_VERSION" + + symbols_with_target_version = _compile_and_extract_symbols( + cpp_content=test_cpp_content, + compile_flags=compile_flags_with_target_version, ) + num_symbols_with_target_version = len(symbols_with_target_version) + assert num_symbols_with_target_version == 0, ( + f"Expected no symbols with TORCH_TARGET_VERSION macro, but found {num_symbols_with_target_version}" + ) + + # Compile WITH both macros (expect 0 symbols) compile_flags_with_both = base_compile_flags + [ "-DTORCH_STABLE_ONLY", "-DTORCH_TARGET_VERSION=1", ] - _expect_compilation_failure(compile_flags_with_both, "both macros") + + symbols_with_both = _compile_and_extract_symbols( + cpp_content=test_cpp_content, + compile_flags=compile_flags_with_both, + ) + + num_symbols_with_both = len(symbols_with_both) + assert num_symbols_with_both == 0, ( + f"Expected no symbols with both macros, but found {num_symbols_with_both}" + ) def check_stable_api_symbols(install_root: Path) -> None: diff --git a/setup.py b/setup.py index f15e7bbdd0ac4..ef584cefdd6dd 100644 --- a/setup.py +++ b/setup.py @@ -1089,60 +1089,6 @@ def check_pydep(importname: str, module: str) -> None: class build_ext(setuptools.command.build_ext.build_ext): - def _wrap_headers_with_macro(self, include_dir: Path) -> None: - """Wrap all header files with #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION). - - Excludes: - - torch/headeronly/* - - torch/csrc/stable/* - - torch/csrc/inductor/aoti_torch/c/ (only shim headers) - - torch/csrc/inductor/aoti_torch/generated/ - - This method is idempotent - it will not wrap headers that are already wrapped. - """ - header_extensions = (".h", ".hpp", ".cuh") - header_files = [ - f for ext in header_extensions for f in include_dir.rglob(f"*{ext}") - ] - - # Paths to exclude from wrapping (relative to include_dir) - exclude_dir_patterns = [ - "torch/headeronly/", - "torch/csrc/stable/", - "torch/csrc/inductor/aoti_torch/c/", - "torch/csrc/inductor/aoti_torch/generated/", - ] - - # Marker to detect if a header is already wrapped - wrap_start_marker = ( - "#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n" - ) - - for header_file in header_files: - rel_path = header_file.relative_to(include_dir).as_posix() - - if any(rel_path.startswith(pattern) for pattern in exclude_dir_patterns): - report(f"Skipping header: {rel_path}") - continue - - original_content = header_file.read_text(encoding="utf-8") - - # Check if already wrapped (idempotency check) - if original_content.startswith(wrap_start_marker): - report(f"Already wrapped, skipping: {rel_path}") - continue - - wrapped_content = ( - wrap_start_marker - + f"{original_content}" - + "\n#else\n" - + '#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."\n' - + "#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n" - ) - - header_file.write_text(wrapped_content, encoding="utf-8") - report(f"Wrapped header: {rel_path}") - def _embed_libomp(self) -> None: # Copy libiomp5.dylib/libomp.dylib inside the wheel package on MacOS build_lib = Path(self.build_lib) @@ -1310,15 +1256,6 @@ def run(self) -> None: super().run() - # Wrap headers with TORCH_STABLE_ONLY and TORCH_TARGET_VERSION guards - build_lib = Path(self.build_lib) - build_torch_include_dir = build_lib / "torch" / "include" - if build_torch_include_dir.exists(): - report( - "-- Wrapping header files with if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)" - ) - self._wrap_headers_with_macro(build_torch_include_dir) - if IS_DARWIN: self._embed_libomp() @@ -1421,6 +1358,45 @@ def __exit__(self, *exc_info: object) -> None: # Need to create the proper LICENSE.txt for the wheel class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel): + def _wrap_headers_with_macro(self, bdist_dir: Path) -> None: + """Wrap all header files with #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION). + + Excludes: + - torch/include/torch/headeronly/* + - torch/include/torch/csrc/stable/* + - torch/include/torch/csrc/inductor/aoti_torch/c/ (only shim headers) + - torch/include/torch/csrc/inductor/aoti_torch/generated/ + """ + header_extensions = (".h", ".hpp", ".cuh") + header_files = [ + f for ext in header_extensions for f in bdist_dir.rglob(f"*{ext}") + ] + + # Paths to exclude from wrapping + exclude_dir_patterns = [ + "torch/include/torch/headeronly/", + "torch/include/torch/csrc/stable/", + "torch/include/torch/csrc/inductor/aoti_torch/c/", + "torch/include/torch/csrc/inductor/aoti_torch/generated/", + ] + + for header_file in header_files: + rel_path = header_file.relative_to(bdist_dir).as_posix() + + if any(rel_path.startswith(pattern) for pattern in exclude_dir_patterns): + report(f"Skipping header: {rel_path}") + continue + + original_content = header_file.read_text(encoding="utf-8") + wrapped_content = ( + "#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n" + f"{original_content}" + "\n#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n" + ) + + header_file.write_text(wrapped_content, encoding="utf-8") + report(f"Wrapped header: {rel_path}") + def run(self) -> None: with concat_license_files(include_files=True): super().run() @@ -1443,6 +1419,14 @@ def write_wheelfile(self, *args: Any, **kwargs: Any) -> None: # need an __init__.py file otherwise we wouldn't have a package (bdist_dir / "torch" / "__init__.py").touch() + # Wrap all header files with TORCH_STABLE_ONLY macro + assert self.bdist_dir is not None, "bdist_dir should be set during wheel build" + bdist_dir = Path(self.bdist_dir) + report( + "-- Wrapping header files with if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)" + ) + self._wrap_headers_with_macro(bdist_dir) + class clean(Command): user_options: ClassVar[list[tuple[str, str | None, str]]] = [] From ce9377d8447d5c6d0a9e00e1a93e5f50e76887be Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 19 Nov 2025 09:04:10 -0800 Subject: [PATCH 051/230] [BE] Remove erroneous `const_cast` (#168165) Not sure what was the purpose of using `TensorBase::const_data_ptr` template, but 1st argument of `memcpy` should be a mutable pointer, therefore replacing it with `TensorBase::data_ptr` Introduced in https://github.com/pytorch/pytorch/pull/134712 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168165 Approved by: https://github.com/ngimel --- aten/src/ATen/native/cuda/Nonzero.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/cuda/Nonzero.cu b/aten/src/ATen/native/cuda/Nonzero.cu index 8811f8dc5117e..ed32e9ac45b30 100644 --- a/aten/src/ATen/native/cuda/Nonzero.cu +++ b/aten/src/ATen/native/cuda/Nonzero.cu @@ -212,7 +212,7 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) { std::nullopt /* memory format */ ); at::cuda::memcpy_and_sync( - (void*)pinned_num_nonzeros_h.const_data_ptr(), + pinned_num_nonzeros_h.data_ptr(), num_nonzeros.get(), sizeof(int) * num_chunks, cudaMemcpyDeviceToHost, From a8ccc4e84f8f99192cf94cb6ef9ea08f295ba881 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 18 Nov 2025 22:11:30 -0800 Subject: [PATCH 052/230] [dynamo][pytree][compile time] Specialize tree_is_leaf (#168070) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168070 Approved by: https://github.com/XuehaiPan, https://github.com/fxdawnn, https://github.com/mlazos, https://github.com/zou3519, https://github.com/williamwen42 ghstack dependencies: #168054 --- test/dynamo/test_repros.py | 65 ++++++++++++++++++++++++++++ torch/_dynamo/trace_rules.py | 2 + torch/_dynamo/variables/__init__.py | 1 + torch/_dynamo/variables/functions.py | 65 ++++++++++++++++++++++++++++ 4 files changed, 133 insertions(+) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index aab7d5268fcdc..24b8f4c48aa32 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -8243,6 +8243,71 @@ def fn(a, b): # Should compile successfully with fullgraph=True self.assertEqual(cnt.frame_count, 1) + def test_pytree_tree_is_leaf_not_traced(self): + # Test that torch.utils._pytree.tree_is_leaf is not traced into + # when is_leaf parameter is None (the common case) + from torch.utils._pytree import tree_is_leaf + + cnt = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnt, fullgraph=True) + def fn(x, y): + # Test with various types + # Tensors are leaves + is_leaf_tensor = tree_is_leaf(x) + assert is_leaf_tensor is True + + # Lists are not leaves (they're in SUPPORTED_NODES) + is_leaf_list = tree_is_leaf([x, y]) + assert is_leaf_list is False + + # Dicts are not leaves + is_leaf_dict = tree_is_leaf({"a": x, "b": y}) + assert is_leaf_dict is False + + return x + y + + x = torch.randn(3, 4) + y = torch.randn(3, 4) + result = fn(x, y) + expected = x + y + + self.assertTrue(torch.allclose(result, expected)) + # Should compile successfully with fullgraph=True + self.assertEqual(cnt.frame_count, 1) + + def test_pytree_tree_is_leaf_with_namedtuple(self): + # Test that torch.utils._pytree.tree_is_leaf handles namedtuples correctly + from collections import namedtuple + + from torch.utils._pytree import tree_is_leaf + + Point = namedtuple("Point", ["x", "y"]) + + cnt = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnt, fullgraph=True) + def fn(a, b): + # Namedtuples are not leaves (they're in SUPPORTED_NODES) + point = Point(a, b) + is_leaf_namedtuple = tree_is_leaf(point) + assert is_leaf_namedtuple is False + + # But individual tensors are leaves + is_leaf_tensor = tree_is_leaf(a) + assert is_leaf_tensor is True + + return a + b + + x = torch.randn(3, 4) + y = torch.randn(3, 4) + result = fn(x, y) + expected = x + y + + self.assertTrue(torch.allclose(result, expected)) + # Should compile successfully with fullgraph=True + self.assertEqual(cnt.frame_count, 1) + instantiate_parametrized_tests(ReproTests) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 36093b042002e..083c8b1f93807 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -65,6 +65,7 @@ NestedUserFunctionVariable, PolyfilledFunctionVariable, PyTreeGetNodeTypeFunctionVariable, + PyTreeTreeIsLeafFunctionVariable, ReparametrizeModuleCallVariable, SkipFunctionVariable, TorchInGraphFunctionVariable, @@ -380,6 +381,7 @@ "torch/testing/_internal/common_distributed.py#forward": UserFunctionVariable, f"torch/testing/_internal/common_distributed.py#{TORCH_DYNAMO_RESUME_IN_PREFIX}": UserFunctionVariable, "torch.utils._pytree._get_node_type": PyTreeGetNodeTypeFunctionVariable, + "torch.utils._pytree.tree_is_leaf": PyTreeTreeIsLeafFunctionVariable, } diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index ac0be3e5888be..439ce274b7ce6 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -65,6 +65,7 @@ NestedUserFunctionVariable, PolyfilledFunctionVariable, PyTreeGetNodeTypeFunctionVariable, + PyTreeTreeIsLeafFunctionVariable, SkipFunctionVariable, TMADescriptorExperimentalVariable, TMADescriptorStableVariable, diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 459b8e0bf6230..7916187193bae 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -65,6 +65,7 @@ DefaultsSource, GetItemSource, SkipGuardSource, + TorchSource, TypeSource, ) from ..utils import ( @@ -118,6 +119,13 @@ _spec_cache: WeakKeyDictionary[Any, Any] = WeakKeyDictionary() +@functools.lru_cache +def get_pytree_SUPPORTED_NODES_source(): + return AttrSource( + AttrSource(AttrSource(TorchSource(), "utils"), "_pytree"), "SUPPORTED_NODES" + ) + + class FunctionSpec: def __init__(self, func: FunctionType): code = func.__code__ @@ -2756,3 +2764,60 @@ def call_function( if is_namedtuple_class(python_type): return VariableTracker.build(tx, namedtuple) return VariableTracker.build(tx, python_type, source=type_source) + + +class PyTreeTreeIsLeafFunctionVariable(UserFunctionVariable): + """ + `torch.utils._pytree.tree_is_leaf` function is a hot function. We want to special case it to reduce Dynamo tracing time. + + def tree_is_leaf( + tree: PyTree, + is_leaf: Callable[[PyTree], bool] | None = None, + ) -> bool: + if is_leaf is not None and is_leaf(tree): + return True + return _get_node_type(tree) not in SUPPORTED_NODES + + When is_leaf is None (the common case), we can optimize by not tracing into the function. + When is_leaf is not None, we fall back to regular tracing since it requires executing user code. + """ + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # tree_is_leaf(tree, is_leaf=None) + if len(args) < 1 or len(args) > 2: + raise_type_error_exc( + tx, + f"tree_is_leaf requires 1 or 2 arguments, got {len(args)}", + ) + + # Check if is_leaf parameter is provided + is_leaf = kwargs.get("is_leaf", ConstantVariable.create(None)) + if len(args) == 2: + is_leaf = args[1] + + if not ( + isinstance(is_leaf, variables.ConstantVariable) and is_leaf.value is None + ): + return super().call_function(tx, args, kwargs) + + # Optimize the case where is_leaf is None + # return _get_node_type(tree) not in SUPPORTED_NODES + tree = args[0] + node_type_var = PyTreeGetNodeTypeFunctionVariable( + torch.utils._pytree._get_node_type + ).call_function(tx, [tree], {}) + + # If the SUPPORTED_NODES was seen earlier and mutated, there would be a + # source and that will give us the mutated SUPPORTED_NODES. + supported_nodes_var = VariableTracker.build( + tx, + torch.utils._pytree.SUPPORTED_NODES, + source=get_pytree_SUPPORTED_NODES_source(), + ) + out = supported_nodes_var.call_method(tx, "__contains__", [node_type_var], {}) + return ConstantVariable.create(not out.value) From 2e1821bfda3602044657e0edb33d5700c9b86671 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Mon, 17 Nov 2025 20:30:34 -0800 Subject: [PATCH 053/230] Support AC in default partitioner when functionalization is enabled (#166610) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166610 Approved by: https://github.com/SherlockNoMad ghstack dependencies: #166536, #168055 --- .../distributed/tensor/test_dtensor_export.py | 2 - test/dynamo/test_activation_checkpointing.py | 267 ++++++++++--- test/functorch/test_aotdispatch.py | 15 +- test/higher_order_ops/test_local_map.py | 4 +- .../_aot_autograd/functional_utils.py | 20 +- .../_aot_autograd/graph_capture_wrappers.py | 5 + torch/_functorch/partitioners.py | 370 +++++++++++------- 7 files changed, 479 insertions(+), 204 deletions(-) diff --git a/test/distributed/tensor/test_dtensor_export.py b/test/distributed/tensor/test_dtensor_export.py index bd75668ab4856..4a88cf9a6e0b1 100644 --- a/test/distributed/tensor/test_dtensor_export.py +++ b/test/distributed/tensor/test_dtensor_export.py @@ -1,7 +1,6 @@ # Owner(s): ["oncall: distributed"] import contextlib -import unittest import torch import torch.distributed as dist @@ -357,7 +356,6 @@ def test_export_parallelize_module_with_dtensor_input( # aot_export_joint_with_descriptors on strict-exported exported_program.module() # is producing a joint graph with backward region missing - @unittest.expectedFailure def test_strict_export_parallelize_module_with_dtensor_input(self): self._run_test(strict_export_and_aot_export_joint_with_descriptors) diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index 0d32a9e4917f5..768555efd1d4c 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -15,7 +15,7 @@ import torch.distributed as dist import torch.nn as nn import torch.utils.checkpoint -from functorch.compile import min_cut_rematerialization_partition +from functorch.compile import default_partition, min_cut_rematerialization_partition from torch._dynamo.backends.common import aot_autograd from torch._dynamo.testing import ( AotEagerAndRecordGraphs, @@ -24,7 +24,7 @@ ) from torch._higher_order_ops.wrap import tag_activation_checkpoint from torch.testing._internal.common_device_type import instantiate_device_type_tests -from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu +from torch.testing._internal.common_utils import IS_WINDOWS, parametrize, skipIfHpu from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON from torch.testing._internal.triton_utils import requires_cuda_and_triton from torch.testing._internal.two_tensor import TwoTensor @@ -281,7 +281,14 @@ def runtime_wrapper(*runtime_args): run(export_compiler) - def test_tags_function(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_function(self, device, partition_fn): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) @@ -297,11 +304,22 @@ def fn(x, y): bw_compiler = functools.partial( count_ops, freq=3, op=torch.ops.aten.mm.default ) # mm recomputed in the bwd - backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) self._validate(fn, backend, x, y) @requires_cuda_and_triton - def test_tags_function_via_global_checkpoint(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_function_via_global_checkpoint(self, device, partition_fn): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) @@ -316,17 +334,28 @@ def fn(x, y): bw_compiler = functools.partial( count_ops, freq=3, op=torch.ops.aten.mm.default ) # mm recomputed in the bwd - backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) self._validate(fn, backend, x, y) @requires_cuda_and_triton - def test_tags_function_with_kwargs(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_function_with_kwargs(self, device, partition_fn): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) def fn(x, y): return torch.utils.checkpoint.checkpoint( - gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False + gn, torch.sin(x), y, use_reentrant=False ) x = torch.randn(4, 4, device=device, requires_grad=True) @@ -336,11 +365,22 @@ def fn(x, y): bw_compiler = functools.partial( count_ops, freq=3, op=torch.ops.aten.mm.default ) # mm recomputed in the bwd - backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) self._validate(fn, backend, x, y) @requires_cuda_and_triton - def test_tags_sequential_layers(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_sequential_layers(self, device, partition_fn): def gn(x): x = x.cos() for _ in range(3): @@ -361,11 +401,22 @@ def fn(x): freqs=[2, 18], ops=[torch.ops.aten.cos.default, torch.ops.aten.mm.default], ) # mm recomputed in the bwd - backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) self._validate(fn, backend, x) @requires_cuda_and_triton - def test_tags_multiple_checkpoints(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_multiple_checkpoints(self, device, partition_fn): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) @@ -383,11 +434,22 @@ def fn(x, y): bw_compiler = functools.partial( count_ops, freq=6, op=torch.ops.aten.mm.default ) # mm recomputed in the bwd - backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) self._validate(fn, backend, x, y) @requires_cuda_and_triton - def test_tags_module(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_module(self, device, partition_fn): class MockModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -411,11 +473,22 @@ def fn(x): bw_compiler = functools.partial( count_ops, freq=1, op=torch.ops.aten.sigmoid.default ) - backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) self._validate(fn, backend, x) @requires_cuda_and_triton - def test_tags_decomps(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_decomps(self, device, partition_fn): # Ensures that tags are passed on through decompositions as well class MockModule(torch.nn.Module): def __init__(self) -> None: @@ -443,6 +516,7 @@ def fn(x): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, + partition_fn=partition_fn, decompositions=lambda: import_module( "torch._inductor.compile_fx" ).select_decomp_table(), @@ -702,7 +776,14 @@ def fn(x, y): @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_must_recompute(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_must_recompute(self, device, partition_fn): def context_fn_must_recompute_mm(): must_recompute_list = [ torch.ops.aten.mm.default, @@ -723,9 +804,9 @@ def context_fn_no_recompute_mm(): ), ) - def _test(context_fn, bw_compiler): + def _test(context_fn, bw_compiler, partition_fn): def gn(x): - return torch.sigmoid(torch.matmul(x, x)) + return torch.cos(torch.sin(torch.matmul(x, x) @ x)) def fn(x): return torch.utils.checkpoint.checkpoint( @@ -739,14 +820,14 @@ def fn(x): fw_compiler = functools.partial( count_ops, - freq=1, + freq=2, op=torch.ops.aten.mm.default, ) backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x) @@ -754,17 +835,19 @@ def fn(x): context_fn=context_fn_must_recompute_mm, bw_compiler=functools.partial( count_ops, - freq=3, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 1 + 2 * 1 = 3) + freq=6, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 2 + 2 * 2 = 6) op=torch.ops.aten.mm.default, ), + partition_fn=partition_fn, ) _test( context_fn=context_fn_no_recompute_mm, bw_compiler=functools.partial( count_ops, - freq=2, # 2 bwd mm ops per fwd matmul + freq=4, # 2 bwd mm ops per fwd matmul op=torch.ops.aten.mm.default, ), + partition_fn=partition_fn, ) def test_sac_with_partial_context_fn(self): @@ -801,7 +884,16 @@ def fn(x, y): @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_must_not_recompute_gemm( + self, device, partition_fn + ): def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, @@ -841,15 +933,22 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) def test_compile_selective_checkpoint_must_not_recompute_gemm_no_functionalization( - self, device + self, device, partition_fn ): def selective_checkpointing_context_fn(): no_recompute_list = [ @@ -889,7 +988,7 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, disable_functionalization=True, ) self._validate(fn, backend, x, y) @@ -897,7 +996,14 @@ def fn(x, y): @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_triton_kernel(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_triton_kernel(self, device, partition_fn): # Copy of the above test, but make sure that having a triton kernel in the # region does not error. def add_one(x): @@ -957,14 +1063,21 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_tensor_subclass(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_tensor_subclass(self, device, partition_fn): def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, @@ -1007,14 +1120,21 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_custom_rule(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_custom_rule(self, device, partition_fn): def _get_custom_policy(meta): no_recompute_list = [ torch.ops.aten.mm.default, @@ -1072,14 +1192,21 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_partial_ctx_fn(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_partial_ctx_fn(self, device, partition_fn): def selective_checkpointing_context_fn(no_recompute_list): return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list) @@ -1118,14 +1245,21 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_outplace_op(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_outplace_op(self, device, partition_fn): def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, @@ -1163,14 +1297,21 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_list_ops(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_list_ops(self, device, partition_fn): def selective_checkpointing_context_fn(): # recompute everything no_recompute_list = [] @@ -1206,7 +1347,7 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @@ -1217,7 +1358,14 @@ def fn(x, y): "requires TorchDispatchMode + torch.compile work to complete" ) @requires_cuda_and_triton - def test_compile_selective_checkpoint_inplace_op(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_inplace_op(self, device, partition_fn): def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, @@ -1257,7 +1405,7 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @@ -1265,7 +1413,14 @@ def fn(x, y): @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") @torch._inductor.config.patch(fallback_random=True) - def test_compile_selective_checkpoint_random_op(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_random_op(self, device, partition_fn): for preserve_rng_state in [True, False]: def selective_checkpointing_context_fn(): @@ -1312,7 +1467,7 @@ def fn(x): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) # NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager, @@ -1324,7 +1479,14 @@ def fn(x): @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_invalid_context(self): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_invalid_context(self, partition_fn): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) * y @@ -1353,7 +1515,7 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) with self.assertRaisesRegex( Exception, "must generate a tuple of two `TorchDispatchMode`s" @@ -1362,7 +1524,14 @@ def fn(x, y): @requires_cuda_and_triton @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) - def test_compile_selective_checkpoint_parametrization(self): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_parametrization(self, partition_fn): def sac_policy(): def _recomp_policy(): def _custom_policy(ctx, func, *args, **kwargs): @@ -1425,7 +1594,9 @@ def reset_parameters(self): bw_compiler = functools.partial( count_ops, freqs=[ - 2, # 1 from mul recompute, 1 from mul backward + # 1 from mul recompute, 1 from mul backward + # w/o CSE, we have one extra mul + 3 if partition_fn is default_partition else 2, 1, ], ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default], @@ -1434,7 +1605,7 @@ def reset_parameters(self): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) model = MLPModule() diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 6cae42d8929da..c452f18e95d75 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -2640,7 +2640,7 @@ def backward(ctx, grad_output): return grad_output * x, grad_output * x def f(a, b): - return FwBwMutation.apply(a, b) + return FwBwMutation.apply(a, b).sin_().clone() inps = [ torch.ones(3, 3, requires_grad=True), @@ -2689,17 +2689,22 @@ def forward(self, primals_1, primals_2): add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None _foreach_mul__1 = torch.ops.aten._foreach_mul_.ScalarList([add], [3]); _foreach_mul__1 = None mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None - return (mul, add)""", + clone = torch.ops.aten.clone.default(mul) + sin_ = torch.ops.aten.sin_.default(mul); mul = None + clone_1 = torch.ops.aten.clone.default(sin_); sin_ = None + return (clone_1, add, clone)""", ) # important bit: there is 1 mutation in the bw self.assertExpectedInline( bw_graph[0].code.strip(), """\ -def forward(self, add, tangents_1): +def forward(self, add, clone, tangents_1): + cos = torch.ops.aten.cos.default(clone); clone = None + mul_1 = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None _foreach_mul__2 = torch.ops.aten._foreach_mul_.ScalarList([add], [4]); _foreach_mul__2 = None - mul_1 = torch.ops.aten.mul.Tensor(tangents_1, add); tangents_1 = add = None - return (mul_1, None)""", + mul_2 = torch.ops.aten.mul.Tensor(mul_1, add); mul_1 = add = None + return (mul_2, None)""", ) def test_fw_bw_mutation_no_functionalization2(self): diff --git a/test/higher_order_ops/test_local_map.py b/test/higher_order_ops/test_local_map.py index a585f2055e89f..7b5f01d236e7f 100644 --- a/test/higher_order_ops/test_local_map.py +++ b/test/higher_order_ops/test_local_map.py @@ -911,8 +911,8 @@ def inputs_fn(): op="call_function", target=torch.ops.aten.mm.default ) self.assertEqual(len(mm_nodes), 4) - self.assertNotIn("partitioner_tag", mm_nodes[0].meta) - self.assertNotIn("partitioner_tag", mm_nodes[1].meta) + self.assertEqual(mm_nodes[0].meta["partitioner_tag"], "is_forward") + self.assertEqual(mm_nodes[1].meta["partitioner_tag"], "is_forward") self.assertEqual(mm_nodes[2].meta["partitioner_tag"], "is_backward") self.assertEqual(mm_nodes[3].meta["partitioner_tag"], "is_backward") self.assertEqual(mm_nodes[0].meta["custom"]["inside_local_map"], 0) diff --git a/torch/_functorch/_aot_autograd/functional_utils.py b/torch/_functorch/_aot_autograd/functional_utils.py index fcbf861e537db..5af4fc9ee1195 100644 --- a/torch/_functorch/_aot_autograd/functional_utils.py +++ b/torch/_functorch/_aot_autograd/functional_utils.py @@ -10,6 +10,7 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Optional import torch from torch import Tensor @@ -449,7 +450,7 @@ def was_tensor_metadata_updated(arg, new_arg): # Returns the number of detected copy_ -def assert_functional_graph(fx_g: torch.fx.Graph) -> int: +def _is_functional_graph(fx_g: torch.fx.Graph) -> tuple[Optional[str], int]: allowed_mutation_ops = [ torch.ops.aten.copy_.default, torch.ops.aten.set_.source_Tensor, @@ -462,6 +463,7 @@ def assert_functional_graph(fx_g: torch.fx.Graph) -> int: # NB: It would also be nice to verify that the mutations all happen at the # end, but we also do some administrative views after mutations so this # isn't actually true. (TODO: Could this cause problems for Inductor?) + error = None for n in fx_g.nodes: if n.op == "placeholder": placeholders.add(n) @@ -471,14 +473,18 @@ def assert_functional_graph(fx_g: torch.fx.Graph) -> int: # this is mostly a hack to avoid failing XLA tests. # See https://github.com/pytorch/pytorch/pull/122434#issuecomment-2101012113 if "set_buffer_donor_" not in str(n.args[0]): - assert n.args[0] in placeholders, ( - f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}" - ) + if n.args[0] not in placeholders: + error = f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}" mutation_count += 1 else: - assert not n.target._schema.is_mutable, ( - f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}" - ) + if n.target._schema.is_mutable: + error = f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}" + return error, mutation_count + + +def assert_functional_graph(fx_g: torch.fx.Graph) -> int: + error, mutation_count = _is_functional_graph(fx_g) + assert error is None, error return mutation_count diff --git a/torch/_functorch/_aot_autograd/graph_capture_wrappers.py b/torch/_functorch/_aot_autograd/graph_capture_wrappers.py index bc4dc87ddeced..2ef84cb488604 100644 --- a/torch/_functorch/_aot_autograd/graph_capture_wrappers.py +++ b/torch/_functorch/_aot_autograd/graph_capture_wrappers.py @@ -27,6 +27,7 @@ from torch._prims_common import CUDARngStateHelper from torch.fx.experimental.proxy_tensor import ( _proxy_tensor_disable_update_tensor_tracker, + get_proxy_mode, maybe_disable_thunkify, maybe_enable_thunkify, ) @@ -295,6 +296,10 @@ def inner_fn( (outs, tangent_mask), (outs_descs, _) = call_and_expect_output_descs( fn, primals ) + mode = get_proxy_mode() + assert mode is not None, "Expected non-None proxy mode" + for node in mode.tracer.graph.nodes: + node.meta["partitioner_tag"] = "is_forward" # TODO: I think this hook can also be eliminated now if joint_fn_handle and joint_fn_handle.post_forward: diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index e7f8075b0281e..f22b274be41ab 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -10,6 +10,7 @@ import os import os.path import re +import warnings from collections import defaultdict from collections.abc import Callable from dataclasses import dataclass, replace @@ -51,6 +52,7 @@ ) from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator from ._aot_autograd.descriptors import AOTOutput, SavedForBackwardsAOTOutput +from ._aot_autograd.functional_utils import _is_functional_graph from ._aot_autograd.logging_utils import get_aot_graph_name from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems @@ -297,6 +299,10 @@ def _has_tag_is_backward(node: fx.Node) -> bool: return node.meta.get("partitioner_tag", None) == "is_backward" +def _has_tag_is_forward(node: fx.Node) -> bool: + return node.meta.get("partitioner_tag", None) == "is_forward" + + def _has_tag_must_be_in_forward(node: fx.Node) -> bool: return node.meta.get("partitioner_tag", None) == "must_be_in_forward" @@ -1021,105 +1027,134 @@ def default_partition( Returns: Returns the generated forward and backward Fx graph modules. """ - if has_recomputable_ops(joint_module): - return min_cut_rematerialization_partition( - joint_module, - _joint_inputs, - num_fwd_outputs=num_fwd_outputs, - static_lifetime_input_indices=static_lifetime_input_indices, - ) - primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) - fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes)) - inputs = primal_inputs + fwd_seed_offset_inputs - fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = ( - _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) - ) - forward_only_graph = _extract_graph_with_inputs_outputs( - joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward" - ) + # Respect the original placement of ops rather than rely on dataflow. + forward_nodes = [] + last_node = None + for node in joint_module.graph.nodes: + if _has_tag_is_forward(node) or _is_primal(node) or _is_fwd_seed_offset(node): + last_node = node + assert last_node is not None + for node in joint_module.graph.nodes: + if not _is_tangent(node): + forward_nodes.append(node) + if node is last_node: + break forward_node_names = OrderedSet( - node.name for node in forward_only_graph.nodes if node.op != "output" + node.name for node in forward_nodes if node.op != "output" ) - order = {node: idx for idx, node in enumerate(joint_module.graph.nodes)} + graph_has_recomputable_ops = has_recomputable_ops(joint_module) + graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module) + if graph_has_recomputable_ops: + if _is_functional_graph(joint_module.graph)[0] is not None: + # Fall-back to previous behavior to avoid bc-breaking, although can + # eventually flip the switch to make this a hard error. + warnings.warn( + "Trying to unsafely apply AC to a non-functional graph with the " + "default partitioner. Falling back to min-cut partitioner." + ) + return min_cut_rematerialization_partition( + joint_module, + _joint_inputs, + num_fwd_outputs=num_fwd_outputs, + static_lifetime_input_indices=static_lifetime_input_indices, + ) + + joint_module = cleanup_recompute_tags(joint_module, is_default_partition=True) + + if not config.unsafe_allow_optimization_of_collectives: + force_save_collectives(joint_module) + + force_save_bw_mutation_src(joint_module) + + if static_lifetime_input_indices is None: + static_lifetime_input_indices = [] + node_info = classify_nodes( + joint_module, static_lifetime_input_indices, num_fwd_outputs + ) + saved_values = [] saved_sym_nodes = [] - def is_mutated_later_in_fw(node): - if _has_tag_is_backward(node): - return False - tensor_arg_aliases = [ - x - for x in node.args - if isinstance(x, fx.Node) - and "val" in x.meta - and isinstance(x.meta["val"], torch.Tensor) - ] - while len(tensor_arg_aliases) > 0: - a = tensor_arg_aliases.pop() - for u in a.users: - if not isinstance(u.target, torch._ops.OpOverload): - continue - # If we witness a mutation on our node later, and that mutation is not "must be in backward", - # then our node needs to be computed in the forward (otherwise we will compute it on the mutated values) - if ( - # one of the args was mutated - u.target._schema.is_mutable - # and the mutation happens "later" - and order[u] > order[node] - # and the mutation happened during the forward - and not (_has_tag_is_backward(u) or _has_tag_must_be_in_backward(u)) - ): - for idx, alias_info in enumerate(u.target._schema.arguments): - if alias_info.is_write and u.args[idx] is a: - return True - elif u.target.is_view: - tensor_arg_aliases.append(u) - return False + distributed_enabled = torch.distributed.is_available() + + def is_tensor(node): + return "tensor_meta" in node.meta or isinstance( + node.meta.get("val"), torch._subclasses.FakeTensor + ) + + def is_multi_output(node): + return ( + all(user.target == operator.getitem for user in node.users) + and len(node.users) > 0 + ) + + def is_impure(node): + # wait tensor is an "impure" op according to DCE's definition of impure + # (see is_impure in torch/fx/node.py), but it survives past + # functionalization and can be safely dup'd and reordered under the + # assumption SPMD. + return ( + node.is_impure(impure_random=False) + and node.op + not in ( + "placeholder", + "output", + ) + and ( + not distributed_enabled + or node.target is not torch.ops._c10d_functional.wait_tensor.default + ) + ) for node in joint_module.graph.nodes: if node.name not in forward_node_names: - # if a node isn't "required" to be in the forward, but any of its arguments - # are later mutated in the forward, then it must have been run in the forward - # (if not, and the node's arg was saved for backward, we would have mutated a saved value) - # NB: doesn't handle nodes where the input is a list of tensors and one of those tensors is later mutated - if is_mutated_later_in_fw(node): - saved_values.append(node) continue if is_sym_node(node): # Symints must be kept separate from tensors so that PythonFunction only calls # save_for_backward on tensors and stashes symints in autograd .ctx saved_sym_nodes.append(node) - elif ( - "tensor_meta" not in node.meta - and node.op == "call_function" - and not isinstance(node.meta.get("val"), torch._subclasses.FakeTensor) - ): - # Since we can't save tuple of tensor values, we need to flatten out what we're saving - users = node.users - assert all(user.target is operator.getitem for user in users) - saved_values.extend(users) - else: - backward_usages = [ - n for n in node.users if n.name not in forward_node_names - ] - if "tensor_meta" in node.meta and all( - is_sym_node(n) for n in backward_usages - ): - # If we have a tensor in the forward, where only its sizes/strides are needed in the backward, - # and not the actual tensor data, - # then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor. - # - # Note that saving the tensor could also cause compilation problems: - # If the user mutated an input in the forward and uses its sizes/strides in the backward, - # then we would be obligated to clone the input before saving it to appease autograd. - # (This is how we originally found this bug). - saved_sym_nodes.extend(backward_usages) - else: - saved_values.append(node) + continue + if is_multi_output(node): + # Must be ordered before MUST_SAVE tags to avoid saving tuples marked MUST_SAVE. + continue + if node.meta.get("recompute") == CheckpointPolicy.MUST_SAVE: + saved_values.append(node) + continue + if is_impure(node): + assert not graph_has_recomputable_ops, ( + "Trying to apply AC on a graph with impure op", + node, + node.target, + ) + saved_values.append(node) + continue + assert is_tensor(node) or node.op != "call_function", ( + f"Expected {node} to be a tensor" + ) + backward_usages = [n for n in node.users if n.name not in forward_node_names] + if all(is_sym_node(n) for n in backward_usages): + # If we have a tensor in the forward, where only its sizes/strides are needed in the backward, + # and not the actual tensor data, + # then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor. + # + # Note that saving the tensor could also cause compilation problems: + # If the user mutated an input in the forward and uses its sizes/strides in the backward, + # then we would be obligated to clone the input before saving it to appease autograd. + # (This is how we originally found this bug). + saved_sym_nodes.extend(backward_usages) + continue + if not must_recompute(node): + saved_values.append(node) + saved_values = list(dict.fromkeys(saved_values).keys()) saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys()) - return _extract_fwd_bwd_modules( + if config._sync_decision_cross_ranks: + saved_values = _sync_decision_cross_ranks(joint_module.graph, saved_values) + + if static_lifetime_input_nodes is None: + static_lifetime_input_nodes = node_info.static_lifetime_input_nodes + fw_module, bw_module = _extract_fwd_bwd_modules( joint_module, saved_values, saved_sym_nodes=saved_sym_nodes, @@ -1127,6 +1162,37 @@ def is_mutated_later_in_fw(node): static_lifetime_input_nodes=static_lifetime_input_nodes, ) + # Run DCE while overriding the definition of is_impure_node + def is_not_collective(node): + if not distributed_enabled: + return True + if node.target is torch.ops._c10d_functional.wait_tensor.default: + return False + if node.target is torch.ops._c10d_functional.all_gather_into_tensor.default: + return False + return True + + fw_module.graph.eliminate_dead_code(is_impure_node=is_not_collective) + bw_module.graph.eliminate_dead_code(is_impure_node=is_not_collective) + + if graph_has_recomputable_ops: + if graph_has_recomputable_rng_ops: + fw_module, bw_module = functionalize_rng_ops( + joint_module, fw_module, bw_module, len(saved_sym_nodes) + ) + bw_module = reordering_to_mimic_autograd_engine(bw_module) + + # raise all getitem ops to as early as possible + # this is helpful for memory, especially in the case of aot_eager backend + fw_module = raise_getitems(fw_module) + bw_module = raise_getitems(bw_module) + + fw_module = thread_graphsafe_rng_from_hops(fw_module, is_backward=False) + if len(node_info.required_bw_nodes) > 0: + bw_module = thread_graphsafe_rng_from_hops(bw_module, is_backward=True) + + return fw_module, bw_module + INT_INF = int(1e6) @@ -1621,7 +1687,16 @@ def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None: break -def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule: +def is_getitem_of_multi_output(node): + if node.target != operator.getitem: + return False + parent = node.args[0] + return "tensor_meta" not in parent.meta and node.op == "call_function" + + +def cleanup_recompute_tags( + joint_module: fx.GraphModule, *, is_default_partition: bool +) -> fx.GraphModule: """ If there are two consecutive checkpointed blocks with no operator in between, we would still want to stash the tensor at the boundary of @@ -1658,6 +1733,20 @@ def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule: # Solution: check whether `out` has a backward hook, and if so, intentionally save `out` # in forward graph outputs. With this, we can break the above circular dependency. node.meta["recompute"] = CheckpointPolicy.MUST_SAVE + elif ( + "ac_graph_id" not in node.meta + and any(must_recompute(user) for user in node.users) + and not ( + # Avoid saving getitem nodes which are not labeled with "ac_graph_id" + is_getitem_of_multi_output(node) and "ac_graph_id" in node.args[0].meta + ) + and is_default_partition + ): + # This node is not part of the AC region and a user is marked as recompute. + # This means it's an input to the AC region and we should save it. + # For ease of landing, gate this to default partitioner only, but we should think + # about flipping the switch in general as well. + node.meta["recompute"] = CheckpointPolicy.MUST_SAVE return joint_module @@ -2765,6 +2854,59 @@ def thread_graphsafe_rng_from_hops(module, is_backward): return module +def classify_nodes(joint_module, static_lifetime_input_indices, num_fwd_outputs): + name_to_node = get_name_to_node(joint_module.graph) + required_bw_nodes: OrderedSet[fx.Node] = OrderedSet() + for node in joint_module.graph.nodes: + if node.op == "placeholder" and "tangents" in node.target: + required_bw_nodes.add(node) + elif _must_be_in_backward(node): + required_bw_nodes.add(node) + + if node in required_bw_nodes: + required_bw_nodes.update(node.users) + + primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) + fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes)) + inputs = primal_inputs + fwd_seed_offset_inputs + fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = ( + _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) + ) + required_bw_nodes.update( + o for o in bwd_outputs if o is not None and o.op != "output" + ) + forward_only_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward" + ) + required_fw_nodes: OrderedSet[fx.Node] = OrderedSet( + name_to_node[node.name] + for node in forward_only_graph.nodes + if node.op != "output" + ) + unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet( + node + for node in joint_module.graph.nodes + if node not in required_fw_nodes and node not in required_bw_nodes + ) + static_lifetime_input_nodes = OrderedSet( + p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices + ) + fw_cnt = 0 + fw_order = {} + for node in joint_module.graph.nodes: + if node in required_fw_nodes: + fw_order[node] = fw_cnt + fw_cnt += 1 + return NodeInfo( + inputs, + required_fw_nodes, + required_bw_nodes, + unclaimed_nodes, + fw_order, + static_lifetime_input_nodes, + ) + + def min_cut_rematerialization_partition( joint_module: fx.GraphModule, _joint_inputs, @@ -2813,68 +2955,16 @@ def min_cut_rematerialization_partition( graph_has_recomputable_ops = has_recomputable_ops(joint_module) graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module) if graph_has_recomputable_ops: - joint_module = cleanup_recompute_tags(joint_module) + joint_module = cleanup_recompute_tags(joint_module, is_default_partition=False) if not config.unsafe_allow_optimization_of_collectives: force_save_collectives(joint_module) force_save_bw_mutation_src(joint_module) - def classify_nodes(joint_module, static_lifetime_input_indices): - name_to_node = get_name_to_node(joint_module.graph) - required_bw_nodes: OrderedSet[fx.Node] = OrderedSet() - for node in joint_module.graph.nodes: - if node.op == "placeholder" and "tangents" in node.target: - required_bw_nodes.add(node) - elif _must_be_in_backward(node): - required_bw_nodes.add(node) - - if node in required_bw_nodes: - required_bw_nodes.update(node.users) - - primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) - fwd_seed_offset_inputs = list( - filter(_is_fwd_seed_offset, joint_module.graph.nodes) - ) - inputs = primal_inputs + fwd_seed_offset_inputs - fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = ( - _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) - ) - required_bw_nodes.update( - o for o in bwd_outputs if o is not None and o.op != "output" - ) - forward_only_graph = _extract_graph_with_inputs_outputs( - joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward" - ) - required_fw_nodes: OrderedSet[fx.Node] = OrderedSet( - name_to_node[node.name] - for node in forward_only_graph.nodes - if node.op != "output" - ) - unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet( - node - for node in joint_module.graph.nodes - if node not in required_fw_nodes and node not in required_bw_nodes - ) - static_lifetime_input_nodes = OrderedSet( - p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices - ) - fw_cnt = 0 - fw_order = {} - for node in joint_module.graph.nodes: - if node in required_fw_nodes: - fw_order[node] = fw_cnt - fw_cnt += 1 - return NodeInfo( - inputs, - required_fw_nodes, - required_bw_nodes, - unclaimed_nodes, - fw_order, - static_lifetime_input_nodes, - ) - if static_lifetime_input_indices is None: static_lifetime_input_indices = [] - node_info = classify_nodes(joint_module, static_lifetime_input_indices) + node_info = classify_nodes( + joint_module, static_lifetime_input_indices, num_fwd_outputs + ) # networkx blows up on graphs with no required backward nodes # Since there's nothing to partition anyway, and the default partitioner can "handle" From acf5b204b033fa65deaa69ff127543412087a185 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 19 Nov 2025 19:26:12 +0000 Subject: [PATCH 054/230] Revert "Hide all symbols (except stable/headeronly/shim) if TORCH_STABLE_ONLY is defined (#167496)" This reverts commit 8f4dc304534529c6abf9da0b7154d49fb907d167. Reverted https://github.com/pytorch/pytorch/pull/167496 on behalf of https://github.com/atalman due to Failing validations - https://github.com/pytorch/test-infra/actions/runs/19513141127/job/55857898996 ([comment](https://github.com/pytorch/pytorch/pull/167496#issuecomment-3554287955)) --- .../smoke_test/check_binary_symbols.py | 338 ------------------ setup.py | 47 --- .../libtorch_agnostic_2_10_extension/setup.py | 1 + .../torch_stable_test_extension/setup.py | 67 ++++ .../torch_stable_test/__init__.py | 0 .../torch_stable_test/csrc/test_extension.cpp | 1 + .../torch_stable_test/test_torch_stable.py | 22 ++ torch/csrc/inductor/aoti_torch/c/shim.h | 6 +- 8 files changed, 94 insertions(+), 388 deletions(-) create mode 100644 test/cpp_extensions/torch_stable_test_extension/setup.py create mode 100644 test/cpp_extensions/torch_stable_test_extension/torch_stable_test/__init__.py create mode 100644 test/cpp_extensions/torch_stable_test_extension/torch_stable_test/csrc/test_extension.cpp create mode 100644 test/cpp_extensions/torch_stable_test_extension/torch_stable_test/test_torch_stable.py diff --git a/.ci/pytorch/smoke_test/check_binary_symbols.py b/.ci/pytorch/smoke_test/check_binary_symbols.py index 51d5174e77912..b0c607659c72d 100755 --- a/.ci/pytorch/smoke_test/check_binary_symbols.py +++ b/.ci/pytorch/smoke_test/check_binary_symbols.py @@ -100,337 +100,6 @@ def check_lib_statically_linked_libstdc_cxx_abi_symbols(lib: str) -> None: ) -def _compile_and_extract_symbols( - cpp_content: str, compile_flags: list[str], exclude_list: list[str] | None = None -) -> list[str]: - """ - Helper to compile a C++ file and extract all symbols. - - Args: - cpp_content: C++ source code to compile - compile_flags: Compilation flags - exclude_list: List of symbol names to exclude. Defaults to ["main"]. - - Returns: - List of all symbols found in the object file (excluding those in exclude_list). - """ - import subprocess - import tempfile - - if exclude_list is None: - exclude_list = ["main"] - - with tempfile.TemporaryDirectory() as tmpdir: - tmppath = Path(tmpdir) - cpp_file = tmppath / "test.cpp" - obj_file = tmppath / "test.o" - - cpp_file.write_text(cpp_content) - - result = subprocess.run( - compile_flags + [str(cpp_file), "-o", str(obj_file)], - capture_output=True, - text=True, - timeout=60, - ) - - if result.returncode != 0: - raise RuntimeError(f"Compilation failed: {result.stderr}") - - symbols = get_symbols(str(obj_file)) - - # Return all symbol names, excluding those in the exclude list - return [name for _addr, _stype, name in symbols if name not in exclude_list] - - -def check_stable_only_symbols(install_root: Path) -> None: - """ - Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code and comparing symbol counts. - - This approach tests: - 1. WITHOUT macros -> many torch symbols exposed - 2. WITH TORCH_STABLE_ONLY -> zero torch symbols (all hidden) - 3. WITH TORCH_TARGET_VERSION -> zero torch symbols (all hidden) - 4. WITH both macros -> zero torch symbols (all hidden) - """ - include_dir = install_root / "include" - assert include_dir.exists(), f"Expected {include_dir} to be present" - - test_cpp_content = """ -// Main torch C++ API headers -#include -#include - -// ATen tensor library -#include - -// Core c10 headers (commonly used) -#include -#include -#include -#include -#include - -int main() { return 0; } -""" - - base_compile_flags = [ - "g++", - "-std=c++17", - f"-I{include_dir}", - f"-I{include_dir}/torch/csrc/api/include", - "-c", # Compile only, don't link - ] - - # Compile WITHOUT any macros - symbols_without = _compile_and_extract_symbols( - cpp_content=test_cpp_content, - compile_flags=base_compile_flags, - ) - - # We expect constexpr symbols, inline functions used by other headers etc. - # to produce symbols - num_symbols_without = len(symbols_without) - print(f"Found {num_symbols_without} symbols without any macros defined") - assert num_symbols_without != 0, ( - "Expected a non-zero number of symbols without any macros" - ) - - # Compile WITH TORCH_STABLE_ONLY (expect 0 symbols) - compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY"] - - symbols_with_stable_only = _compile_and_extract_symbols( - cpp_content=test_cpp_content, - compile_flags=compile_flags_with_stable_only, - ) - - num_symbols_with_stable_only = len(symbols_with_stable_only) - assert num_symbols_with_stable_only == 0, ( - f"Expected no symbols with TORCH_STABLE_ONLY macro, but found {num_symbols_with_stable_only}" - ) - - # Compile WITH TORCH_TARGET_VERSION (expect 0 symbols) - compile_flags_with_target_version = base_compile_flags + [ - "-DTORCH_TARGET_VERSION=1" - ] - - symbols_with_target_version = _compile_and_extract_symbols( - cpp_content=test_cpp_content, - compile_flags=compile_flags_with_target_version, - ) - - num_symbols_with_target_version = len(symbols_with_target_version) - assert num_symbols_with_target_version == 0, ( - f"Expected no symbols with TORCH_TARGET_VERSION macro, but found {num_symbols_with_target_version}" - ) - - # Compile WITH both macros (expect 0 symbols) - compile_flags_with_both = base_compile_flags + [ - "-DTORCH_STABLE_ONLY", - "-DTORCH_TARGET_VERSION=1", - ] - - symbols_with_both = _compile_and_extract_symbols( - cpp_content=test_cpp_content, - compile_flags=compile_flags_with_both, - ) - - num_symbols_with_both = len(symbols_with_both) - assert num_symbols_with_both == 0, ( - f"Expected no symbols with both macros, but found {num_symbols_with_both}" - ) - - -def check_stable_api_symbols(install_root: Path) -> None: - """ - Test that stable API headers still expose symbols with TORCH_STABLE_ONLY. - The torch/csrc/stable/c/shim.h header is tested in check_stable_c_shim_symbols - """ - include_dir = install_root / "include" - assert include_dir.exists(), f"Expected {include_dir} to be present" - - stable_dir = include_dir / "torch" / "csrc" / "stable" - assert stable_dir.exists(), f"Expected {stable_dir} to be present" - - stable_headers = list(stable_dir.rglob("*.h")) - if not stable_headers: - raise RuntimeError("Could not find any stable headers") - - includes = [] - for header in stable_headers: - rel_path = header.relative_to(include_dir) - includes.append(f"#include <{rel_path.as_posix()}>") - - includes_str = "\n".join(includes) - test_stable_content = f""" -{includes_str} -int main() {{ return 0; }} -""" - - compile_flags = [ - "g++", - "-std=c++17", - f"-I{include_dir}", - f"-I{include_dir}/torch/csrc/api/include", - "-c", - "-DTORCH_STABLE_ONLY", - ] - - symbols_stable = _compile_and_extract_symbols( - cpp_content=test_stable_content, - compile_flags=compile_flags, - ) - num_symbols_stable = len(symbols_stable) - print(f"Found {num_symbols_stable} symbols in torch/csrc/stable") - assert num_symbols_stable > 0, ( - f"Expected stable headers to expose symbols with TORCH_STABLE_ONLY, " - f"but found {num_symbols_stable} symbols" - ) - - -def check_headeronly_symbols(install_root: Path) -> None: - """ - Test that header-only utility headers still expose symbols with TORCH_STABLE_ONLY. - """ - include_dir = install_root / "include" - assert include_dir.exists(), f"Expected {include_dir} to be present" - - # Find all headers in torch/headeronly - headeronly_dir = include_dir / "torch" / "headeronly" - assert headeronly_dir.exists(), f"Expected {headeronly_dir} to be present" - headeronly_headers = list(headeronly_dir.rglob("*.h")) - if not headeronly_headers: - raise RuntimeError("Could not find any headeronly headers") - - # Filter out platform-specific headers that may not compile everywhere - platform_specific_keywords = [ - "cpu/vec", - ] - - filtered_headers = [] - for header in headeronly_headers: - rel_path = header.relative_to(include_dir).as_posix() - if not any( - keyword in rel_path.lower() for keyword in platform_specific_keywords - ): - filtered_headers.append(header) - - includes = [] - for header in filtered_headers: - rel_path = header.relative_to(include_dir) - includes.append(f"#include <{rel_path.as_posix()}>") - - includes_str = "\n".join(includes) - test_headeronly_content = f""" -{includes_str} -int main() {{ return 0; }} -""" - - compile_flags = [ - "g++", - "-std=c++17", - f"-I{include_dir}", - f"-I{include_dir}/torch/csrc/api/include", - "-c", - "-DTORCH_STABLE_ONLY", - ] - - symbols_headeronly = _compile_and_extract_symbols( - cpp_content=test_headeronly_content, - compile_flags=compile_flags, - ) - num_symbols_headeronly = len(symbols_headeronly) - print(f"Found {num_symbols_headeronly} symbols in torch/headeronly") - assert num_symbols_headeronly > 0, ( - f"Expected headeronly headers to expose symbols with TORCH_STABLE_ONLY, " - f"but found {num_symbols_headeronly} symbols" - ) - - -def check_aoti_shim_symbols(install_root: Path) -> None: - """ - Test that AOTI shim headers still expose symbols with TORCH_STABLE_ONLY. - """ - include_dir = install_root / "include" - assert include_dir.exists(), f"Expected {include_dir} to be present" - - # There are no constexpr symbols etc., so we need to actually use functions - # so that some symbols are found. - test_shim_content = """ -#include -int main() { - int32_t (*fp1)() = &aoti_torch_device_type_cpu; - int32_t (*fp2)() = &aoti_torch_dtype_float32; - (void)fp1; (void)fp2; - return 0; -} -""" - - compile_flags = [ - "g++", - "-std=c++17", - f"-I{include_dir}", - f"-I{include_dir}/torch/csrc/api/include", - "-c", - "-DTORCH_STABLE_ONLY", - ] - - symbols_shim = _compile_and_extract_symbols( - cpp_content=test_shim_content, - compile_flags=compile_flags, - ) - num_symbols_shim = len(symbols_shim) - assert num_symbols_shim > 0, ( - f"Expected shim headers to expose symbols with TORCH_STABLE_ONLY, " - f"but found {num_symbols_shim} symbols" - ) - - -def check_stable_c_shim_symbols(install_root: Path) -> None: - """ - Test that stable C shim headers still expose symbols with TORCH_STABLE_ONLY. - """ - include_dir = install_root / "include" - assert include_dir.exists(), f"Expected {include_dir} to be present" - - # Check if the stable C shim exists - stable_shim = include_dir / "torch" / "csrc" / "stable" / "c" / "shim.h" - if not stable_shim.exists(): - raise RuntimeError("Could not find stable c shim") - - # There are no constexpr symbols etc., so we need to actually use functions - # so that some symbols are found. - test_stable_shim_content = """ -#include -int main() { - // Reference stable C API functions to create undefined symbols - AOTITorchError (*fp1)(const char*, uint32_t*, int32_t*) = &torch_parse_device_string; - AOTITorchError (*fp2)(uint32_t*) = &torch_get_num_threads; - (void)fp1; (void)fp2; - return 0; -} -""" - - compile_flags = [ - "g++", - "-std=c++17", - f"-I{include_dir}", - f"-I{include_dir}/torch/csrc/api/include", - "-c", - "-DTORCH_STABLE_ONLY", - ] - - symbols_stable_shim = _compile_and_extract_symbols( - cpp_content=test_stable_shim_content, - compile_flags=compile_flags, - ) - num_symbols_stable_shim = len(symbols_stable_shim) - assert num_symbols_stable_shim > 0, ( - f"Expected stable C shim headers to expose symbols with TORCH_STABLE_ONLY, " - f"but found {num_symbols_stable_shim} symbols" - ) - - def check_lib_symbols_for_abi_correctness(lib: str) -> None: print(f"lib: {lib}") cxx11_symbols = grep_symbols(lib, LIBTORCH_CXX11_PATTERNS) @@ -460,13 +129,6 @@ def main() -> None: check_lib_symbols_for_abi_correctness(libtorch_cpu_path) check_lib_statically_linked_libstdc_cxx_abi_symbols(libtorch_cpu_path) - # Check symbols when TORCH_STABLE_ONLY is defined - check_stable_only_symbols(install_root) - check_stable_api_symbols(install_root) - check_headeronly_symbols(install_root) - check_aoti_shim_symbols(install_root) - check_stable_c_shim_symbols(install_root) - if __name__ == "__main__": main() diff --git a/setup.py b/setup.py index ef584cefdd6dd..314f719ea67f0 100644 --- a/setup.py +++ b/setup.py @@ -1358,45 +1358,6 @@ def __exit__(self, *exc_info: object) -> None: # Need to create the proper LICENSE.txt for the wheel class bdist_wheel(setuptools.command.bdist_wheel.bdist_wheel): - def _wrap_headers_with_macro(self, bdist_dir: Path) -> None: - """Wrap all header files with #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION). - - Excludes: - - torch/include/torch/headeronly/* - - torch/include/torch/csrc/stable/* - - torch/include/torch/csrc/inductor/aoti_torch/c/ (only shim headers) - - torch/include/torch/csrc/inductor/aoti_torch/generated/ - """ - header_extensions = (".h", ".hpp", ".cuh") - header_files = [ - f for ext in header_extensions for f in bdist_dir.rglob(f"*{ext}") - ] - - # Paths to exclude from wrapping - exclude_dir_patterns = [ - "torch/include/torch/headeronly/", - "torch/include/torch/csrc/stable/", - "torch/include/torch/csrc/inductor/aoti_torch/c/", - "torch/include/torch/csrc/inductor/aoti_torch/generated/", - ] - - for header_file in header_files: - rel_path = header_file.relative_to(bdist_dir).as_posix() - - if any(rel_path.startswith(pattern) for pattern in exclude_dir_patterns): - report(f"Skipping header: {rel_path}") - continue - - original_content = header_file.read_text(encoding="utf-8") - wrapped_content = ( - "#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n" - f"{original_content}" - "\n#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n" - ) - - header_file.write_text(wrapped_content, encoding="utf-8") - report(f"Wrapped header: {rel_path}") - def run(self) -> None: with concat_license_files(include_files=True): super().run() @@ -1419,14 +1380,6 @@ def write_wheelfile(self, *args: Any, **kwargs: Any) -> None: # need an __init__.py file otherwise we wouldn't have a package (bdist_dir / "torch" / "__init__.py").touch() - # Wrap all header files with TORCH_STABLE_ONLY macro - assert self.bdist_dir is not None, "bdist_dir should be set during wheel build" - bdist_dir = Path(self.bdist_dir) - report( - "-- Wrapping header files with if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)" - ) - self._wrap_headers_with_macro(bdist_dir) - class clean(Command): user_options: ClassVar[list[tuple[str, str | None, str]]] = [] diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py index 405944bc0f9bf..ff2aeff5e932b 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py @@ -35,6 +35,7 @@ def get_extension(): extra_compile_args = { "cxx": [ "-fdiagnostics-color=always", + "-DTORCH_STABLE_ONLY", "-DTORCH_TARGET_VERSION=0x020a000000000000", ], } diff --git a/test/cpp_extensions/torch_stable_test_extension/setup.py b/test/cpp_extensions/torch_stable_test_extension/setup.py new file mode 100644 index 0000000000000..062d466e7ae98 --- /dev/null +++ b/test/cpp_extensions/torch_stable_test_extension/setup.py @@ -0,0 +1,67 @@ +import distutils.command.clean +import shutil +from pathlib import Path + +from setuptools import find_packages, setup + +from torch.utils.cpp_extension import BuildExtension, CppExtension + + +ROOT_DIR = Path(__file__).parent +CSRC_DIR = ROOT_DIR / "torch_stable_test" / "csrc" + + +class clean(distutils.command.clean.clean): + def run(self): + # Run default behavior first + distutils.command.clean.clean.run(self) + + # Remove extension + for path in (ROOT_DIR / "torch_stable_test").glob("**/*.so"): + path.unlink() + # Remove build and dist and egg-info directories + dirs = [ + ROOT_DIR / "build", + ROOT_DIR / "dist", + ROOT_DIR / "torch_stable_test.egg-info", + ] + for path in dirs: + if path.exists(): + shutil.rmtree(str(path), ignore_errors=True) + + +def get_extension(): + extra_compile_args = { + "cxx": ["-fdiagnostics-color=always", "-DTORCH_STABLE_ONLY"], + } + + sources = list(CSRC_DIR.glob("**/*.cpp")) + + return [ + CppExtension( + "torch_stable_test._C", + sources=sorted(str(s) for s in sources), + py_limited_api=True, + extra_compile_args=extra_compile_args, + extra_link_args=[], + ) + ] + + +setup( + name="torch_stable_test", + version="0.0", + author="PyTorch Core Team", + description="Test extension to verify TORCH_STABLE_ONLY flag", + packages=find_packages(exclude=("test",)), + package_data={"torch_stable_test": ["*.dll", "*.dylib", "*.so"]}, + install_requires=[ + "torch", + ], + ext_modules=get_extension(), + cmdclass={ + "build_ext": BuildExtension.with_options(no_python_abi_suffix=True), + "clean": clean, + }, + options={"bdist_wheel": {"py_limited_api": "cp39"}}, +) diff --git a/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/__init__.py b/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/csrc/test_extension.cpp b/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/csrc/test_extension.cpp new file mode 100644 index 0000000000000..c92d56da11ba3 --- /dev/null +++ b/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/csrc/test_extension.cpp @@ -0,0 +1 @@ +#include // This should trigger the TORCH_STABLE_ONLY error diff --git a/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/test_torch_stable.py b/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/test_torch_stable.py new file mode 100644 index 0000000000000..5c5613bb5484e --- /dev/null +++ b/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/test_torch_stable.py @@ -0,0 +1,22 @@ +# Owner(s): ["module: cpp"] + +from pathlib import Path + +from torch.testing._internal.common_utils import ( + install_cpp_extension, + IS_WINDOWS, + run_tests, + TestCase, +) + + +if not IS_WINDOWS: + + class TestTorchStable(TestCase): + def test_setup_fails(self): + with self.assertRaisesRegex(RuntimeError, "build failed for cpp extension"): + install_cpp_extension(extension_root=Path(__file__).parent.parent) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index 2eda2b218e705..4fb746ea15271 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -38,9 +38,9 @@ // The following files are implemented in a header-only way and are guarded by // test/cpp/aoti_abi_check -#include -#include -#include +#include +#include +#include #ifdef __cplusplus extern "C" { From 6c02dde0b1fdef86ffa88b8658ee0e0614ef999a Mon Sep 17 00:00:00 2001 From: Dzmitry Huba Date: Wed, 19 Nov 2025 08:17:45 -0800 Subject: [PATCH 055/230] Introduce missing collectives and small fixes to support local tensor mode in AutoParallel (#168110) This PR introduces support for additional functional collectives used in AutoParallel. Another change is in the semantic of the tolist() on the LocalTensor. Previously LocalTensor would reconcile first and then return a single tensor that is same on all ranks. AutoParallel uses tolist() to compute all-to-all splits during token dispatch and combine. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168110 Approved by: https://github.com/ezyang --- test/distributed/test_local_tensor.py | 47 ++++++ torch/distributed/_local_tensor/__init__.py | 157 +++++++++++++++++--- torch/distributed/_local_tensor/_c10d.py | 63 ++++++++ 3 files changed, 250 insertions(+), 17 deletions(-) diff --git a/test/distributed/test_local_tensor.py b/test/distributed/test_local_tensor.py index fa081243c2816..d4c1a7333bf34 100644 --- a/test/distributed/test_local_tensor.py +++ b/test/distributed/test_local_tensor.py @@ -373,6 +373,53 @@ def test_all_gather_collective(self): self.assertEqual(tensor_list[1], different_tensors[1]) self.assertEqual(tensor_list[2], different_tensors[2]) + def test_all_to_all_single_collective(self): + """Test that all_to_all_single collective operation works correctly with LocalTensor.""" + from torch.distributed._functional_collectives import all_to_all_single + + # Create different tensors for each rank + # Each rank will split its tensor and send parts to other ranks + different_tensors = { + 0: torch.tensor( + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + ), # rank 0 sends [0,0], [0,0], [0,0] to ranks 0,1,2 + 1: torch.tensor( + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + ), # rank 1 sends [1,1], [1,1], [1,1] to ranks 0,1,2 + 2: torch.tensor( + [2.0, 2.0, 2.0, 2.0, 2.0, 2.0] + ), # rank 2 sends [2,2], [2,2], [2,2] to ranks 0,1,2 + } + + # Each rank splits its input into 3 parts of size 2 each + input_split_sizes = [2, 2, 2] + # Each rank receives 3 parts of size 2 each from all ranks + output_split_sizes = [2, 2, 2] + + with LocalTensorMode(self.world_size): + lt_input = LocalTensor(different_tensors) + + # Test all_to_all_single using functional collectives API + result = all_to_all_single( + lt_input, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=torch.distributed.distributed_c10d._get_default_group(), + ) + + result = result.wait() + # Verify result is a LocalTensor + self.assertIsInstance(result, LocalTensor) + + # After all_to_all_single: + # rank 0 receives: [0,0] from rank 0, [1,1] from rank 1, [2,2] from rank 2 = [0,0,1,1,2,2] + # rank 1 receives: [0,0] from rank 0, [1,1] from rank 1, [2,2] from rank 2 = [0,0,1,1,2,2] + # rank 2 receives: [0,0] from rank 0, [1,1] from rank 1, [2,2] from rank 2 = [0,0,1,1,2,2] + expected_output = torch.tensor([0.0, 0.0, 1.0, 1.0, 2.0, 2.0]) + + for rank in different_tensors: + self.assertEqual(result._local_tensors[rank], expected_output) + class TestLocalTensorWorld4(LocalTensorTestBase): world_size = 4 diff --git a/torch/distributed/_local_tensor/__init__.py b/torch/distributed/_local_tensor/__init__.py index db03d26227911..dbb0071d86ec7 100644 --- a/torch/distributed/_local_tensor/__init__.py +++ b/torch/distributed/_local_tensor/__init__.py @@ -49,6 +49,7 @@ import operator import os import sys +import threading from collections import defaultdict from collections.abc import Callable, Generator, Sequence from types import TracebackType @@ -96,6 +97,83 @@ def _is_in_fake_tensor_mode() -> bool: ) +def _reduce_multidim_lists( + lists_to_reduce: list[Any], reduce_func: Callable[[list[Any]], Any] +) -> Any: + """ + Reduces a list of multi-dimensional lists, assuming they all have + the exact same shape. + + Args: + lists_to_reduce (list): A list where each item is a multi-dimensional + list (e.g., [md_list_1, md_list_2, ...]). + All inner md_lists must have the same shape. + reduce_func (callable): A function that takes an iterable (list) of + values and returns a single reduced value. + For example: sum, max, min, or + lambda x: sum(x) / len(x) for mean. + + Returns: + A single multi-dimensional list of the same shape as the inputs, + where each value is the result of the reduce_func. + + Raises: + ValueError: If the input list is empty or if shapes are inconsistent + (which may also raise IndexError or TypeError). + """ + if not lists_to_reduce: + raise ValueError("Input 'lists_to_reduce' cannot be empty.") + + # Get the first list to inspect its structure (shape) + first_list = lists_to_reduce[0] + + # Check if the first element of this list is *also* a list. + # This determines if we are at the base case or need to recurse. + if isinstance(first_list[0], list): + # --- RECURSIVE STEP --- + # The elements are lists, so we need to go one level deeper. + + # We find the number of sub-lists from the first list. + # (e.g., for [[1,2], [3,4]], this is 2) + num_sublists = len(first_list) + + result = [] + # Iterate by the index of the sub-lists (e.g., i = 0, then i = 1) + for i in range(num_sublists): + # Build a new list to pass to the recursive call. + # This list will contain the i-th sublist from *each* of the + # input lists. + # e.g., if lists_to_reduce = [ L1, L2 ] and i = 0, + # this creates [ L1[0], L2[0] ] + sublists_to_reduce = [l[i] for l in lists_to_reduce] + + # Recurse and append the result + result.append(_reduce_multidim_lists(sublists_to_reduce, reduce_func)) + return result + else: + # --- BASE CASE --- + # The elements are values (int, float, etc.), not lists. + # We are at the innermost dimension. + + # Find the number of values in the innermost list. + # (e.g., for [1, 2], this is 2) + num_values = len(first_list) + + result = [] + # Iterate by the index of the values (e.g., i = 0, then i = 1) + for i in range(num_values): + # Get the values at this specific position (i) from *all* + # input lists. + # e.g., if lists_to_reduce = [ [1,2], [10,20] ] and i = 0, + # this creates [ 1, 10 ] + values_at_pos = [l[i] for l in lists_to_reduce] + + # Apply the user-provided reduction function to this list of values + # and append the single result. + result.append(reduce_func(values_at_pos)) + return result + + def _is_inplace_op(op: OpOverload | Callable[..., Any]) -> bool: return ( isinstance(op, OpOverload) @@ -284,7 +362,11 @@ def _for_each_rank_run_func( rank_flat_args = [_map_to_rank_local_val(a, r) for a in flat_args] rank_args, rank_kwargs = pytree.tree_unflatten(rank_flat_args, args_spec) - rank_ret = func(*rank_args, **rank_kwargs) + if func is torch.ops.aten.hash_tensor.default and rank_args[0].numel() == 0: + # Special case for empty tensors, hash_tensor returns an empty tensor + rank_ret = torch.empty(0, dtype=torch.uint64, device=rank_args[0].device) + else: + rank_ret = func(*rank_args, **rank_kwargs) flat_rank_rets[r] = rank_ret if use_per_rank_rng: @@ -385,6 +467,12 @@ def sym_max( } ) + def sym_sum(self, other: Any) -> "LocalIntNode | ConstantIntNode": + t = LocalIntNode(dict.fromkeys(self._local_ints, 0)) + for o in other: + t = t.add(o) + return t + def neg(self) -> "LocalIntNode | ConstantIntNode": return LocalIntNode({r: -self._local_ints[r] for r in self._local_ints}) @@ -971,10 +1059,24 @@ def is_contiguous( def tolist(self) -> list[Any]: """ - Reconcile and convert result to list. + Try to reconcile, if successful convert to list, otherwise if dtype is integer, + convert to list of local integers. """ + equal_obj = self._equal_local_tensors() + if isinstance(equal_obj, torch.Tensor): + return equal_obj.tolist() + if isinstance(equal_obj, torch.Size): + if not self.dtype.is_floating_point and not self.dtype.is_complex: + ranks = sorted(self._ranks) + local_lists = [self._local_tensors[r].tolist() for r in ranks] + return _reduce_multidim_lists( + local_lists, + lambda values: torch.SymInt( + LocalIntNode(dict(zip(ranks, values, strict=True))) + ), + ) - return self.reconcile().tolist() + raise RuntimeError("Cannot convert local tensor to list") def reconcile(self) -> torch.Tensor: """ @@ -988,16 +1090,23 @@ def reconcile(self) -> torch.Tensor: """ # Force all local tensor shards across ranks to be the same - it = iter(self._local_tensors.values()) - t1 = next(it) - for t2 in it: - assert torch.equal(t1, t2), ( - "LocalTensor shards must be the same to reconcile" - ) - cl = t1.clone().detach() + equal_obj = self._equal_local_tensors() + assert isinstance(equal_obj, torch.Tensor), ( + "LocalTensor shards must be the same to reconcile" + ) + cl = equal_obj.clone().detach() cl.requires_grad_(self.requires_grad) return cl + def _equal_local_tensors(self) -> torch.Tensor | torch.Size | None: + it = iter(self._local_tensors.values()) + t1 = next(it) + if all(t2.equal(t1) for t2 in it): + return t1 + if all(t2.shape == t1.shape for t2 in it): + return t1.shape + return None + def _sync_meta(self) -> None: with no_dispatch(): (shape, strides, device, dtype, layout, extra_dispatch_keys) = ( @@ -1006,7 +1115,18 @@ def _sync_meta(self) -> None: self._size = shape -_LOCAL_TENSOR_MODE: list["LocalTensorMode"] = [] +_GLOBAL_LOCAL_TENSOR_MODE: list["LocalTensorMode"] = [] +# When running under local runner each thread must create its own local tensor mode +# so that they do not interfere with each other. +_THREAD_LOCAL_TENSOR_MODE: threading.local = threading.local() + + +def get_local_tensor_mode_list() -> list["LocalTensorMode"]: + if not hasattr(_THREAD_LOCAL_TENSOR_MODE, "value"): + _THREAD_LOCAL_TENSOR_MODE.value = [] + if len(_THREAD_LOCAL_TENSOR_MODE.value) > 0: + return _THREAD_LOCAL_TENSOR_MODE.value + return _GLOBAL_LOCAL_TENSOR_MODE class LocalTensorMode(TorchDispatchMode): @@ -1047,7 +1167,7 @@ def __enter__(self) -> "LocalTensorMode": self._disable = False self._patch_device_mesh() self._patch_random_functions() - _LOCAL_TENSOR_MODE.append(self) + get_local_tensor_mode_list().append(self) # _distribute_region will compute correct per-shard offsets # but we want all ranks to start with the same state @@ -1070,7 +1190,7 @@ def __exit__( self._disable = True self._unpatch_device_mesh() self._unpatch_random_functions() - _LOCAL_TENSOR_MODE.pop() + get_local_tensor_mode_list().pop() super().__exit__(exc_type, exc_val, exc_tb) def __torch_dispatch__( @@ -1160,6 +1280,10 @@ def __torch_dispatch__( return _c10d._local_functional_all_gather_into_tensor(*args, **kwargs) elif func is torch.ops._c10d_functional.reduce_scatter_tensor.default: return _c10d._local_functional_reduce_scatter_tensor(*args, **kwargs) + elif func is torch.ops._c10d_functional.all_to_all_single.default: + return _c10d._local_functional_all_to_all_single(*args, **kwargs) + elif func is torch.ops._c10d_functional.wait_tensor.default: + return _c10d._local_functional_wait_tensor(*args, **kwargs) else: with LocalTensorMode(self.ranks): return func._op_dk( @@ -1381,8 +1505,9 @@ def local_tensor_mode() -> Optional[LocalTensorMode]: Returns: Optional[LocalTensorMode]: The current LocalTensorMode if active, else None. """ - if len(_LOCAL_TENSOR_MODE) > 0: - return _LOCAL_TENSOR_MODE[-1] + local_tensor_mode_list = get_local_tensor_mode_list() + if len(local_tensor_mode_list) > 0: + return local_tensor_mode_list[-1] return None @@ -1602,7 +1727,6 @@ def _get_recv_object(self, src: int, dst: int) -> object | None: def _signal_send(self, src: int, dst: int, obj: object) -> None: assert obj is not None, "Cannot signal None" - self._assert_holds_run_lock() # Only a single thread a time executes so it is safe to mutate # read objects queue (executing thread is already holding the lock) self._recv_objects[dst][src].put(obj) @@ -1611,7 +1735,6 @@ def _signal_send(self, src: int, dst: int, obj: object) -> None: self._run_cond.notify_all() def _wait_recv(self, src: int, dst: int, post: Callable[[object], None]) -> None: - self._assert_holds_run_lock() # Wait for the object to be available while True: obj = self._get_recv_object(src, dst) diff --git a/torch/distributed/_local_tensor/_c10d.py b/torch/distributed/_local_tensor/_c10d.py index 0b63330dfafce..873da1ad5c626 100644 --- a/torch/distributed/_local_tensor/_c10d.py +++ b/torch/distributed/_local_tensor/_c10d.py @@ -216,6 +216,69 @@ def _local_functional_shard_dim_alltoall( return output +def _local_functional_all_to_all_single( + tensor: torch.Tensor, + output_split_sizes: list[torch.SymInt], + input_split_sizes: list[torch.SymInt], + group_name: str, +) -> torch.Tensor: + # "all_to_all_single(Tensor input, SymInt[] output_split_sizes, SymInt[] input_split_sizes, str group_name) -> Tensor" + from . import LocalIntNode, LocalTensor + + ranks, group_offsets, offset = _prepare_collective_groups( + _resolve_process_group(group_name) + ) + + assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor" + + split_local_sizes: dict[int, list[int]] = {} + for input_split_size in input_split_sizes: + if isinstance(input_split_size, torch.SymInt) and isinstance( + input_split_size.node, LocalIntNode + ): + local_ints = dict(input_split_size.node._local_ints.items()) + else: + local_ints = { + rank: int(input_split_size) for rank in tensor._local_tensors.keys() + } + for rank, split_size in local_ints.items(): + if rank not in split_local_sizes: + split_local_sizes[rank] = [] + split_local_sizes[rank].append(split_size) + + split_local_tensors: dict[int, list[torch.Tensor]] = {} + + for rank, split_sizes in split_local_sizes.items(): + split_local_tensors[rank] = list( + torch.split(tensor._local_tensors[rank], split_sizes) + ) + + output_local_tensors: dict[int, torch.Tensor] = {} + + for group_offset in group_offsets: + group_ranks = [group_offset + r for r in ranks] + + for i, dst in enumerate(group_ranks): + splits = [] + for j, src in enumerate(group_ranks): + splits.append(split_local_tensors[src][i]) + output_local_tensors[dst] = torch.cat(splits) + + # pyrefly: ignore [bad-argument-type, bad-argument-count] + output = LocalTensor(output_local_tensors) + + return output + + +def _local_functional_wait_tensor(tensor: torch.Tensor) -> torch.Tensor: + # "wait_tensor(Tensor input) -> Tensor" + from . import LocalTensor + + assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor" + + return tensor + + def _local_broadcast_( tensors: list[torch.Tensor], process_group_so: ScriptObject, From f9724db4921288a096e331cee835abd43257fbd6 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Wed, 19 Nov 2025 19:32:56 +0000 Subject: [PATCH 056/230] [torch.onnx.export] Fix onnx export on big endian machines (#167816) On big endian machines, constant values in the exported onnx model are still in big endian, they need to be converted to little endian to comply with the onnx specification. This fixes that issue by calling super's methods of `ir.Tensor` that already handles endianness well. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167816 Approved by: https://github.com/titaiwangms --- torch/onnx/_internal/exporter/_core.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torch/onnx/_internal/exporter/_core.py b/torch/onnx/_internal/exporter/_core.py index f1f1ac6c67e40..77e2e3049fb31 100644 --- a/torch/onnx/_internal/exporter/_core.py +++ b/torch/onnx/_internal/exporter/_core.py @@ -9,6 +9,7 @@ import logging import operator import pathlib +import sys import textwrap import traceback import typing @@ -182,6 +183,9 @@ def _get_cbytes(self): ).from_address(tensor.data_ptr()) def tobytes(self) -> bytes: + # On big-endian machines, call the super's tobytes() which returns a little-endian result. + if sys.byteorder == "big": + return super().tobytes() # Implement tobytes to support native PyTorch types so we can use types like bloat16 # Reading from memory directly is also more efficient because # it avoids copying to a NumPy array @@ -189,6 +193,9 @@ def tobytes(self) -> bytes: return bytes(data) def tofile(self, file) -> None: + # On big-endian machines, call the super's tofile() which returns a little-endian result. + if sys.byteorder == "big": + return super().tofile(file) _, data = self._get_cbytes() return file.write(data) From 607e2e7f2c53b463fa0657e441cd087529698795 Mon Sep 17 00:00:00 2001 From: eqy Date: Wed, 19 Nov 2025 20:35:39 +0000 Subject: [PATCH 057/230] [Distributed] Fix @parametrize on unordered iterable in distributed test (again) (#168012) Same fix as https://github.com/pytorch/pytorch/pull/159793, as this was broken again in https://github.com/pytorch/pytorch/pull/161476, assuming to be due to a botched rebase Pull Request resolved: https://github.com/pytorch/pytorch/pull/168012 Approved by: https://github.com/Skylion007 --- test/distributed/fsdp/test_distributed_checkpoint.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/test/distributed/fsdp/test_distributed_checkpoint.py b/test/distributed/fsdp/test_distributed_checkpoint.py index 00479cf0935b9..67f8e1af9abbd 100644 --- a/test/distributed/fsdp/test_distributed_checkpoint.py +++ b/test/distributed/fsdp/test_distributed_checkpoint.py @@ -30,11 +30,13 @@ ) sys.exit(0) - -_DISTRIBUTED_STATE_DICT_IMPLS = { +# NB: this iterable needs to be orderd as otherwise different ranks may run with +# conflicting settings when e.g., @parametrize(_DISTRIBUTED_STATE_DICT_IMPLS) is +# used to decorate tests +_DISTRIBUTED_STATE_DICT_IMPLS = ( StateDictType.LOCAL_STATE_DICT, StateDictType.SHARDED_STATE_DICT, -} +) class TestDistributedCheckpoint(FSDPTest): From 5b35cf102fbae1d964b3609bf36de0ea7eb90c20 Mon Sep 17 00:00:00 2001 From: Anshul Sinha Date: Tue, 18 Nov 2025 11:20:48 -0800 Subject: [PATCH 058/230] [DTensor][ops] adding aten.std.correction propagation rule (#168057) **Summary:** I added a new sharding propagation rule for aten.std.correction so that users can call .std() on DTensors. Since aten.var.correction already has a sharding propagation rule, and aten.std.correction should just take the square root, I added it to std_var_reduction_strategy's register op strategy list. Also removed std from list of ops that should fail on dtensors. **Test Case** 1. pytest test/distributed/tensor/test_math_ops.py -k test_std Pull Request resolved: https://github.com/pytorch/pytorch/pull/168057 Approved by: https://github.com/wconstab --- test/distributed/tensor/test_dtensor_ops.py | 2 -- test/distributed/tensor/test_math_ops.py | 24 +++++++++++++++++++++ torch/distributed/tensor/_ops/_math_ops.py | 9 ++++++-- 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/test/distributed/tensor/test_dtensor_ops.py b/test/distributed/tensor/test_dtensor_ops.py index df51152a90307..5880efb3734bf 100644 --- a/test/distributed/tensor/test_dtensor_ops.py +++ b/test/distributed/tensor/test_dtensor_ops.py @@ -435,8 +435,6 @@ def repurpose_ops(op_db, base_test_name, derived_test_name): xfail("signal.windows.nuttall"), xfail("signal.windows.kaiser"), xfail("stack"), - xfail("std"), - xfail("std", "unbiased"), xfail("std_mean"), xfail("std_mean", "unbiased"), xfail("stft"), diff --git a/test/distributed/tensor/test_math_ops.py b/test/distributed/tensor/test_math_ops.py index 56321806477b9..5eb92a44188e6 100644 --- a/test/distributed/tensor/test_math_ops.py +++ b/test/distributed/tensor/test_math_ops.py @@ -1037,6 +1037,30 @@ def test_matching_partial_reduction_ops(self): self.assertTrue(out_with_redistribute.placements[0].is_replicate()) self.assertEqual(out_without_redistribute, out_with_redistribute) + @with_comms + def test_std(self): + mesh = DeviceMesh(self.device_type, torch.arange(4).reshape(2, 2)) + rank = self.rank + comm_mode = CommDebugMode() + + global_tensor = map_local_for_rank( + rank, + lambda rank: torch.tensor( + [[-20.0, -18.0, -12.0, 0.0], [-20.0, -18.0, -8.0, 4.0]] + ), + ) + + dt = distribute_tensor(global_tensor, mesh, [Shard(0), Shard(1)]) + + with comm_mode: + res = dt.std(dim=1) + expected_answer = torch.tensor([9.0, 11.0]) + + self.assertEqual(comm_mode.get_total_counts(), 1) + self.assertEqual(comm_mode.get_comm_counts()[funcol.all_gather_into_tensor], 1) + self.assertEqual(res.placements, [Shard(0), Replicate()]) + self.assertEqual(res.full_tensor(), expected_answer) + DistMathOpsTestWithLocalTensor = create_local_tensor_test_class( DistMathOpsTest, diff --git a/torch/distributed/tensor/_ops/_math_ops.py b/torch/distributed/tensor/_ops/_math_ops.py index 545895c83b6eb..63a352adc8dc7 100644 --- a/torch/distributed/tensor/_ops/_math_ops.py +++ b/torch/distributed/tensor/_ops/_math_ops.py @@ -405,10 +405,15 @@ def cumsum_strategy(op_schema: OpSchema) -> OpStrategy: @register_op_strategy( - [aten.var.correction, aten.var.correction_out], + [ + aten.std.correction, + aten.std.correction_out, + aten.var.correction, + aten.var.correction_out, + ], schema_info=RuntimeSchemaInfo(1, ["keepdim"]), ) -def var_reduction_strategy(op_schema: OpSchema) -> OpStrategy: +def std_var_reduction_strategy(op_schema: OpSchema) -> OpStrategy: args_schema = op_schema.args_schema input_strategy = args_schema[0] if not isinstance(input_strategy, OpStrategy): From f6cde6e2988e1527efc0da2f643a4855db5809d9 Mon Sep 17 00:00:00 2001 From: Georgia Phillips Date: Wed, 19 Nov 2025 21:06:46 +0000 Subject: [PATCH 059/230] Fix tensor -> scalar variant swap (#168007) Summary: Follow up on https://github.com/pytorch/pytorch/issues/123478 Test Plan: Before change hit: ``` terminate called after throwing an instance of 'c10::Error' what(): Exception while executing node: %_FOLDED_CONST_mul_tensor_72 = torch.ops.aten.mul.Tensor(self=%main_module_module_over_arch_dhen_arch__layers_0__layernorm2_bias, other=2) with args: arg0 self: Tensor c10::BFloat16[5640]cpu arg1 other: Int 2 ``` ``` MTIAC_OPERATORS_DEFINITION=~/fbsource/fbcode/mtia/kernels/artemis/kernel_impl/op_defs/triton_arange.json AFG_RECOMPILE_ELF=true buck run mode/dev-nosan -c mtia.use_msp=true -c glow.backend=MTIA -c mtia.debug=false caffe2/torch/fb/model_transform/fx2trt/packaging:load_net_predictor -- --loadMode=Benchmark --inputNetFile=/data/users/$USER/models/${MODEL_ENTITY_ID}/${SNAPSHOT_ID}/${MODEL_ENTITY_ID}_${SNAPSHOT_ID}.predictor.mtia.${module_name} --moduleName=${module_name} --submodToDevice="" --benchmarkEnableProfiling=false --disableStaticRuntime=true --doNotRandomizeSampleInputs=true --benchmarkDontRebatchSamples=true --pytorch_predictor_sigmoid_static_dispatch_enable=false --pytorch_predictor_sigmoid_graph_passes_enable=false --pytorch_predictor_runtime_const_folding_enable=false --sampleInputFilePath=${SAMPLE_INPUT_DIR}/${module_name}.pt --load_lowered_merge=3 --predictor_hardware_type=101 --module_num_workers_per_gpu="merge|4;remote|2" --useTgif=true ``` Reviewed By: henryoier Differential Revision: D86979852 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168007 Approved by: https://github.com/henryoier --- torch/nativert/kernels/KernelHandlerRegistry.cpp | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/torch/nativert/kernels/KernelHandlerRegistry.cpp b/torch/nativert/kernels/KernelHandlerRegistry.cpp index 3ac176a81bc3a..69655067a79c7 100644 --- a/torch/nativert/kernels/KernelHandlerRegistry.cpp +++ b/torch/nativert/kernels/KernelHandlerRegistry.cpp @@ -23,9 +23,10 @@ std::string maybeRevisedStaticDispatchTarget(const Node& node) { auto overloadName = selectScalarOverloadName(node); if (!overloadName.empty() && !c10::ends_with(node.target(), overloadName)) { - const std::string& newTarget = + const std::string newTarget = std::string(node.target()) - .replace(node.target().rfind('.'), std::string::npos, overloadName); + .replace( + node.target().rfind('.') + 1, std::string::npos, overloadName); LOG(INFO) << fmt::format( "Converting Tensor to {} for node: {} -> {}", overloadName, @@ -36,6 +37,11 @@ std::string maybeRevisedStaticDispatchTarget(const Node& node) { return std::string(node.target()); } +void updateNodeTargetIfNeeded(Node& node) { + auto newTarget = maybeRevisedStaticDispatchTarget(node); + node.setTarget(newTarget); +} + std::unique_ptr make_proxy_executor( const std::string& filename, bool is_cpu, @@ -69,6 +75,8 @@ void register_kernel_handlers() { const torch::nativert::ExecutorConfig& executorConfig, caffe2::serialize::PyTorchStreamReader* packageReader) -> std::pair { + updateNodeTargetIfNeeded(const_cast(node)); + return { torch::nativert::StaticallyDispatchedCPUKernelRegistry() ->Create(maybeRevisedStaticDispatchTarget(node), &node), From c56655268b4ae575ee4c89c312fd93ca2f5b3ba9 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Wed, 19 Nov 2025 21:27:36 +0000 Subject: [PATCH 060/230] [DebugMode] wait before hashing collectives by default (#168119) Calls wait() on collectives by default for `log_tensor_hashes()`, can be disabled if you don't care with `wait_on_collectives=False`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168119 Approved by: https://github.com/yushangdi, https://github.com/mlazos --- .../tensor/debug/test_debug_mode.py | 26 +++++++++++++++++++ torch/utils/_debug_mode.py | 11 +++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/test/distributed/tensor/debug/test_debug_mode.py b/test/distributed/tensor/debug/test_debug_mode.py index 1ba4bccf3696d..0b7acebbd8aac 100644 --- a/test/distributed/tensor/debug/test_debug_mode.py +++ b/test/distributed/tensor/debug/test_debug_mode.py @@ -582,6 +582,32 @@ def test_check_structure_mismatches(self): with self.assertRaisesRegex(ValueError, "Log lengths don't match"): DebugMode.check_hash_mismatches(dm1.logs, dm3.logs) + @unittest.skipIf( + not torch.cuda.is_available() + or torch.cuda.get_device_properties(0).total_memory < 2**26, + "Being conservative, test peak memory is 25MB?", + ) + def test_tensor_hash_waits_on_collective(self): + # test that hashing collectives gives correct results + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + + local_tensor = torch.ones(2**18, device=self.device_type) + dt = DTensor.from_local(local_tensor, mesh, [Shard(0)], run_check=False) + + with DebugMode() as debug_mode, DebugMode.log_tensor_hashes(): + dt.redistribute(mesh, [Replicate()]) + + # Find all_gather hash + all_gather_logs = [ + op + for op in debug_mode.logs + if isinstance(op, _OpCall) + and op.op == torch.ops._c10d_functional.all_gather_into_tensor.default + ] + self.assertEqual(len(all_gather_logs), 1) + actual_hash = all_gather_logs[0].log["hash"] + self.assertEqual(actual_hash, float(local_tensor.numel() * self.world_size)) + def test_pretty_print_dtensor_make_fx(self): mesh = DeviceMesh(self.device_type, list(range(self.world_size))) diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index 745b05d1904d7..0b853997261a9 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -924,7 +924,9 @@ def dispatch_hook(func, types, args, kwargs, result): @staticmethod @contextlib.contextmanager def log_tensor_hashes( - hash_fn: Union[Callable, str, list[str]] = "norm", hash_inputs: bool = False + hash_fn: Union[Callable, str, list[str]] = "norm", + hash_inputs: bool = False, + wait_on_collectives: bool = True, ): """ Installs hook for tensor hash logging. @@ -936,6 +938,7 @@ def log_tensor_hashes( - "hash_tensor": uses torch.hash_tensor (XOR sum reduction) - List of strings: returns tuple of hashes from above options hash_inputs: if True, also hashes tensors in (args, kwargs), storing them in "input_hash". + wait_on_collectives: if True (default), waits on async collective Work handles before hashing. NOTE: this is currently a post-hook, so e.g. inplace ops will log the "output" hashes. """ @@ -966,6 +969,12 @@ def _dispatch_hash_hook(func, types, args, kwargs, result): if "empty" in str(func) or "profiler" in str(func): return None + # Wait on async collective Work handles before hashing + if wait_on_collectives and isinstance(result, (tuple, list)): + for item in result: + if isinstance(item, torch.ScriptObject) and hasattr(item, "wait"): + item.wait() + out = {} out["hash"] = _tree_hash(result) if hash_inputs: From 84a7a34e5fb22c89530791136dc5f3023d2f709b Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 19 Nov 2025 19:34:06 +0000 Subject: [PATCH 061/230] [FlexFlash] Specify lowering w/ new `BACKEND` kernel option (#168017) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): Align w/ naming convention Pull Request resolved: https://github.com/pytorch/pytorch/pull/168017 Approved by: https://github.com/Chillee, https://github.com/Skylion007 --- test/export/test_export.py | 2 +- test/inductor/test_flex_attention.py | 189 +++++++++++++++++- test/inductor/test_flex_flash.py | 26 ++- torch/_inductor/kernel/flex/flex_attention.py | 35 +++- .../kernel/flex/flex_flash_attention.py | 26 ++- torch/nn/attention/flex_attention.py | 41 +++- 6 files changed, 281 insertions(+), 38 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 204d458e77704..8545c210e1b8d 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -968,7 +968,7 @@ def forward(self, x): view_3 = torch.ops.aten.view.default(linear_3, [2, 1, 128, 64]); linear_3 = None sdpa_score0 = self.sdpa_score0 sdpa_mask0 = self.sdpa_mask0 - flex_attention = torch.ops.higher_order.flex_attention(view_1, view_2, view_3, sdpa_score0, (128, 128, to_3, to_4, to_6, to_7, to_9, to_10, to_12, to_13, 128, 128, sdpa_mask0), 0.125, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': False, 'OUTPUT_MAX': False}, (), (detach,)); view_1 = view_2 = view_3 = sdpa_score0 = to_3 = to_4 = to_6 = to_7 = to_9 = to_10 = to_12 = to_13 = sdpa_mask0 = detach = None + flex_attention = torch.ops.higher_order.flex_attention(view_1, view_2, view_3, sdpa_score0, (128, 128, to_3, to_4, to_6, to_7, to_9, to_10, to_12, to_13, 128, 128, sdpa_mask0), 0.125, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': False, 'OUTPUT_MAX': False}, (), (detach,)); view_1 = view_2 = view_3 = sdpa_score0 = to_3 = to_4 = to_6 = to_7 = to_9 = to_10 = to_12 = to_13 = sdpa_mask0 = detach = None getitem = flex_attention[0] getitem_1 = flex_attention[1]; getitem_1 = None getitem_2 = flex_attention[2]; flex_attention = getitem_2 = None diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 7a2f9ecdeae8b..84d179e2ca52b 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -15,7 +15,7 @@ from dataclasses import dataclass from itertools import product from typing import Optional, TypeVar, Union -from unittest import expectedFailure, skip, skipUnless +from unittest import expectedFailure, mock, skip, skipUnless from unittest.mock import patch import torch @@ -28,6 +28,7 @@ from torch.nn.attention import SDPBackend from torch.nn.attention.experimental._paged_attention import PagedAttention from torch.nn.attention.flex_attention import ( + _apply_kernel_options, _create_empty_block_mask, _DEFAULT_SPARSE_BLOCK_SIZE, _identity, @@ -3522,6 +3523,184 @@ def test_kernel_options_argument_is_respected(self, device): ) FileCheck().check("BLOCK_M : tl.constexpr = 16").run(code[0]) + @supported_platform + @skip_on_cpu + def test_backend_auto_matches_triton_large(self, device): + """BACKEND='AUTO' should follow Triton heuristics on large shapes.""" + make_tensor = functools.partial( + torch.randn, + (2, 2, 256, 64), + device=device, + dtype=torch.float16, + requires_grad=False, + ) + q, k, v = make_tensor(), make_tensor(), make_tensor() + + def compile_and_run(kernel_options): + return run_and_get_code( + torch.compile(flex_attention, fullgraph=True), + q, + k, + v, + kernel_options=kernel_options, + ) + + default_out, default_code = compile_and_run({"BACKEND": "AUTO"}) + triton_out, triton_code = compile_and_run({"BACKEND": "TRITON"}) + + torch.testing.assert_close(default_out, triton_out, atol=0.0, rtol=0.0) + + default_src = "\n".join(default_code) + FileCheck().check("flex_attention").check_not("flex_decoding").run(default_src) + + triton_src = "\n".join(triton_code) + FileCheck().check("flex_attention").check_not("flex_decoding").run(triton_src) + + @supported_platform + @skip_on_cpu + def test_backend_triton_decode_matches_auto(self, device): + """BACKEND='TRITON_DECODE' should match heuristics on decode-friendly shapes.""" + make_tensor = functools.partial( + torch.randn, + (1, 2, 64, 64), + device=device, + dtype=torch.float16, + requires_grad=False, + ) + q, k, v = make_tensor(), make_tensor(), make_tensor() + + def compile_and_run(kernel_options): + return run_and_get_code( + torch.compile(flex_attention, fullgraph=True), + q, + k, + v, + kernel_options=kernel_options, + ) + + from torch._inductor.kernel.flex import flex_attention as flex_kernel_mod + + with mock.patch.object( + flex_kernel_mod, + "create_flex_decoding_kernel", + wraps=flex_kernel_mod.create_flex_decoding_kernel, + ) as decode_kernel: + default_out, _ = compile_and_run({"BACKEND": "AUTO"}) + self.assertTrue( + decode_kernel.called, + "Expected heuristics to dispatch to flex decoding kernel.", + ) + + with mock.patch.object( + flex_kernel_mod, + "create_flex_decoding_kernel", + wraps=flex_kernel_mod.create_flex_decoding_kernel, + ) as decode_kernel: + decode_out, _ = compile_and_run({"BACKEND": "TRITON_DECODE"}) + self.assertTrue( + decode_kernel.called, + "Expected explicit BACKEND='TRITON_DECODE' to use flex decoding kernel.", + ) + + self.assertEqual(decode_out.shape, (1, 2, 64, 64)) + torch.testing.assert_close(default_out, decode_out, atol=3e-3, rtol=3e-3) + + @supported_platform + @skip_on_cpu + def test_backend_triton_decode_errors_when_not_supported(self, device): + """Requesting decode on unsupported shapes should raise a helpful error.""" + make_tensor = functools.partial( + torch.randn, + (1, 2, 256, 64), + device=device, + dtype=torch.float16, + requires_grad=False, + ) + q, k, v = make_tensor(), make_tensor(), make_tensor() + + flex_compiled = torch.compile(flex_attention, fullgraph=True) + with self.assertRaisesRegex( + RuntimeError, + r"BACKEND='TRITON_DECODE' was specified but flex_decoding cannot be used", + ): + flex_compiled(q, k, v, kernel_options={"BACKEND": "TRITON_DECODE"}) + + @supported_platform + @skip_on_cpu + def test_backend_triton_decode_errors_with_non_power_of_two_gqa(self, device): + """BACKEND='TRITON_DECODE' should fail when GQA ratio is not a power of two.""" + q = torch.randn( + 1, 3, 64, 64, device=device, dtype=torch.float16, requires_grad=False + ) + k = torch.randn( + 1, 1, 64, 64, device=device, dtype=torch.float16, requires_grad=False + ) + v = torch.randn( + 1, 1, 64, 64, device=device, dtype=torch.float16, requires_grad=False + ) + + flex_compiled = torch.compile(flex_attention, fullgraph=True) + with self.assertRaisesRegex( + RuntimeError, + r"BACKEND='TRITON_DECODE' was specified but flex_decoding cannot be used", + ): + flex_compiled( + q, + k, + v, + enable_gqa=True, + kernel_options={"BACKEND": "TRITON_DECODE"}, + ) + + @supported_platform + @skip_on_cpu + def test_backend_rejects_legacy_force_use_flag(self, device): + """Combining BACKEND with FORCE_USE_FLEX_ATTENTION should raise an error.""" + make_tensor = functools.partial( + torch.randn, + (2, 2, 128, 64), + device=device, + dtype=torch.float16, + requires_grad=False, + ) + q, k, v = make_tensor(), make_tensor(), make_tensor() + + flex_compiled = torch.compile(flex_attention, fullgraph=True) + with self.assertRaisesRegex( + RuntimeError, + r"BACKEND cannot be combined with legacy FORCE_USE_FLEX_ATTENTION", + ): + flex_compiled( + q, + k, + v, + kernel_options={ + "BACKEND": "TRITON", + "FORCE_USE_FLEX_ATTENTION": True, + }, + ) + + @supported_platform + def test_backend_defaults_and_rejects_invalid(self, device): + device = torch.device(device) + query = torch.randn(1, 1, 4, 8, device=device, dtype=torch.float32) + key = torch.randn(1, 1, 4, 8, device=device, dtype=torch.float32) + value = torch.randn(1, 1, 4, 8, device=device, dtype=torch.float32) + + kernel_options = _apply_kernel_options( + query, key, value, return_lse=True, kernel_options={} + ) + self.assertEqual(kernel_options["BACKEND"], "AUTO") + + with self.assertRaisesRegex(ValueError, r"Invalid BACKEND value 'INVALID'"): + _apply_kernel_options( + query, + key, + value, + return_lse=True, + kernel_options={"BACKEND": "INVALID"}, + ) + @supported_platform def test_block_mask_non_divisible(self, device): seq = torch.arange(1023, device=device) // 128 @@ -4154,7 +4333,7 @@ def forward(self, L_query_: "f64[2, 2, 128, 4]", L_key_: "f64[2, 2, 128, 4]", L_ score_mod_0 = self.score_mod_0 mask_fn_0 = self.mask_fn_0 - flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (128, 128, l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None + flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (128, 128, l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None out: "f64[2, 2, 128, 4]" = flex_attention[0]; flex_attention = None return (out,) @@ -4190,11 +4369,11 @@ def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs): """\ class GraphModule(torch.nn.Module): def forward(self, primals_1: "f64[2, 2, 128, 4]", primals_2: "f64[2, 2, 128, 4]", primals_3: "f64[2, 2, 128, 4]", full: "i32[1, 1, 1]", full_default: "i32[1, 1, 1, 1]", convert_element_type: "i32[1, 1, 1]", convert_element_type_1: "i32[1, 1, 1, 1]", getitem_2: "f64[2, 2, 128, 4]", getitem_3: "f32[2, 2, 128]", tangents_1: "f64[2, 2, 128, 4]"): - full_default_4: "f32[2, 2, 128]" = torch.ops.aten.full.default([2, 2, 128], 0, dtype = torch.float32, layout = torch.strided, device = device(type='GPU_TYPE', index=0), pin_memory = False) + full_default_4: "f32[2, 2, 128]" = torch.ops.aten.full.default([2, 2, 128], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) fw_graph0 = self.fw_graph0 joint_graph0 = self.joint_graph0 mask_graph0 = self.mask_graph0 - flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (1, 1, full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None + flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (1, 1, full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None getitem_5: "f64[2, 2, 128, 4]" = flex_attention_backward[0] getitem_6: "f64[2, 2, 128, 4]" = flex_attention_backward[1] getitem_7: "f64[2, 2, 128, 4]" = flex_attention_backward[2]; flex_attention_backward = None @@ -4214,7 +4393,7 @@ def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3 class mask_graph0(torch.nn.Module): def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]"): - full_default: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='GPU_TYPE', index=0), pin_memory = False) + full_default: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) return full_default """.replace( # noqa: B950 "GPU_TYPE", torch.device(device).type diff --git a/test/inductor/test_flex_flash.py b/test/inductor/test_flex_flash.py index 5f3735ac87e0d..40ca12ea9f526 100644 --- a/test/inductor/test_flex_flash.py +++ b/test/inductor/test_flex_flash.py @@ -139,7 +139,7 @@ def flash_vs_triton(q, k, v, score_mod=None, block_mask=None, rtol=2): v, score_mod=score_mod, block_mask=block_mask, - kernel_options={"force_flash": True}, + kernel_options={"BACKEND": "FLASH"}, ) out_triton = compiled_fn( q, @@ -147,7 +147,7 @@ def flash_vs_triton(q, k, v, score_mod=None, block_mask=None, rtol=2): v, score_mod=score_mod, block_mask=block_mask, - kernel_options={"force_flash": False}, + kernel_options={"BACKEND": "TRITON"}, ) assert out_flash.shape == out_ref_fp32.shape == out_triton.shape @@ -200,30 +200,28 @@ def test_flash_attention_unfriendly_seqlen_with_causal( @dtypes(torch.float16, torch.bfloat16) def test_flash_attention_kernel_called(self, device, dtype): - """Test that flash attention kernel is actually called when force_flash=True.""" + """Test that flash attention kernel is actually called when BACKEND='FLASH'.""" q, k, v = create_test_tensors(dtype=dtype, device=device) compiled_fn = torch.compile(flex_attention) - # Test that flash kernel is called with force_flash=True + # Test that flash kernel is called with BACKEND='FLASH' with cuda_kernel_profiler("flash_attncute") as prof_result: - compiled_fn( - q, k, v, score_mod=_causal, kernel_options={"force_flash": True} - ) + compiled_fn(q, k, v, score_mod=_causal, kernel_options={"BACKEND": "FLASH"}) self.assertTrue( prof_result["found"], f"Flash attention kernel not found. Available kernels: {prof_result['kernel_names']}", ) - # Test that flash kernel is NOT called with force_flash=False + # Test that flash kernel is NOT called with BACKEND='TRITON' with cuda_kernel_profiler("flash_attncute") as prof_result: compiled_fn( - q, k, v, score_mod=_causal, kernel_options={"force_flash": False} + q, k, v, score_mod=_causal, kernel_options={"BACKEND": "TRITON"} ) self.assertFalse( prof_result["found"], - f"Flash attention kernel unexpectedly found when force_flash=False. Kernels: {prof_result['kernel_names']}", + f"Flash attention kernel unexpectedly found when BACKEND='TRITON'. Kernels: {prof_result['kernel_names']}", ) @dtypes(torch.float16, torch.bfloat16) @@ -284,8 +282,8 @@ def score_view_mod(score, b, h, q_idx, kv_idx): flash_vs_triton(q, k, v, score_mod=score_view_mod) @dtypes(torch.float16, torch.bfloat16) - def test_force_flash_error_with_requires_grad(self, device, dtype): - """Test that force_flash=True raises error when tensor requires gradients.""" + def test_flash_impl_error_with_requires_grad(self, device, dtype): + """Test that BACKEND='FLASH' raises error when tensor requires gradients.""" q, k, v = create_test_tensors(dtype=dtype, device=device) bias = torch.randn(4, device=device, dtype=dtype, requires_grad=True) @@ -296,14 +294,14 @@ def score_mod_with_grad(score, b, h, q_idx, kv_idx): compiled_fn = torch.compile(flex_attention) with self.assertRaisesRegex( RuntimeError, - r"force_flash=True but flash attention cannot be used.*require gradients", + r"BACKEND='FLASH' but flash attention cannot be used.*require gradients", ): compiled_fn( q, k, v, score_mod=score_mod_with_grad, - kernel_options={"force_flash": True}, + kernel_options={"BACKEND": "FLASH"}, ) @dtypes(torch.float16, torch.bfloat16) diff --git a/torch/_inductor/kernel/flex/flex_attention.py b/torch/_inductor/kernel/flex/flex_attention.py index 1a72e279aab79..c555f66dbf538 100644 --- a/torch/_inductor/kernel/flex/flex_attention.py +++ b/torch/_inductor/kernel/flex/flex_attention.py @@ -7,12 +7,13 @@ import math from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Optional, TYPE_CHECKING, Union +from typing import Any, cast, Optional, TYPE_CHECKING, Union import sympy import torch from torch._inductor.virtualized import V +from torch.nn.attention.flex_attention import _Backend from ...ir import ComputedBuffer, ExternKernel, FixedLayout, TensorBox from ...lowering import empty, empty_strided, lowerings, register_lowering @@ -51,6 +52,17 @@ Expr = sympy.Expr +def _sanitize_kernel_options_for_triton( + kernel_options: dict[str, Any], +) -> tuple[dict[str, Any], _Backend]: + """We always strip quotes around str values, we only need this in lowering, so we pop it here + to avoid passing to triton constexpr dict + """ + sanitized = dict(kernel_options) + backend = cast(_Backend, sanitized.pop("BACKEND", "AUTO")) + return sanitized, backend + + @SymbolicGridFn def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta, *, cdiv): """How is this kernel parallelized? @@ -93,7 +105,7 @@ def flex_attention( subgraph, block_mask, scale, - kernel_options, + kernel_options: dict[str, Any], score_mod_other_buffers, mask_mod_other_buffers, ): @@ -170,7 +182,7 @@ def flex_attention( ) freeze_irnodes(mask_graph_buffer) - kernel_options = dict(kernel_options) + kernel_options, backend = _sanitize_kernel_options_for_triton(kernel_options) # Mark symbols in custom kernel options as static shapes and add guards. kernel_options = { k: V.graph.sizevars.guard_int(v) if isinstance(v, sympy.Symbol) else v @@ -180,7 +192,19 @@ def flex_attention( enable_gqa = V.graph.sizevars.evaluate_expr( sympy.Ne(query.get_size()[1], key.get_size()[1]), ) - if _use_flex_decoding(query, kv_indices, value, kernel_options, enable_gqa): + + can_use_decode = _use_flex_decoding( + query, kv_indices, value, kernel_options, enable_gqa + ) + use_decode = (backend == "TRITON_DECODE") or (backend == "AUTO" and can_use_decode) + + if backend == "TRITON_DECODE" and not can_use_decode: + raise RuntimeError( + "BACKEND='TRITON_DECODE' was specified but flex_decoding cannot be used for this input. " + "flex_decoding is only available for short sequence lengths with specific configurations." + ) + + if use_decode: return create_flex_decoding_kernel( query, key, @@ -227,6 +251,7 @@ def flex_attention( mask_graph, kernel_options, num_score_mod_placeholders=len(placeholder_inps), + backend=backend, ): return create_flex_flash_attention_kernel( query, @@ -635,7 +660,7 @@ def flex_attention_backward(*args, **kwargs): f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" ) - kernel_options = dict(kernel_options) + kernel_options, _ = _sanitize_kernel_options_for_triton(kernel_options) # Mark symbols in custom kernel options as static shapes and add guards. kernel_options = { k: V.graph.sizevars.guard_int(v) if isinstance(v, sympy.Symbol) else v diff --git a/torch/_inductor/kernel/flex/flex_flash_attention.py b/torch/_inductor/kernel/flex/flex_flash_attention.py index 0d3721aa730a4..05d1040d35c9b 100644 --- a/torch/_inductor/kernel/flex/flex_flash_attention.py +++ b/torch/_inductor/kernel/flex/flex_flash_attention.py @@ -5,7 +5,7 @@ import importlib from collections.abc import Callable, Sequence from contextlib import contextmanager -from typing import Any, Optional +from typing import Any, Literal, Optional import sympy from sympy import Expr, Integer @@ -171,20 +171,34 @@ def _use_flex_flash_attention( mask_graph: Subgraph, kernel_options: dict[str, Any], num_score_mod_placeholders: int, + backend: Literal["AUTO", "TRITON", "FLASH", "TRITON_DECODE"], ) -> bool: - """Determine if we should use flex flash attention for the given inputs.""" - force_flash = kernel_options.get("force_flash", False) + """Determine if we should use flex flash attention for the given inputs. + + Args: + subgraph: The score modification subgraph + mask_graph: The mask modification subgraph + kernel_options: Kernel configuration options + num_score_mod_placeholders: Number of placeholders in score_mod + backend: Implementation selector (AUTO, TRITON, FLASH, TRITON_DECODE) + + Returns: + True if flash attention should be used, False otherwise + """ + # Flash is experimental and must be explicitly requested + if backend != "FLASH": + return False can_use, reason = _can_use_flex_flash_attention( subgraph, mask_graph, num_score_mod_placeholders ) - if force_flash and not can_use: + if not can_use: raise RuntimeError( - f"force_flash=True but flash attention cannot be used: {reason}" + f"BACKEND='FLASH' but flash attention cannot be used: {reason}" ) - return force_flash and can_use + return True def create_flex_flash_attention_kernel( diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index f3746bcea1264..ad922227ccff8 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -7,10 +7,11 @@ import itertools import math import operator +import typing import warnings from collections.abc import Callable from enum import Enum -from typing import Any, NamedTuple +from typing import Any, Literal, NamedTuple, TypeAlias import torch from torch import Tensor @@ -82,6 +83,7 @@ def _warn_once( _score_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], Tensor] _mask_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor] +_Backend: TypeAlias = Literal["AUTO", "TRITON", "FLASH", "TRITON_DECODE"] # pyrefly: ignore [invalid-inheritance] @@ -219,12 +221,18 @@ class FlexKernelOptions(TypedDict, total=False): """ROCm-specific waves per execution unit.""" # pyrefly: ignore [invalid-annotation] - force_flash: NotRequired[bool] - """ If True, forces use of the cute-dsl flash attention kernel. - - Raises an error if flash attention cannot be used instead of falling back - to the default implementation. Useful for ensuring flash attention is used - when expected. + BACKEND: NotRequired[_Backend] + """Selects a specific kernel backend. + + Options: + - "AUTO": Use current heuristics (typically Triton-based kernels with + automatic selection between flex_attention and flex_decoding) + - "TRITON": Standard Triton flex_attention kernel + - "TRITON_DECODE": Triton flex_decoding kernel, only available for short sequence lengths with specific configurations + - "FLASH": Experimental: Flash Attention kernel (cute-dsl), user needs to have flash installed + + This option cannot be combined with legacy knobs such as ``FORCE_USE_FLEX_ATTENTION``. + Raises an error if the requested backend cannot be used. Default: "AUTO" """ @@ -1242,6 +1250,25 @@ def _apply_kernel_options( ): kernel_options = {} if kernel_options is None else dict(kernel_options) + if "BACKEND" in kernel_options and kernel_options.get( + "FORCE_USE_FLEX_ATTENTION", False + ): + # TODO: remove FORCE_USE_FLEX_ATTENTION once BACKEND is fully adopted. + raise RuntimeError( + "BACKEND cannot be combined with legacy FORCE_USE_FLEX_ATTENTION. " + "BACKEND supersedes the legacy knob; please drop FORCE_USE_FLEX_ATTENTION " + "and only specify the desired BACKEND." + ) + + if "BACKEND" in kernel_options: + valid_backends = typing.get_args(_Backend) + if kernel_options["BACKEND"] not in valid_backends: + raise ValueError( + f"Invalid BACKEND value '{kernel_options['BACKEND']}'. " + f"Must be one of {valid_backends}" + ) + + kernel_options.setdefault("BACKEND", "AUTO") kernel_options.setdefault("PRESCALE_QK", False) kernel_options.setdefault("ROWS_GUARANTEED_SAFE", False) kernel_options.setdefault("BLOCKS_ARE_CONTIGUOUS", False) From 3ecc137077021fe89f0f88399b2e0ce8f7373bbc Mon Sep 17 00:00:00 2001 From: Nicolas De Carli Date: Wed, 19 Nov 2025 22:27:34 +0000 Subject: [PATCH 062/230] [Caffe2] Improve AddMomentsVec and UpdateMomentsVec (#167664) Summary: RowwiseMomentsImpl accounts for about 0.4% cpu time of AdRanker: https://fburl.com/strobelight/ywf79nw3. It primarily calls AddMomentsVec and UpdateMomentsVec. These two routines are written using Pytorch's VecLib, meaning the utilized operators translate into intrinsics. Unfortunately, the compiler makes less transformations and optimizations when intrinsics are used. Therefore, if we carefully decouple and re-order operations, the emitted instruction sequence improves. Here we can see the dissassembly for the old and new AddMomentsVec: https://godbolt.org/z/83fxYvKfv We can see a much better instruction sequence is achieved in the new implementation. Test Plan: AdRanker ServiceLab Differential Revision: D86805648 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167664 Approved by: https://github.com/mcfi --- aten/src/ATen/native/cpu/moments_utils.h | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/native/cpu/moments_utils.h b/aten/src/ATen/native/cpu/moments_utils.h index 8aba425e89637..8fa84b4445798 100644 --- a/aten/src/ATen/native/cpu/moments_utils.h +++ b/aten/src/ATen/native/cpu/moments_utils.h @@ -46,8 +46,11 @@ C10_ALWAYS_INLINE void AddMomentsVec( const T c = n == 0 ? static_cast(0) : static_cast(m0_add) / static_cast(n); const Vec c_vec(c); const Vec delta = m1_add - m1; - m1 += c_vec * delta; - m2 += m2_add + delta * delta * c_vec * Vec(static_cast(m0)); + const Vec m2_tmp = m2 + m2_add; + const Vec c_vec_delta = c_vec * delta; + const Vec m0_delta = delta * Vec(static_cast(m0)); + m1 = m1 + c_vec_delta; + m2 = fmadd(m0_delta, c_vec_delta, m2_tmp); m0 = n; } @@ -65,9 +68,11 @@ UpdateMomentsVec( Vec m2_vec(0); for (const auto j : c10::irange(m0)) { const Vec x_vec = Vec::loadu(X_ptr + j * Vec::size()); + const Vec tmpVec = c_vecs[j]; const Vec delta_vec = x_vec - m1_vec; - m1_vec += delta_vec * c_vecs[j]; - m2_vec += delta_vec * (x_vec - m1_vec); + m1_vec = fmadd(tmpVec, delta_vec, m1_vec); + const Vec tmpVec2 = x_vec - m1_vec; + m2_vec = fmadd(delta_vec, tmpVec2, m2_vec); } AddMomentsVec(m0, m1_vec, m2_vec, m0_stk0, m1_stk0, m2_stk0); } @@ -89,13 +94,16 @@ UpdateMomentsVec( fVec m2_fvec0(0), m2_fvec1(0); for (const auto j : c10::irange(m0)) { const Vec x_bvec = Vec::loadu(X_ptr + j * Vec::size()); + const fVec tmpVec = c_vecs[j]; auto [x_fvec0, x_fvec1] = convert_to_float(x_bvec); const fVec delta_fvec0 = x_fvec0 - m1_fvec0; const fVec delta_fvec1 = x_fvec1 - m1_fvec1; - m1_fvec0 += delta_fvec0 * c_vecs[j]; - m1_fvec1 += delta_fvec1 * c_vecs[j]; - m2_fvec0 += delta_fvec0 * (x_fvec0 - m1_fvec0); - m2_fvec1 += delta_fvec1 * (x_fvec1 - m1_fvec1); + m1_fvec0 = fmadd(delta_fvec0, tmpVec, m1_fvec0); + m1_fvec1 = fmadd(delta_fvec1, tmpVec, m1_fvec1); + const fVec delta_fvec2 = x_fvec0 - m1_fvec0; + const fVec delta_fvec3 = x_fvec1 - m1_fvec1; + m2_fvec0 = fmadd(delta_fvec0, delta_fvec2, m2_fvec0); + m2_fvec1 = fmadd(delta_fvec1, delta_fvec3, m2_fvec1); } AddMomentsVec(m0, m1_fvec0, m2_fvec0, m0_stk0, m1_stk0, m2_stk0); AddMomentsVec(m0, m1_fvec1, m2_fvec1, m0_stk0, m1_stk0, m2_stk0); From 9c811b15c28beae35a4429df1b8c85d6086fc947 Mon Sep 17 00:00:00 2001 From: zhxchen17 Date: Wed, 19 Nov 2025 22:33:12 +0000 Subject: [PATCH 063/230] [export] Enable context manager returns for dynamo graph capture. (#168102) Summary: This adds the ability to trace through the code which has context manager enabled and returned in a user callable. Since we rely on side effects to return user defined variables, we cannot disable side effects by default anymore in the short term. So we decide to leave side effect config up to the caller side of dynamo_graph_capture_for_export, and still disable it for torch.export by default. In the short term we will just assume dynamo_graph_capture_for_export is a low level API and it's user responsibility to control side effect options. Test Plan: CI Reviewers: Subscribers: Tasks: Tags: Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/168102 Approved by: https://github.com/tugsbayasgalan --- test/export/test_experimental.py | 73 ++++++++++++++++++++++++++++-- torch/_dynamo/convert_frame.py | 12 ++++- torch/_dynamo/functional_export.py | 10 ++-- torch/export/_trace.py | 4 +- 4 files changed, 89 insertions(+), 10 deletions(-) diff --git a/test/export/test_experimental.py b/test/export/test_experimental.py index 5efbb13c25fd2..abfbb7a6004df 100644 --- a/test/export/test_experimental.py +++ b/test/export/test_experimental.py @@ -5,7 +5,7 @@ import unittest import warnings from dataclasses import dataclass -from typing import Dict, List, Tuple +from typing import Any, Dict, List, Tuple import torch import torch._dynamo @@ -17,11 +17,46 @@ from torch.export.graph_signature import OutputKind from torch.testing import FileCheck from torch.testing._internal.common_utils import TEST_CUDA +from torch.utils import _pytree as pytree GLOBAL_LIST = [] +class GlobalContext: + def __init__(self) -> None: + self._summaries: dict[str, MetricValue] = {} + self._tensors: dict[str, Tensor] = {} + + def __flatten__(self): + """Flattens into (leaves, ctx).""" + summary_leaves, summary_spec = pytree.tree_flatten(self._summaries) + tensor_leaves, tensor_spec = pytree.tree_flatten(self._tensors) + leaves = (*summary_leaves, *tensor_leaves) + ctx = (summary_spec, tensor_spec) + return leaves, ctx + + @classmethod + def __unflatten__(cls, leaves, ctx: tuple[pytree.TreeSpec, pytree.TreeSpec]): + """Reconstructs from (leaves, ctx).""" + output = cls() + summary_spec, tensor_spec = ctx + assert len(leaves) == summary_spec.num_leaves + tensor_spec.num_leaves + output._summaries = pytree.tree_unflatten( + leaves[: summary_spec.num_leaves], summary_spec + ) + output._tensors = pytree.tree_unflatten( + leaves[summary_spec.num_leaves :], tensor_spec + ) + return output + + def __enter__(self) -> "GlobalContext": + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + pass + + @unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported") class TestExperiment(TestCase): def test_joint_basic(self) -> None: @@ -582,6 +617,33 @@ def make_inputs(b: int): self.assertIsNotNone(gm.meta["tracing_context"].fake_mode) self.assertEqual(len(gm.meta["tracing_context"].tensor_to_context), 1) + def test_dynamo_graph_capture_ctx_return(self): + class Module(torch.nn.Module): + def forward(self, x): + with GlobalContext() as ctx: + z = x + 1 + ctx._tensors["6"] = x + 2 + return z, ctx + + def make_inputs(): + return (torch.randn(2, 3),) + + try: + pytree.register_pytree_node( + GlobalContext, + lambda x: x.__flatten__(), + GlobalContext.__unflatten__, + ) + mod = Module() + + gm = dynamo_graph_capture_for_export(mod)(*make_inputs()) + test_inputs = make_inputs() + actual_outputs = pytree.tree_leaves(gm(*test_inputs)) + expected_outputs = pytree.tree_leaves(mod(*test_inputs)) + self.assertEqual(actual_outputs, expected_outputs) + finally: + pytree._deregister_pytree_node(GlobalContext) + def test_dynamo_graph_capture_dict_keys_getitem(self): class Module(torch.nn.Module): def forward(self, x): @@ -614,9 +676,9 @@ def forward(self, args_0): _tree_leaf_0, _tree_leaf_1, = pytree.tree_leaves((self, args_0,)) L_args_0_ , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1) l_args_0_ = L_args_0_ - add = l_args_0_ + 1; add = None + add = l_args_0_ + 1 mul = l_args_0_ * 2; l_args_0_ = None - return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, mul), self._out_spec)""", + return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, mul, add), self._out_spec)""", ) self.assertEqual(gm(*test_inputs), foo(*test_inputs)) @@ -652,7 +714,10 @@ def make_inputs(): return (torch.randn(2, 3),) trace_inputs = make_inputs() - with warnings.catch_warnings(record=True) as w: + with ( + torch._dynamo.config.patch(replay_side_effects=False), + warnings.catch_warnings(record=True) as w, + ): gm = dynamo_graph_capture_for_export(foo)(*trace_inputs) cnt = 0 for entry in w: diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 1c44de2c1ad1e..58767245fa9a4 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -933,7 +933,11 @@ class GraphRuntimeEnv: argdefs: Optional[tuple[Any, ...]] def forward_callable( - self, backend_id: str, compiled_fn: Callable[..., Any] + self, + backend_id: str, + compiled_fn: Callable[..., Any], + *, + extra_globals: Optional[dict[str, Any]] = None, ) -> Callable[..., Any]: import_sources = { alias: importlib.import_module(module_name) @@ -942,6 +946,7 @@ def forward_callable( f_globals = { **import_sources, **self.used_globals, + **(extra_globals or {}), backend_id: compiled_fn, } return types.FunctionType( @@ -1026,13 +1031,16 @@ def forward_callable( self, *, compiled_fn: Optional[Callable[..., Any]] = None, + extra_globals: Optional[dict[str, Any]] = None, ) -> Callable[..., Any]: runtime_env = self.graph_capture_output.get_runtime_env() assert self.backend_input is not None backend_id = self.backend_input.backend_id # pyrefly: ignore [not-callable] compiled_fn = compiled_fn or self.backend_input.graph_module - return runtime_env.forward_callable(backend_id, compiled_fn) + return runtime_env.forward_callable( + backend_id, compiled_fn, extra_globals=extra_globals + ) def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]: diff --git a/torch/_dynamo/functional_export.py b/torch/_dynamo/functional_export.py index 84641d66c6bd8..548a4b279b860 100644 --- a/torch/_dynamo/functional_export.py +++ b/torch/_dynamo/functional_export.py @@ -501,6 +501,7 @@ def pytreeify( torch._dynamo.eval_frame.check_user_input_output( flat_real_args[1 if root else 0 :], UserErrorType.INVALID_INPUT ) + f_globals = out.graph_capture_output.f_globals class Yield(Exception): pass @@ -522,7 +523,9 @@ def backend_dummy(*example_inputs): raise Yield try: - out.forward_callable(compiled_fn=backend_dummy)(*args, **kwargs) + out.forward_callable( + compiled_fn=backend_dummy, extra_globals=f_globals + )(*args, **kwargs) except Yield: assert self.gm_inputs is not None return self.gm_inputs @@ -557,7 +560,9 @@ def backend_dummy(*example_inputs): for i in range(self.num_outputs) ] - results = out.forward_callable(compiled_fn=backend_dummy)(*args, **kwargs) + results = out.forward_callable( + compiled_fn=backend_dummy, extra_globals=f_globals + )(*args, **kwargs) ret, self.out_spec = pytree.tree_flatten(results) return ret @@ -606,7 +611,6 @@ def dynamo_graph_capture_for_export( def inner(*args: Any, **kwargs: Any) -> Any: assert not torch._dynamo.config.install_free_tensors with ( - torch._dynamo.config.patch(replay_side_effects=False), torch._dynamo.config.patch(side_effect_replay_policy="warn"), get_metrics_context(), dynamo_timed("fullgraph_capture"), diff --git a/torch/export/_trace.py b/torch/export/_trace.py index b38986ab070f7..856f23f68b19e 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -851,7 +851,9 @@ def use_legacy_dynamo_graph_capture() -> bool: f, constraints=constraints, dynamic_shapes=dynamic_shapes ) else: - dynamo_graph_capture = dynamo_graph_capture_for_export(f) + dynamo_graph_capture = torch._dynamo.config.patch( + replay_side_effects=False + )(dynamo_graph_capture_for_export(f)) # We can't serialize entire fake mode yet, so this is to make sure # things like copy.deepcopy(ep.graph_module) not crash. # see test_export.py::test_custom_tag_metadata_re_export From fcc78410a8e51107a7f4a15431e57da137741aee Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 19 Nov 2025 09:44:14 -0800 Subject: [PATCH 064/230] [pytree][compile] Slightly faster TreeSpec init (#168024) Helps with reducing Dynamo tracing time. Earlier the generator object would cause more polyfills. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168024 Approved by: https://github.com/williamwen42 --- torch/_dynamo/polyfills/pytree.py | 7 +++++-- torch/utils/_pytree.py | 7 +++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/torch/_dynamo/polyfills/pytree.py b/torch/_dynamo/polyfills/pytree.py index 1c6283e8a038f..63a72afa43a6d 100644 --- a/torch/_dynamo/polyfills/pytree.py +++ b/torch/_dynamo/polyfills/pytree.py @@ -201,8 +201,11 @@ def __post_init__(self, /) -> None: num_children = 0 else: assert callable(self._unflatten_func) - num_nodes = sum((spec.num_nodes for spec in self._children), start=1) - num_leaves = sum(spec.num_leaves for spec in self._children) + num_nodes = 1 + num_leaves = 0 + for child in self._children: + num_nodes += child.num_nodes + num_leaves += child.num_leaves num_children = len(self._children) object.__setattr__(self, "num_nodes", num_nodes) diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 3d2e4d110b6b2..16877719718af 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -1113,8 +1113,11 @@ def __post_init__(self) -> None: num_leaves = 1 num_children = 0 else: - num_nodes = sum((spec.num_nodes for spec in self._children), start=1) - num_leaves = sum(spec.num_leaves for spec in self._children) + num_nodes = 1 + num_leaves = 0 + for child in self._children: + num_nodes += child.num_nodes + num_leaves += child.num_leaves num_children = len(self._children) object.__setattr__(self, "num_nodes", num_nodes) object.__setattr__(self, "num_leaves", num_leaves) From 159aa44a243a734e1be38ad821d9446254096a80 Mon Sep 17 00:00:00 2001 From: Johnson Wong Date: Wed, 19 Nov 2025 23:22:03 +0000 Subject: [PATCH 065/230] Replace 2**31 with explicit int (#168046) Summary: Prior to #[164333](https://github.com/pytorch/pytorch/pull/164333), the 32-bit activation range was defined as `(int(-(2**31)), int(2**31 - 1))`. The `int` was deemed unnecessary, however torch.jit.script inteprets 2**31 as a float (example error P2044074770). Instead of reverting to the old definition (introduced by our team in #[150870](https://github.com/pytorch/pytorch/pull/150870), which could be "fixed" again), I replace with the value directly. Test Plan: N8628317 demonstrates the error without this diff. No error on this diff. Differential Revision: D87278420 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168046 Approved by: https://github.com/cyyever, https://github.com/yangw-dev --- torch/ao/nn/quantized/reference/modules/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/ao/nn/quantized/reference/modules/utils.py b/torch/ao/nn/quantized/reference/modules/utils.py index 8311b0e5697b0..7bdbcd4a6739e 100644 --- a/torch/ao/nn/quantized/reference/modules/utils.py +++ b/torch/ao/nn/quantized/reference/modules/utils.py @@ -202,7 +202,7 @@ def _quantize_weight_decomposed( _DTYPE_TO_QVALUE_BOUNDS: dict[torch.dtype, tuple[int, int]] = { torch.uint8: (0, 255), torch.int8: (-128, 127), - torch.int32: ((-(2**31)), (2**31 - 1)), + torch.int32: (-2147483648, 2147483647), # torch.jit interprets 2**31 as a float } # TODO: add an util function for converting qdtype to dtype @@ -265,7 +265,7 @@ def _dequantize_weight_decomposed( _DTYPE_TO_QVALUE_BOUNDS: dict[torch.dtype, tuple[int, int]] = { torch.uint8: (0, 255), torch.int8: (-128, 127), - torch.int32: ((-(2**31)), (2**31 - 1)), + torch.int32: (-2147483648, 2147483647), # torch.jit interprets 2**31 as a float } # TODO: add an util function for converting qdtype to dtype _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE = { From 7bfe8b0b942610e647277721d75ca666f9402b27 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Wed, 19 Nov 2025 23:31:08 +0000 Subject: [PATCH 066/230] [codemod][lowrisk] Remove unused exception parameter from caffe2/torch/csrc/Storage.cpp (#168184) Summary: `-Wunused-exception-parameter` has identified an unused exception parameter. This diff removes it. This: ``` try { ... } catch (exception& e) { // no use of e } ``` should instead be written as ``` } catch (exception&) { ``` If the code compiles, this is safe to land. Test Plan: Sandcastle Differential Revision: D87467930 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168184 Approved by: https://github.com/malfet --- torch/csrc/Storage.cpp | 2 +- torch/csrc/distributed/autograd/engine/dist_engine.cpp | 6 +++--- torch/csrc/distributed/c10d/PyProcessGroup.hpp | 2 +- torch/csrc/distributed/c10d/TCPStore.cpp | 2 +- torch/csrc/distributed/c10d/python_callback_work.cpp | 4 ++-- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index 671c28adef3e3..33dfa3132cb45 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -245,7 +245,7 @@ static PyObject* THPStorage_pynew( storage_set(storage, i, value); } } - } catch (const std::exception& e) { + } catch (const std::exception&) { TORCH_CHECK( THPStorageStr "(): tried to construct a storage from a sequence (", THPUtils_typename(sequence), diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.cpp b/torch/csrc/distributed/autograd/engine/dist_engine.cpp index 156c9efd5ca98..2104e3030d445 100644 --- a/torch/csrc/distributed/autograd/engine/dist_engine.cpp +++ b/torch/csrc/distributed/autograd/engine/dist_engine.cpp @@ -448,7 +448,7 @@ c10::intrusive_ptr DistEngine:: const variable_list& grads = futureGrads.constValue().toTensorVector(); TORCH_INTERNAL_ASSERT(grads.size() == outputEdges.size()); accumulateGradFuture->markCompleted(c10::IValue()); - } catch (std::exception& e) { + } catch (std::exception&) { accumulateGradFuture->setErrorIfNeeded(std::current_exception()); } }); @@ -527,7 +527,7 @@ c10::intrusive_ptr DistEngine::executeSendFunctionAsync( // Perform cleanup at the end of the backward pass (before // we mark the future as completed). DistEngine::getInstance().cleanupBackwardPass(autogradContext); - } catch (std::exception& e) { + } catch (std::exception&) { callbackFuture->setErrorIfNeeded(std::current_exception()); return; } @@ -539,7 +539,7 @@ c10::intrusive_ptr DistEngine::executeSendFunctionAsync( callbackFuture->setError(rpcFuture.exception_ptr()); } }); - } catch (std::exception& e) { + } catch (std::exception&) { callbackFuture->setErrorIfNeeded(std::current_exception()); } }); diff --git a/torch/csrc/distributed/c10d/PyProcessGroup.hpp b/torch/csrc/distributed/c10d/PyProcessGroup.hpp index afec6bbe11a9a..39474c49052fe 100644 --- a/torch/csrc/distributed/c10d/PyProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/PyProcessGroup.hpp @@ -339,7 +339,7 @@ class TORCH_PYTHON_API PythonOnCompletionHook { eptr = std::make_exception_ptr(std::runtime_error(e.what())); e.restore(); PyErr_Clear(); - } catch (std::exception& e) { + } catch (std::exception&) { eptr = std::current_exception(); } } diff --git a/torch/csrc/distributed/c10d/TCPStore.cpp b/torch/csrc/distributed/c10d/TCPStore.cpp index 0c1cf581887d1..9f566032b5b3c 100644 --- a/torch/csrc/distributed/c10d/TCPStore.cpp +++ b/torch/csrc/distributed/c10d/TCPStore.cpp @@ -270,7 +270,7 @@ TCPStore::TCPStore(std::string host, const TCPStoreOptions& opts) // server successfully started C10D_DEBUG("The server has started on port = {}.", server_->port()); addr_.port = server_->port(); - } catch (const SocketError& e) { + } catch (const SocketError&) { bool useAgentStore = getCvarBool({"TORCHELASTIC_USE_AGENT_STORE"}, false); int masterPort = getCvarInt({"MASTER_PORT"}, 0); if (useAgentStore && masterPort == opts.port) { diff --git a/torch/csrc/distributed/c10d/python_callback_work.cpp b/torch/csrc/distributed/c10d/python_callback_work.cpp index 47bef1831a480..685b3cceeaa4c 100644 --- a/torch/csrc/distributed/c10d/python_callback_work.cpp +++ b/torch/csrc/distributed/c10d/python_callback_work.cpp @@ -40,14 +40,14 @@ bool PythonCallbackWork::wait(std::chrono::milliseconds timeout) { } return success; - } catch (py::error_already_set& e) { + } catch (py::error_already_set&) { // Capture the Python exception and store it finish(std::current_exception()); if (!future_->completed()) { future_->setErrorIfNeeded(std::current_exception()); } throw; - } catch (const std::exception& e) { + } catch (const std::exception&) { // Capture any C++ exception and store it finish(std::current_exception()); if (!future_->completed()) { From a5e9dce4f7509f2538e4487b15ec78596e4a8518 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 19 Nov 2025 13:42:25 -0800 Subject: [PATCH 067/230] [DTensor] Fix mypy on register_op_strategy (#167673) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167673 Approved by: https://github.com/zpcore, https://github.com/Skylion007 --- torch/distributed/tensor/_ops/utils.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/torch/distributed/tensor/_ops/utils.py b/torch/distributed/tensor/_ops/utils.py index 9a4ce12ed82fa..14a08e37c698f 100644 --- a/torch/distributed/tensor/_ops/utils.py +++ b/torch/distributed/tensor/_ops/utils.py @@ -4,8 +4,7 @@ import itertools import operator from collections.abc import Callable, Iterable, Sequence -from typing import cast, Optional, TypeVar, Union -from typing_extensions import ParamSpec +from typing import cast, Optional, TypeAlias, TypeVar, Union import torch from torch._prims_common import DimsSequenceType, DimsType @@ -30,10 +29,6 @@ ) -_T = TypeVar("_T") -_P = ParamSpec("_P") - - # convenient wrapper to register sharding propagation rules def register_prop_rule( op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]], @@ -54,11 +49,20 @@ def wrapper( return wrapper -def register_op_strategy( - op, schema_info=None -) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: - # pyre-fixme[2]: Parameter must be annotated. +# Note: +# using TypeVar here allows the registration decorator to preserve the specific type info of the wrapped strategy, +# while hardcoding the typing on the wrapper (e.g. Callable[[OpSchema], StrategyType]) would mean mypy would treat +# the return value of the wrapped strategy as always being a `StrategyType` even if it were a derived class like +# MyStrategyType(StrategyType). +_OpSchemaT = TypeVar("_OpSchemaT", bound=OpSchema) +_StrategyTypeT = TypeVar("_StrategyTypeT", bound=StrategyType) +_ShardingStrategyFunc: TypeAlias = Callable[[_OpSchemaT], _StrategyTypeT] + +def register_op_strategy( + op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]], + schema_info: Optional[RuntimeSchemaInfo] = None, +) -> Callable[[_ShardingStrategyFunc], _ShardingStrategyFunc]: # For every ATen op that accepts any args in this list, # the arg itself can impact the strides (and potentially the sharding strategy) # of the output tensor. @@ -68,7 +72,7 @@ def register_op_strategy( "memory_format", ] - def wrapper(impl): + def wrapper(impl: _ShardingStrategyFunc) -> _ShardingStrategyFunc: if isinstance(op, list): overloads = op else: From c9d944b614c8623ad704911938f1de124369abce Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 19 Nov 2025 13:42:26 -0800 Subject: [PATCH 068/230] [DTensor] Document some utils (#168113) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168113 Approved by: https://github.com/mlazos, https://github.com/zpcore ghstack dependencies: #167673 --- torch/distributed/tensor/_ops/utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torch/distributed/tensor/_ops/utils.py b/torch/distributed/tensor/_ops/utils.py index 14a08e37c698f..a19ce091e3748 100644 --- a/torch/distributed/tensor/_ops/utils.py +++ b/torch/distributed/tensor/_ops/utils.py @@ -163,7 +163,10 @@ def prod(xs: Iterable[int]) -> int: def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool: - """Check if the shape is shardable according to the spec.""" + """Check if the spec matches these criteria: + * any Shard placements in spec refer to valid tensor dims + * no empty local tensors (uneven sharding OK, as long as last rank has >0 size) + """ # number of shards in each tensor dimension shards_map = [1] * len(shape) for i, placement in enumerate(spec.placements): @@ -229,6 +232,9 @@ def infer_broadcast_dims_map( ) -> list[int]: # infer the broadcast dims map, where it maps from the common shape dim to the input shape dim # this is aligned with the broadcast semantics + # e.g. if common_shape = [1, 2, 3, 4] and input_shape = [2, 3, 4], + # broadcast_dims_map will be [-1, 0, 1, 2] + # meaning that dim 0 in the output has no mapping to the input, and dim 1 in the output maps to dim 0 in the input common_ndim = len(common_shape) input_ndim = len(input_shape) broadcast_dims_map = [-1] * common_ndim From a4a5d03779d876043b0a1f0c565659fc2298afd2 Mon Sep 17 00:00:00 2001 From: Rob Timpe Date: Wed, 19 Nov 2025 20:02:33 +0000 Subject: [PATCH 069/230] Update linalg.norm to match numpy's handling of degenerate inputs (#168086) See https://github.com/numpy/numpy/pull/28343 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168086 Approved by: https://github.com/williamwen42 --- aten/src/ATen/native/LinearAlgebra.cpp | 19 ++++- test/test_linalg.py | 83 +++++++++++++++++-- torch/_refs/linalg/__init__.py | 16 +++- .../_internal/opinfo/definitions/linalg.py | 15 +++- 4 files changed, 119 insertions(+), 14 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 07bdc19ec8ff7..169f340e955d6 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -2909,13 +2909,26 @@ Tensor linalg_matrix_norm( // Check A, dim, and dtype _linalg_matrix_norm_checks(A, dim_, opt_dtype, /*low_precision*/abs_ord != 2.); - auto max_min = [ord, keepdim](const Tensor& A, int64_t dim) { return ord > 0 ? A.amax(dim, keepdim) : A.amin(dim, keepdim); }; + auto max_min_wrapper = [ord, keepdim](const Tensor &A, int64_t dim) { + if (A.size(dim) == 0 && ord > 0) { + auto new_shape(DimVector(A.sizes())); + auto dim_ = maybe_wrap_dim(dim, A.dim()); + if (keepdim) { + new_shape[dim_] = 1; + } else { + new_shape.erase(std::begin(new_shape) + dim_); + } + return at::zeros(new_shape, A.options()); + } else { + return ord > 0 ? A.amax(dim, keepdim) : A.amin(dim, keepdim); + } + }; if (abs_ord == 2.) { // Move dims to the end auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], A.dim()); auto A_ = opt_dtype.has_value() ? A.to(*opt_dtype) : A; - auto result = max_min(at::linalg_svdvals(A_.permute(permutation)), -1); + auto result = max_min_wrapper(at::linalg_svdvals(A_.permute(permutation)), -1); if (keepdim) { auto permutation_reverse = create_reverse_permutation(std::move(permutation)); result = result.unsqueeze(-1).permute(permutation_reverse); @@ -2932,7 +2945,7 @@ Tensor linalg_matrix_norm( if (!keepdim && (dim_[0] < dim_[1])) { dim_[1]--; } - return max_min(at::linalg_vector_norm(A, 1., {dim_[0]}, keepdim, opt_dtype), dim_[1]); + return max_min_wrapper(at::linalg_vector_norm(A, 1., {dim_[0]}, keepdim, opt_dtype), dim_[1]); } } diff --git a/test/test_linalg.py b/test/test_linalg.py index 7e3a1ebaa6f3a..ed3ca079748fd 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -27,7 +27,7 @@ runOnRocmArch, MI300_ARCH, NAVI_ARCH, TEST_CUDA) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, has_cusolver, has_hipsolver, - onlyCPU, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, + onlyCPU, skipIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, onlyNativeDeviceTypes, dtypesIfCUDA, onlyCUDA, skipMeta, skipCUDAIfNoCusolver, skipCUDAIfNotRocm, dtypesIfMPS, largeTensorTest) from torch.testing import make_tensor @@ -2015,6 +2015,7 @@ def run_test_case(input, ord, dim, keepdim): run_test_case(input, ord, dim, keepdim) # Test degenerate shape results match numpy for linalg.norm matrix norms + @skipIf(np.lib.NumpyVersion(np.__version__) < '2.3.0', 'Numpy changed handling of degenerate inputs in 2.3.0') @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) @@ -2043,13 +2044,13 @@ def run_test_case(input, ord, dim, keepdim, should_error): S = 10 test_cases = [ # input size, p settings that cause error, dim - ((0, 0), [1, 2, inf, -1, -2, -inf], None), - ((0, S), [2, inf, -2, -inf], None), - ((S, 0), [1, 2, -1, -2], None), + ((0, 0), [-1, -2, -inf], None), + ((0, S), [-2, -inf], None), + ((S, 0), [-1, -2], None), ((S, S, 0), [], (0, 1)), ((1, S, 0), [], (0, 1)), - ((0, 0, S), [1, 2, inf, -1, -2, -inf], (0, 1)), - ((0, 0, S), [1, 2, inf, -1, -2, -inf], (1, 0)), + ((0, 0, S), [-1, -2, -inf], (0, 1)), + ((0, 0, S), [-1, -2, -inf], (1, 0)), ] for keepdim in [True, False]: @@ -2058,6 +2059,76 @@ def run_test_case(input, ord, dim, keepdim, should_error): for ord in ord_matrix: run_test_case(input, ord, dim, keepdim, ord in error_ords) + # TODO this is redundant with test_norm_matrix_degenerate_shapes above, + # remove when old numpy versions are dropped + @skipIf(np.lib.NumpyVersion(np.__version__) >= '2.3.0', 'Numpy changed handling of degenerate inputs in 2.3.0') + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + def test_norm_matrix_degenerate_shapes_old_numpy(self, device, dtype): + def run_test_case(input, ord, dim, keepdim, should_error): + msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}' + input_numpy = input.cpu().numpy() + ops = [torch.linalg.norm] + + if ord is not None and dim is not None: + ops.append(torch.linalg.matrix_norm) + + if should_error == 'both': + with self.assertRaises(ValueError): + np.linalg.norm(input_numpy, ord, dim, keepdim) + for op in ops: + with self.assertRaises(IndexError): + op(input, ord, dim, keepdim) + elif should_error == 'np_only': + with self.assertRaises(ValueError): + np.linalg.norm(input_numpy, ord, dim, keepdim) + for op in ops: + result = op(input, ord, dim, keepdim) + dim_ = dim + if dim_ is None: + dim_ = (0, 1) + expected_shape = list(input.shape) + if keepdim: + expected_shape[dim_[0]] = 1 + expected_shape[dim_[1]] = 1 + else: + del expected_shape[max(dim_)] + del expected_shape[min(dim_)] + expected = torch.zeros(expected_shape, dtype=dtype.to_real()) + self.assertEqual(expected, result, msg=msg) + else: + result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim) + for op in ops: + result = op(input, ord, dim, keepdim) + self.assertEqual(result, result_numpy, msg=msg) + + ord_matrix = ['fro', 'nuc', 1, 2, inf, -1, -2, -inf, None] + S = 10 + test_cases = [ + # input size, p settings that cause error, + # p settings that error numpy but not torch, dim + ((0, 0), [-1, -2, -inf], [inf, 1, 2], None), + ((0, S), [-2, -inf], [inf, 2], None), + ((S, 0), [-1, -2], [1, 2], None), + ((S, S, 0), [], [], (0, 1)), + ((1, S, 0), [], [], (0, 1)), + ((0, 0, S), [-1, -2, -inf], [inf, 1, 2], (0, 1)), + ((0, 0, S), [-1, -2, -inf], [inf, 1, 2], (1, 0)), + ] + + for keepdim in [True, False]: + for input_size, error_ords, np_error_ords, dim in test_cases: + input = torch.randn(*input_size, dtype=dtype, device=device) + for ord in ord_matrix: + if ord in error_ords: + should_error = 'both' + elif ord in np_error_ords: + should_error = 'np_only' + else: + should_error = 'no' + run_test_case(input, ord, dim, keepdim, should_error) + def test_norm_fastpaths(self, device): x = torch.randn(3, 5, device=device) diff --git a/torch/_refs/linalg/__init__.py b/torch/_refs/linalg/__init__.py index f4281674bd118..4d194f773f859 100644 --- a/torch/_refs/linalg/__init__.py +++ b/torch/_refs/linalg/__init__.py @@ -269,12 +269,24 @@ def matrix_norm( max_min = partial(torch.amax if ord > 0.0 else torch.amin, keepdim=keepdim) + def _max_min_wrapper(A, dim): + # pyrefly: ignore [unsupported-operation] + if A.size(dim) == 0 and ord > 0.0: + new_size = list(A.size()) + if keepdim: + new_size[dim] = 1 + else: + del new_size[dim] + return torch.zeros(new_size, dtype=A.dtype, device=A.device) + else: + return max_min(A, dim) + if abs_ord == 2.0: if dtype is not None: A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment] # pyrefly: ignore [index-error] perm = _backshift_permutation(dim[0], dim[1], A.ndim) - result = max_min(svdvals(prims.transpose(A, perm)), dim=-1) + result = _max_min_wrapper(svdvals(prims.transpose(A, perm)), dim=-1) if keepdim: inv_perm = _inverse_permutation(perm) result = prims.transpose(torch.unsqueeze(result, -1), inv_perm) @@ -286,7 +298,7 @@ def matrix_norm( dim0, dim1 = dim1, dim0 if not keepdim and (dim0 < dim1): dim1 -= 1 - return max_min( + return _max_min_wrapper( vector_norm(A, 1.0, dim=dim0, keepdim=keepdim, dtype=dtype), dim1 ) diff --git a/torch/testing/_internal/opinfo/definitions/linalg.py b/torch/testing/_internal/opinfo/definitions/linalg.py index ae5a468ddd6ae..da75f82815507 100644 --- a/torch/testing/_internal/opinfo/definitions/linalg.py +++ b/torch/testing/_internal/opinfo/definitions/linalg.py @@ -382,9 +382,6 @@ def sample_inputs_linalg_norm( elif is_matrix_norm: dims_to_check = { None: (0,), - np.inf: (0,), - 2: (0, 1), - 1: (1,), -1: (1,), -2: (0, 1), -np.inf: (0,), @@ -395,6 +392,18 @@ def sample_inputs_linalg_norm( # have non-zero size. continue + no_grad_dims_to_check = { + np.inf: (0,), + 2: (0, 1), + 1: (1,), + }.get(ord, ()) + + if ( + any(test_size[d] == 0 for d in no_grad_dims_to_check) + and requires_grad + ): + continue + if variant == "subgradient_at_zero": yield SampleInput( torch.zeros( From 90c57aa3b3adc7f52c73090db6e2e7e8caebb762 Mon Sep 17 00:00:00 2001 From: Ruben Rodriguez Buchillon Date: Tue, 18 Nov 2025 14:40:45 -0800 Subject: [PATCH 070/230] conv: refactor for lookup table support (#167179) \# why enable configuring conv operations through the lookup table \# what - move kwargs etc into template_heuristics - add conv specific kernel inputs - add lookup table e2e test for conv \# testing ``` python3 -bb -m pytest test/inductor/test_lookup_table.py -k "conv2d" -v python3 -bb -m pytest test/inductor/test_max_autotune.py -k "conv" -v ``` Differential Revision: [D86474839](https://our.internmc.facebook.com/intern/diff/D86474839) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167179 Approved by: https://github.com/drisspg --- test/inductor/test_lookup_table.py | 156 +++++++++- torch/_inductor/kernel/conv.py | 130 ++++---- torch/_inductor/kernel_inputs.py | 125 +++++++- .../_inductor/template_heuristics/__init__.py | 2 +- torch/_inductor/template_heuristics/conv.py | 287 ++++++++++++++++++ 5 files changed, 611 insertions(+), 89 deletions(-) create mode 100644 torch/_inductor/template_heuristics/conv.py diff --git a/test/inductor/test_lookup_table.py b/test/inductor/test_lookup_table.py index 250a822267833..32be3e730a6fb 100644 --- a/test/inductor/test_lookup_table.py +++ b/test/inductor/test_lookup_table.py @@ -2,14 +2,18 @@ import re import unittest from functools import partial -from typing import Any, Optional, Union +from typing import Any, Optional from unittest.mock import patch import torch import torch.nn as nn from torch._inductor import config as inductor_config from torch._inductor.choices import InductorChoices -from torch._inductor.kernel_inputs import MMKernelInputs +from torch._inductor.kernel_inputs import ( + ConvKernelInputs, + MMKernelInputs, + SerializableValue, +) from torch._inductor.lookup_table.choices import LookupTableChoices from torch._inductor.select_algorithm import ( add_preprocessing_fn, @@ -54,7 +58,7 @@ class MockMMKernelInputs(MMKernelInputs): def __init__( self, tensors: list[torch.Tensor], - scalars: Optional[dict[str, Union[float, int]]] = None, + scalars: Optional[dict[str, SerializableValue]] = None, mat1_idx: int = -2, mat2_idx: int = -1, ): @@ -80,6 +84,37 @@ def device_type(self) -> Optional[str]: return self.tensors[0].device.type +class MockConvKernelInputs(ConvKernelInputs): + """Mock ConvKernelInputs that subclasses the real class and uses real tensors""" + + def __init__( + self, + tensors: list[torch.Tensor], + scalars: Optional[dict[str, SerializableValue]] = None, + x_idx: int = 0, + weight_idx: int = 1, + bias_idx: Optional[int] = None, + ): + """Initialize with real tensors, creating mock nodes for the base class""" + mock_nodes = [MockTensorNode(t) for t in tensors] + super().__init__( + mock_nodes, scalars, x_idx=x_idx, weight_idx=weight_idx, bias_idx=bias_idx + ) + self.tensors = tensors # Keep reference to original tensors + + def shapes_hinted(self) -> tuple[tuple[int, ...], ...]: + """Delegate to symbolic since real tensors already have int shapes""" + return self.shapes_symbolic() + + def strides_hinted(self) -> tuple[tuple[int, ...], ...]: + """Delegate to symbolic since real tensors already have int strides""" + return self.strides_symbolic() # pyre-ignore + + @property + def device_type(self) -> Optional[str]: + return self.tensors[0].device.type + + class BaseLookupTableTest(TestCase): """Base class for lookup table tests with common setup and utilities""" @@ -103,7 +138,7 @@ def create_mock_mm_kernel_inputs( shapes: Optional[list[tuple[int, ...]]] = None, device: torch.device = torch.device("cuda"), dtype: torch.dtype = torch.float32, - scalars: Optional[dict[str, Union[float, int]]] = None, + scalars: Optional[dict[str, SerializableValue]] = None, ) -> MockMMKernelInputs: """Create MockMMKernelInputs with real tensors""" if shapes is None: @@ -1055,6 +1090,119 @@ def test_template_hash_filtering_e2e(self): with patch.object(inductor_config.lookup_table, "check_src_hash", True): self.run_model("mm", tensors) + @fresh_cache() + def test_conv2d_lookup_table_entry_e2e(self): + """Test end-to-end conv2d with lookup table entry - verifies config is picked up and produces valid results""" + import torch._inductor.kernel.conv + + # Create input tensors with specific shapes for conv2d + # Input: [batch=2, in_channels=3, height=32, width=32] + # Weight: [out_channels=64, in_channels=3, kernel_h=3, kernel_w=3] + # Make them channels-last to match what conv lowering uses + x = torch.randn(2, 3, 32, 32, device=self.device, dtype=torch.float16).to( + memory_format=torch.channels_last + ) + weight = torch.randn(64, 3, 3, 3, device=self.device, dtype=torch.float16).to( + memory_format=torch.channels_last + ) + + # Define conv parameters - use these SAME values everywhere + stride = (1, 1) + padding = (1, 1) + dilation = (1, 1) + groups = 1 + + # Create MockConvKernelInputs using the SAME tensors and SAME scalar values + mock_scalars = { + "stride": stride, + "padding": padding, + "dilation": dilation, + "transposed": False, + "output_padding": (0, 0), + "groups": groups, + } + mock_kernel_inputs = MockConvKernelInputs([x, weight], mock_scalars) + + # Create lookup key for "convolution" operation + choices_handler = LookupTableChoices() + lookup_key = choices_handler.make_lookup_key(mock_kernel_inputs, "convolution") + + # Get the exact template UID from conv2d_template + template_uid = torch._inductor.kernel.conv.conv2d_template.uid + + # Create a precisely configured conv2d config + # IMPORTANT: Only include per-config tunable parameters! + # Static parameters (KERNEL_H, STRIDE_H, GROUPS, UNROLL, ALLOW_TF32) are + # automatically generated by get_extra_kwargs() and should NOT be in the lookup table + conv2d_config = { + "template_id": template_uid, + # Per-config tunable parameters only (what you'd tune via autotuning) + "BLOCK_M": 64, + "BLOCK_N": 64, + "BLOCK_K": 32, + "num_stages": 2, + "num_warps": 4, + } + + # Setup lookup table + inductor_config.lookup_table.table = {lookup_key: [conv2d_config]} + + def validate_conv_choice(choices): + assert len(choices) == 1, ( + f"Expected 1 choice from lookup table, got {len(choices)}" + ) + assert isinstance(choices[0], TritonTemplateCaller), ( + f"Expected TritonTemplateCaller, got {type(choices[0])}" + ) + assert "convolution2d" in choices[0].name, ( + f"Expected 'convolution2d' in name, got {choices[0].name}" + ) + return choices + + add_preprocessing_fn(validate_conv_choice) + + # Create and compile the model using the SAME weight tensor + class SimpleConv2d(nn.Module): + def __init__(self, weight): + super().__init__() + self.register_buffer("weight", weight) + + def forward(self, x): + return torch.conv2d( + x, + self.weight, + bias=None, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + model = SimpleConv2d(weight).to(self.device) + + with inductor_config.patch({"max_autotune": True, "max_autotune_gemm": True}): + compiled_model = torch.compile(model) + result = compiled_model(x) # Use the SAME x tensor + + # Output shape: [batch=2, out_channels=64, out_h=32, out_w=32] + # (same spatial dims due to padding=1, stride=1, kernel=3) + expected_shape = (2, 64, 32, 32) + self.assertEqual( + result.shape, + expected_shape, + f"Expected shape {expected_shape}, got {result.shape}", + ) + + self.assertFalse( + torch.isnan(result).any().item(), + "Output contains NaN values", + ) + + self.assertFalse( + torch.isinf(result).any().item(), + "Output contains Inf values", + ) + if __name__ == "__main__": from torch._inductor.utils import is_big_gpu diff --git a/torch/_inductor/kernel/conv.py b/torch/_inductor/kernel/conv.py index 8e5a2aa09d4ea..2179364c7d0c2 100644 --- a/torch/_inductor/kernel/conv.py +++ b/torch/_inductor/kernel/conv.py @@ -8,6 +8,7 @@ from torch._inductor.codegen.rocm.ck_conv_template import CKGroupedConvFwdTemplate from .. import config, ir +from ..kernel_inputs import ConvKernelInputs from ..lowering import ( add_layout_constraint, constrain_to_fx_strides, @@ -16,7 +17,9 @@ ) from ..select_algorithm import ( autotune_select_algorithm, + ChoiceCaller, ExternKernelChoice, + KernelTemplate, SymbolicGridFn, TritonTemplate, ) @@ -542,34 +545,40 @@ def channels_last_conv(): x = ir.ExternKernel.require_stride_order(x, req_stride_order) # type: ignore[assignment] weight = ir.ExternKernel.require_stride_order(weight, req_stride_order) # type: ignore[assignment] - ordered_kwargs_for_cpp_kernel = [ - "stride", - "padding", - "dilation", - "transposed", - "output_padding", - "groups", - ] - if bias is None: - args = [x, weight] - kwargs["bias"] = None # type: ignore[typeddict-unknown-key] - ordered_kwargs_for_cpp_kernel.insert(0, "bias") - else: - args = [x, weight, bias] + # Create ConvKernelInputs for unified template configuration + # Only include bias in input_nodes when it's not None + # - For Triton templates: bias is always None here (peeled off earlier), so input_nodes = [x, weight] + # - For ATEN: input_nodes = [x, weight] when bias is None, [x, weight, bias] when bias is present + if bias is not None: bias.realize() bias.freeze_layout() V.graph.sizevars.guard_int_seq(bias.get_size()) + input_nodes = [x, weight, bias] + bias_idx = 2 + else: + input_nodes = [x, weight] + bias_idx = None + + kernel_inputs = ConvKernelInputs( + input_nodes, + scalars={ + "stride": stride, + "padding": padding, + "dilation": dilation, + "transposed": transposed, + "output_padding": output_padding, + "groups": groups, + }, + x_idx=0, + weight_idx=1, + bias_idx=bias_idx, + ) + + # Build list of templates to try + templates: list[ExternKernelChoice | KernelTemplate] = [] - choices = [] if torch._inductor.utils._use_conv_autotune_backend("ATEN"): - choices = [ - aten_convolution.bind( - args, - layout, - ordered_kwargs_for_cpp_kernel, - **kwargs, - ) - ] + templates.append(aten_convolution) if ( torch._inductor.utils._use_conv_autotune_backend("TRITON") @@ -587,60 +596,23 @@ def channels_last_conv(): and is_zeros(padding) and groups == 1 ): - choices.append(aten_conv1x1_via_mm.bind(args, layout)) - - conv_configs = V.choices.get_conv_configs(device_type) - - dtype_size = x.get_dtype().itemsize - for cfg in conv_configs( - sympy_product([x.get_size()[0], *x.get_size()[2:]]), - out_chan, - in_chan, - dtype_size=dtype_size, - ): - if ndim == 2: - conv2d_template.maybe_append_choice( - choices, - input_nodes=(x, weight), - layout=layout, - KERNEL_H=kernel_shape[0], - KERNEL_W=kernel_shape[1], - STRIDE_H=stride[0], - STRIDE_W=stride[1], - PADDING_H=padding[0], - PADDING_W=padding[1], - GROUPS=groups, - # TODO(jansel): try unroll for bigger kernels once fixed: - # https://github.com/triton-lang/triton/issues/1254 - UNROLL=is_ones(kernel_shape), - ALLOW_TF32=torch.backends.cudnn.allow_tf32, - num_stages=cfg.num_stages, - num_warps=cfg.num_warps, - **cfg.kwargs, - ) - elif ndim == 3: - conv3d_template.maybe_append_choice( - choices, - input_nodes=(x, weight), - layout=layout, - KERNEL_D=kernel_shape[0], - KERNEL_H=kernel_shape[1], - KERNEL_W=kernel_shape[2], - STRIDE_D=stride[0], - STRIDE_H=stride[1], - STRIDE_W=stride[2], - PADDING_D=padding[0], - PADDING_H=padding[1], - PADDING_W=padding[2], - GROUPS=groups, - # TODO(jansel): try unroll for bigger kernels once fixed: - # https://github.com/triton-lang/triton/issues/1254 - UNROLL=is_ones(kernel_shape), - ALLOW_TF32=torch.backends.cudnn.allow_tf32, - num_stages=cfg.num_stages, - num_warps=cfg.num_warps, - **cfg.kwargs, - ) + templates.append(aten_conv1x1_via_mm) + + # Add appropriate template based on ndim + if ndim == 2: + templates.append(conv2d_template) + elif ndim == 3: + templates.append(conv3d_template) + + # Initialize choices list and extend with template configs + choices: list[ChoiceCaller] = [] + choices.extend( + V.choices.get_template_configs( + kernel_inputs, + templates, + "convolution", + ) + ) if use_ck_conv_template(layout): CKGroupedConvFwdTemplate.add_ck_conv_choices( choices, @@ -652,7 +624,9 @@ def channels_last_conv(): groups=groups, n_spatial_dimensions=ndim, ) - return autotune_select_algorithm("convolution", choices, args, layout) + return autotune_select_algorithm( + "convolution", choices, kernel_inputs.nodes(), layout + ) @register_lowering(aten._convolution) diff --git a/torch/_inductor/kernel_inputs.py b/torch/_inductor/kernel_inputs.py index c579cf7565772..9e585a4880106 100644 --- a/torch/_inductor/kernel_inputs.py +++ b/torch/_inductor/kernel_inputs.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Sequence from typing import Any, Optional, TYPE_CHECKING, Union import torch @@ -12,10 +13,12 @@ if TYPE_CHECKING: - from collections.abc import Sequence - import sympy +# Type aliases for serializable scalar values +Serializable = Union[int, float, bool] +SerializableValue = Union[Serializable, Sequence[Serializable]] + class KernelInputs(ABC): """ @@ -27,7 +30,7 @@ class KernelInputs(ABC): def __init__( self, input_nodes: list[Any], - scalars: Optional[dict[str, Union[float, int]]] = None, + scalars: Optional[dict[str, SerializableValue]] = None, out_dtype: Optional[torch.dtype] = None, ): """ @@ -183,7 +186,7 @@ def out_dtype(self) -> torch.dtype: The output dtype """ - def get_scalar(self, name: str) -> Union[float, int]: + def get_scalar(self, name: str) -> SerializableValue: """ Get the scalar value for a given name. @@ -191,7 +194,7 @@ def get_scalar(self, name: str) -> Union[float, int]: name: Name of the scalar to get Returns: - The scalar value + The scalar value (can be int, float, bool, or tuple of these types) """ assert name in self._scalars, f"Scalar {name} not found, but required" return self._scalars[name] @@ -216,7 +219,7 @@ class MMKernelInputs(KernelInputs): def __init__( self, input_nodes: list[Any], - scalars: Optional[dict[str, Union[float, int]]] = None, + scalars: Optional[dict[str, SerializableValue]] = None, out_dtype: Optional[torch.dtype] = None, mat1_idx: int = -2, mat2_idx: int = -1, @@ -336,3 +339,113 @@ def mnk_hinted(self) -> tuple[int, int, int]: assert k == k_check, f"K dimensions don't match: {k} vs {k_check}" return (m, n, k) + + +class ConvKernelInputs(KernelInputs): + """ + Specialized KernelInputs for convolution operations. + Stores input tensor, weight tensor, and optional bias, along with conv parameters. + """ + + def __init__( + self, + input_nodes: list[Any], + scalars: Optional[dict[str, SerializableValue]] = None, + out_dtype: Optional[torch.dtype] = None, + x_idx: int = 0, + weight_idx: int = 1, + bias_idx: Optional[int] = None, + ): + """ + Initialize with convolution input nodes. + + Args: + input_nodes: List containing [x, weight] or [x, weight, bias] + scalars: Dict with conv params (stride, padding, dilation, groups, transposed, output_padding) + out_dtype: Optional output dtype + x_idx: Index of input tensor (default: 0) + weight_idx: Index of weight tensor (default: 1) + bias_idx: Index of bias tensor if present (default: None) + """ + super().__init__(input_nodes, scalars, out_dtype) + assert len(input_nodes) >= 2, "Expected at least 2 input nodes (x, weight)" + + self._x_idx = x_idx + self._weight_idx = weight_idx + self._bias_idx = bias_idx + + # Validate that required scalars are present + required_scalars = [ + "stride", + "padding", + "dilation", + "transposed", + "output_padding", + "groups", + ] + for key in required_scalars: + assert key in self._scalars, f"Conv requires scalar '{key}'" + + def out_dtype(self) -> torch.dtype: + """ + Get the output dtype, whether passed in or inferred from the nodes + + Returns: + The output dtype + """ + if self._out_dtype is not None: + return self._out_dtype + return self._input_nodes[self._x_idx].get_dtype() + + def output_layout(self, flexible: bool = True) -> Layout: + """ + Handle output layout generation for convolution. + + Args: + flexible: If True, return FlexibleLayout, otherwise FixedLayout + + Returns: + Layout for the convolution output + """ + from torch._inductor.kernel.conv import conv_layout + + x = self._input_nodes[self._x_idx] + weight = self._input_nodes[self._weight_idx] + bias = self._input_nodes[self._bias_idx] if self._bias_idx is not None else None + + # Use existing conv_layout function + # We know the types here because conv requires these specific scalar types + layout = conv_layout( + x, + weight, + bias, + self._scalars["stride"], # type: ignore[arg-type] + self._scalars["padding"], # type: ignore[arg-type] + self._scalars["dilation"], # type: ignore[arg-type] + self._scalars["transposed"], # type: ignore[arg-type] + self._scalars["output_padding"], # type: ignore[arg-type] + self._scalars["groups"], # type: ignore[arg-type] + ) + + # TODO: Handle flexible vs fixed based on config if needed + return layout + + def get_x_weight_bias(self) -> tuple[Any, Any, Optional[Any]]: + """ + Get x, weight, and optional bias nodes. + + Returns: + Tuple of (x, weight, bias) where bias may be None + """ + bias = self._input_nodes[self._bias_idx] if self._bias_idx is not None else None + return self._input_nodes[self._x_idx], self._input_nodes[self._weight_idx], bias + + def spatial_dims(self) -> tuple[Any, ...]: + """ + Get spatial dimensions from input tensor (H, W for 2D, D, H, W for 3D). + + Returns: + Tuple of spatial dimension sizes + """ + x_shape = self._input_nodes[self._x_idx].get_size() + return x_shape[2:] # Skip batch and channel dims diff --git a/torch/_inductor/template_heuristics/__init__.py b/torch/_inductor/template_heuristics/__init__.py index eb3d731525ea8..8b980816c56dc 100644 --- a/torch/_inductor/template_heuristics/__init__.py +++ b/torch/_inductor/template_heuristics/__init__.py @@ -1,6 +1,6 @@ # NOTE: add new template heuristics here, so they get imported and registered # TODO: write a simple glob if there are many heuristics to auto import them in the right order -from . import aten, base, contiguous_mm, decompose_k, registry, triton +from . import aten, base, contiguous_mm, conv, decompose_k, registry, triton # expose the entry function from .registry import get_template_heuristic diff --git a/torch/_inductor/template_heuristics/conv.py b/torch/_inductor/template_heuristics/conv.py new file mode 100644 index 0000000000000..7333b5a679bd8 --- /dev/null +++ b/torch/_inductor/template_heuristics/conv.py @@ -0,0 +1,287 @@ +from __future__ import annotations + +from typing import Any, cast, TYPE_CHECKING + +import torch + +from ..kernel.conv import aten_convolution, conv2d_template, conv3d_template +from ..kernel_inputs import ConvKernelInputs +from ..utils import is_ones, sympy_product +from ..virtualized import V +from .base import TemplateConfigHeuristics +from .registry import register_template_heuristic +from .triton import ( + CPUConfigHeuristic, + CUDAConfigHeuristic, + MTIAConfigHeuristic, + ROCmConfigHeuristic, + XPUConfigHeuristic, +) + + +if TYPE_CHECKING: + from collections.abc import Generator + + from ..kernel_inputs import KernelInputs + + +class ConvTemplateConfigMixin(TemplateConfigHeuristics): + """ + Mixin for conv templates that converts config lists to template kwargs. + Similar to MMTemplateConfigMixin but for convolutions. + + This handles generating both the static template kwargs (KERNEL_H, STRIDE_H, etc.) + and the per-config kwargs (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps). + """ + + # Type hint for methods from BaseConfigHeuristic + get_conv_configs: Any + + def get_extra_kwargs( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> dict[str, Any]: + """ + Return template kwargs that don't change per-config. + These are derived from kernel_inputs and must include all template parameters. + + Args: + kernel_inputs: ConvKernelInputs containing input tensors and conv params + op_name: Operation name (e.g., "convolution") + + Returns: + Dict of static template kwargs (KERNEL_H, STRIDE_H, GROUPS, etc.) + """ + assert isinstance(kernel_inputs, ConvKernelInputs), ( + f"ConvTemplateConfigMixin requires ConvKernelInputs, got {type(kernel_inputs)}" + ) + + x, weight, bias = kernel_inputs.get_x_weight_bias() + + # Extract kernel shape from weight: [out_chan, in_chan, *kernel_shape] + weight_size = V.graph.sizevars.guard_int_seq(weight.get_size()) + kernel_shape = weight_size[2:] # Skip out_chan, in_chan + ndim = len(kernel_shape) + + # Extract scalars + stride = cast(tuple[int, ...], kernel_inputs.get_scalar("stride")) + padding = cast(tuple[int, ...], kernel_inputs.get_scalar("padding")) + groups = cast(int, kernel_inputs.get_scalar("groups")) + + # Check if we should unroll (only for 1x1 kernels) + unroll = is_ones(kernel_shape) + + # Build kwargs dict based on ndim + kwargs: dict[str, Any] = { + "GROUPS": groups, + "UNROLL": unroll, + "ALLOW_TF32": torch.backends.cudnn.allow_tf32, + } + + if ndim == 2: + kwargs.update( + { + "KERNEL_H": kernel_shape[0], + "KERNEL_W": kernel_shape[1], + "STRIDE_H": stride[0], + "STRIDE_W": stride[1], + "PADDING_H": padding[0], + "PADDING_W": padding[1], + } + ) + elif ndim == 3: + kwargs.update( + { + "KERNEL_D": kernel_shape[0], + "KERNEL_H": kernel_shape[1], + "KERNEL_W": kernel_shape[2], + "STRIDE_D": stride[0], + "STRIDE_H": stride[1], + "STRIDE_W": stride[2], + "PADDING_D": padding[0], + "PADDING_H": padding[1], + "PADDING_W": padding[2], + } + ) + + return kwargs + + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + """ + Yield per-config kwargs (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps). + + Args: + kernel_inputs: ConvKernelInputs containing input tensors + op_name: Operation name + + Yields: + Dict of per-config kwargs for each configuration to try + """ + assert isinstance(kernel_inputs, ConvKernelInputs), ( + "ConvTemplateConfigMixin requires ConvKernelInputs" + ) + + x, weight, bias = kernel_inputs.get_x_weight_bias() + + # Calculate dimensions for heuristics + weight_size = weight.get_size() + out_chan = weight_size[0] + in_chan = weight_size[1] + + # Batch * spatial dimensions product + x_size = x.get_size() + batch_spatial_product = sympy_product([x_size[0], *x_size[2:]]) + + # Get conv config generator from self (which is a BaseConfigHeuristic subclass) + conv_configs_generator = self.get_conv_configs() + + dtype_size = x.get_dtype().itemsize + + # Generate configs (reusing mm preprocess_mm_configs machinery) + for c in conv_configs_generator( + batch_spatial_product, + out_chan, + in_chan, + dtype_size=dtype_size, + op_name="conv", + ): + # Yield per-config kwargs + yield { + "BLOCK_M": c.kwargs.get("BLOCK_M"), + "BLOCK_N": c.kwargs.get("BLOCK_N"), + "BLOCK_K": c.kwargs.get("BLOCK_K"), + "num_stages": c.num_stages, + "num_warps": c.num_warps, + } + + +# ATEN convolution heuristic (no per-config tuning) +@register_template_heuristic(aten_convolution.uid, None) +class ATenConvConfigHeuristic(TemplateConfigHeuristics): + """ + Pseudo heuristic for ATen convolution. + ATen doesn't have configs to tune - it's a single choice. + """ + + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> Generator[dict[str, Any], None, None]: + # ATen doesn't have per-config kwargs to tune + yield dict() + + def get_extra_kwargs( + self, + kernel_inputs: KernelInputs, + op_name: str, + ) -> dict[str, Any]: + """ + ATen gets stride, padding, etc. as ordered kwargs for the C++ kernel. + """ + assert isinstance(kernel_inputs, ConvKernelInputs) + + # Extract scalar values from kernel_inputs + stride = cast(tuple[int, ...], kernel_inputs.get_scalar("stride")) + padding = cast(tuple[int, ...], kernel_inputs.get_scalar("padding")) + dilation = cast(tuple[int, ...], kernel_inputs.get_scalar("dilation")) + transposed = cast(bool, kernel_inputs.get_scalar("transposed")) + output_padding = cast( + tuple[int, ...], kernel_inputs.get_scalar("output_padding") + ) + groups = cast(int, kernel_inputs.get_scalar("groups")) + + # Check if bias is None to match old behavior + # When bias is None: input_nodes = [x, weight], add 'bias' to kwargs and ordered list + # When bias is present: input_nodes = [x, weight, bias], don't add 'bias' to kwargs + x, weight, bias = kernel_inputs.get_x_weight_bias() + + kwargs: dict[str, Any] = { + "stride": stride, + "padding": padding, + "dilation": dilation, + "transposed": transposed, + "output_padding": output_padding, + "groups": groups, + } + + if bias is None: + # When bias is None, torch.convolution expects it as a kwarg + kwargs["bias"] = None + kwargs["ordered_kwargs_for_cpp_kernel"] = [ + "bias", + "stride", + "padding", + "dilation", + "transposed", + "output_padding", + "groups", + ] + else: + # When bias is present, it's passed as a positional arg (3rd in input_nodes) + kwargs["ordered_kwargs_for_cpp_kernel"] = [ + "stride", + "padding", + "dilation", + "transposed", + "output_padding", + "groups", + ] + + return kwargs + + +# CUDA Conv2D/Conv3D heuristics +@register_template_heuristic( + conv2d_template.uid, + "cuda", + register=torch.version.hip is None, +) +@register_template_heuristic( + conv3d_template.uid, + "cuda", + register=torch.version.hip is None, +) +class CUDAConvTemplateConfigHeuristic(ConvTemplateConfigMixin, CUDAConfigHeuristic): + """Conv template heuristic for CUDA.""" + + +# ROCm Conv2D/Conv3D heuristics +@register_template_heuristic( + conv2d_template.uid, + "cuda", + register=torch.version.hip is not None, +) +@register_template_heuristic( + conv3d_template.uid, + "cuda", + register=torch.version.hip is not None, +) +class ROCmConvTemplateConfigHeuristic(ConvTemplateConfigMixin, ROCmConfigHeuristic): + """Conv template heuristic for ROCm.""" + + +# CPU Conv2D/Conv3D heuristics +@register_template_heuristic(conv2d_template.uid, "cpu") +@register_template_heuristic(conv3d_template.uid, "cpu") +class CPUConvTemplateConfigHeuristic(ConvTemplateConfigMixin, CPUConfigHeuristic): + """Conv template heuristic for CPU.""" + + +# XPU Conv2D/Conv3D heuristics +@register_template_heuristic(conv2d_template.uid, "xpu") +@register_template_heuristic(conv3d_template.uid, "xpu") +class XPUConvTemplateConfigHeuristic(ConvTemplateConfigMixin, XPUConfigHeuristic): + """Conv template heuristic for XPU.""" + + +# MTIA Conv2D/Conv3D heuristics +@register_template_heuristic(conv2d_template.uid, "mtia") +@register_template_heuristic(conv3d_template.uid, "mtia") +class MTIAConvTemplateConfigHeuristic(ConvTemplateConfigMixin, MTIAConfigHeuristic): + """Conv template heuristic for MTIA.""" From 6461548b4dcca39077aa83c7199173ad946e0898 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Thu, 20 Nov 2025 00:00:41 +0000 Subject: [PATCH 071/230] [vLLM] Update xformers and remove flashinfer-python (#168141) A couple of changes: * Update `xformers==0.0.33.post1`. This is the latest version for 2.9 release * Remove `flashinfer-python` build, we don't need to compile it anymore after https://github.com/vllm-project/vllm/pull/26443. This is now a regular dependency for vLLM * I also switch the base image to 12.9.1 to match what is vLLM is using nowadays ### Testing https://github.com/pytorch/pytorch/actions/runs/19490188972/job/55780754518 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168141 Approved by: https://github.com/yangw-dev --- .ci/lumen_cli/cli/lib/core/vllm/vllm_test.py | 1 - .github/ci_configs/vllm/Dockerfile | 35 +++----------------- .github/scripts/prepare_vllm_wheels.sh | 2 +- 3 files changed, 6 insertions(+), 32 deletions(-) diff --git a/.ci/lumen_cli/cli/lib/core/vllm/vllm_test.py b/.ci/lumen_cli/cli/lib/core/vllm/vllm_test.py index 224f078788702..aea27ca7dddae 100644 --- a/.ci/lumen_cli/cli/lib/core/vllm/vllm_test.py +++ b/.ci/lumen_cli/cli/lib/core/vllm/vllm_test.py @@ -84,7 +84,6 @@ def __init__(self, args: Any): self.VLLM_TEST_WHLS_REGEX = [ "xformers/*.whl", "vllm/vllm*.whl", - "flashinfer-python/flashinfer*.whl", ] def prepare(self): diff --git a/.github/ci_configs/vllm/Dockerfile b/.github/ci_configs/vllm/Dockerfile index a57793151de66..13fdb036abfe7 100644 --- a/.github/ci_configs/vllm/Dockerfile +++ b/.github/ci_configs/vllm/Dockerfile @@ -1,4 +1,4 @@ -ARG CUDA_VERSION=12.8.1 +ARG CUDA_VERSION=12.9.1 ARG PYTHON_VERSION=3.12 # BUILD_BASE_IMAGE: used to setup python build xformers, and vllm wheels, It can be replaced with a different base image from local machine, @@ -124,7 +124,7 @@ RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH' git clone https://github.com/facebookresearch/xformers.git pushd xformers - git checkout v0.0.32.post2 + git checkout v0.0.33.post1 git submodule update --init --recursive python3 setup.py bdist_wheel --dist-dir=../xformers-dist --verbose popd @@ -256,7 +256,7 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match" # Use copy mode to avoid hardlink failures with Docker cache mounts ENV UV_LINK_MODE=copy -# Install build and runtime dependencies, this is needed for flashinfer install +# Install build and runtime dependencies COPY requirements/build.txt requirements/build.txt COPY use_existing_torch.py use_existing_torch.py RUN python3 use_existing_torch.py @@ -294,33 +294,9 @@ RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system /wheels/xformers/*.whl --verbose -# Build FlashInfer from source -ARG torch_cuda_arch_list='8.0;8.9;9.0a;10.0a;12.0' -ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} - -# TODO(elainewy): remove this once vllm commit is updated, and install flashinfer from pip -# see https://github.com/pytorch/pytorch/pull/165274#issuecomment-3408531784 -ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" -ARG FLASHINFER_GIT_REF="v0.2.14.post1" - -RUN --mount=type=cache,target=/root/.cache/uv \ - git clone --depth 1 --recursive --shallow-submodules \ - --branch ${FLASHINFER_GIT_REF} \ - ${FLASHINFER_GIT_REPO} flashinfer \ - && echo "Building FlashInfer with AOT for arches: ${torch_cuda_arch_list}" \ - && cd flashinfer \ - && python3 -m flashinfer.aot \ - && python3 -m build --no-isolation --wheel --outdir ../wheels/flashinfer \ - && cd .. \ - && rm -rf flashinfer - -# Install FlashInfer -RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install --system wheels/flashinfer/*.whl --verbose - # Logging to confirm the torch versions -RUN pip freeze | grep -E 'torch|xformers|vllm|flashinfer' -RUN uv pip freeze | grep -i '^torch\|^torchvision\|^torchaudio\|^xformers\|^vllm\|^flashinfer' > build_summary.txt +RUN pip freeze | grep -E 'torch|xformers|vllm' +RUN uv pip freeze | grep -i '^torch\|^torchvision\|^torchaudio\|^xformers\|^vllm' > build_summary.txt ################### VLLM INSTALLED IMAGE #################### @@ -331,4 +307,3 @@ FROM scratch as export-wheels COPY --from=base /workspace/xformers-dist /wheels/xformers COPY --from=build /workspace/vllm-dist /wheels/vllm COPY --from=vllm-base /workspace/build_summary.txt /wheels/build_summary.txt -COPY --from=vllm-base /workspace/wheels/flashinfer /wheels/flashinfer-python diff --git a/.github/scripts/prepare_vllm_wheels.sh b/.github/scripts/prepare_vllm_wheels.sh index 62362c7ff207c..0d56a4ef43273 100755 --- a/.github/scripts/prepare_vllm_wheels.sh +++ b/.github/scripts/prepare_vllm_wheels.sh @@ -88,7 +88,7 @@ repackage_wheel() { ${PYTHON_EXECUTABLE} -mpip install wheel==0.45.1 pushd externals/vllm/wheels -for package in xformers flashinfer-python vllm; do +for package in xformers vllm; do repackage_wheel $package done popd From cda1b8d23a31b7d93fc366265cb1fe66067eec1b Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 19 Nov 2025 19:34:06 +0000 Subject: [PATCH 072/230] [FlexFlash] Blackwell fwd support (#167040) Need to land: https://github.com/Dao-AILab/flash-attention/pull/1985 ^^First^^ Pull Request resolved: https://github.com/pytorch/pytorch/pull/167040 Approved by: https://github.com/Skylion007, https://github.com/albanD ghstack dependencies: #168017 --- test/inductor/test_flex_flash.py | 50 ++++++++++++++++--- .../kernel/flex/flex_flash_attention.py | 2 +- 2 files changed, 43 insertions(+), 9 deletions(-) diff --git a/test/inductor/test_flex_flash.py b/test/inductor/test_flex_flash.py index 40ca12ea9f526..0c877ff33b5e2 100644 --- a/test/inductor/test_flex_flash.py +++ b/test/inductor/test_flex_flash.py @@ -6,7 +6,11 @@ import torch from torch._inductor.kernel.flex.flex_flash_attention import ensure_flash_available from torch._inductor.test_case import TestCase as InductorTestCase -from torch.nn.attention.flex_attention import create_block_mask, flex_attention +from torch.nn.attention.flex_attention import ( + _DEFAULT_SPARSE_BLOCK_SIZE, + create_block_mask, + flex_attention, +) from torch.profiler import profile, ProfilerActivity from torch.testing._internal.common_device_type import ( dtypes, @@ -105,6 +109,28 @@ def create_test_tensors( return q, k, v +def _create_block_mask_for_device( + mask_mod, batch_size, num_heads, q_len, kv_len, *, device +): + """Match FlexAttention's block-height expectations per compute capability.""" + q_block = _DEFAULT_SPARSE_BLOCK_SIZE + kv_block = _DEFAULT_SPARSE_BLOCK_SIZE + dev = torch.device(device) + if dev.type == "cuda": + major, _ = torch.cuda.get_device_capability(dev) + if major >= 10: + q_block *= 2 + return create_block_mask( + mask_mod, + batch_size, + num_heads, + q_len, + kv_len, + device=device, + BLOCK_SIZE=(q_block, kv_block), + ) + + @contextmanager def cuda_kernel_profiler(kernel_pattern="flash_attncute"): """Context manager for profiling CUDA kernels.""" @@ -312,7 +338,9 @@ def test_flash_attention_with_block_mask(self, device, dtype): def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx - block_mask = create_block_mask(causal_mask, 2, 4, 512, 512, device=device) + block_mask = _create_block_mask_for_device( + causal_mask, 2, 4, 512, 512, device=device + ) flash_vs_triton(q, k, v, block_mask=block_mask) @dtypes(torch.float16, torch.bfloat16) @@ -323,7 +351,9 @@ def test_flash_attention_block_mask_with_score_mod(self, device, dtype): def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx - block_mask = create_block_mask(causal_mask, 2, 4, 512, 512, device=device) + block_mask = _create_block_mask_for_device( + causal_mask, 2, 4, 512, 512, device=device + ) flash_vs_triton(q, k, v, score_mod=_times_two, block_mask=block_mask) @dtypes(torch.float16, torch.bfloat16) @@ -339,7 +369,9 @@ def custom_mask(b, h, q_idx, kv_idx): bias_value = mask_bias[h] return (q_idx >= kv_idx) | (bias_value > 0) - block_mask = create_block_mask(custom_mask, 2, 4, 512, 512, device=device) + block_mask = _create_block_mask_for_device( + custom_mask, 2, 4, 512, 512, device=device + ) flash_vs_triton(q, k, v, block_mask=block_mask) @dtypes(torch.float16, torch.bfloat16) @@ -368,7 +400,7 @@ def document_mask(b, _h, q_idx, kv_idx): doc_id_kv = document_ids[b, kv_idx] return doc_id_q == doc_id_kv - block_mask = create_block_mask( + block_mask = _create_block_mask_for_device( document_mask, 2, 1, seq_len, seq_len, device=device ) flash_vs_triton(q, k, v, block_mask=block_mask) @@ -390,7 +422,7 @@ def mask_with_view_buffer(b, h, q_idx, kv_idx): double_bias = bias_value * 2 return (q_idx >= kv_idx) | (double_bias > 0) - block_mask = create_block_mask( + block_mask = _create_block_mask_for_device( mask_with_view_buffer, batch_size, num_heads, @@ -418,7 +450,7 @@ def dual_buffer_mask(b, h, q_idx, kv_idx): bias_cond = (head_term + batch_term).to(torch.float32) > 0 return causal | bias_cond - block_mask = create_block_mask( + block_mask = _create_block_mask_for_device( dual_buffer_mask, batch_size, num_heads, seq_len, seq_len, device=device ) flash_vs_triton(q, k, v, block_mask=block_mask) @@ -461,7 +493,9 @@ def mask_with_buffer(b, h, q_idx, kv_idx): bias_value = mask_bias[h] return (q_idx >= kv_idx) | (bias_value > 0) - block_mask = create_block_mask(mask_with_buffer, 2, 4, 512, 512, device=device) + block_mask = _create_block_mask_for_device( + mask_with_buffer, 2, 4, 512, 512, device=device + ) flash_vs_triton(q, k, v, score_mod=score_with_buffer, block_mask=block_mask) diff --git a/torch/_inductor/kernel/flex/flex_flash_attention.py b/torch/_inductor/kernel/flex/flex_flash_attention.py index 05d1040d35c9b..78a79f9664b68 100644 --- a/torch/_inductor/kernel/flex/flex_flash_attention.py +++ b/torch/_inductor/kernel/flex/flex_flash_attention.py @@ -133,7 +133,7 @@ def is_trivial_mask_graph(graph_module: GraphModule) -> bool: @functools.lru_cache(maxsize=1) def _supports_nontrivial_mask_graphs() -> bool: """Currently only supported on Hopper (SM90) GPUs.""" - return torch.cuda.get_device_capability()[0] == 9 + return torch.cuda.get_device_capability()[0] in [9, 10] def _can_use_flex_flash_attention( From a6bfe2dedaa6655e7558e52e2b760936615828ff Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 20 Nov 2025 00:21:11 +0000 Subject: [PATCH 073/230] Revert "[invoke_subgraph] Don't run the graph twice when autograd enabled (#167245)" This reverts commit 789240bae27c957fb59f737a8e171bd221bf4a40. Reverted https://github.com/pytorch/pytorch/pull/167245 on behalf of https://github.com/yangw-dev due to the base pr is broken internal tests in the stack ([comment](https://github.com/pytorch/pytorch/pull/167245#issuecomment-3555175850)) --- test/higher_order_ops/test_with_effects.py | 6 +- torch/_higher_order_ops/invoke_subgraph.py | 64 ++-------------------- 2 files changed, 11 insertions(+), 59 deletions(-) diff --git a/test/higher_order_ops/test_with_effects.py b/test/higher_order_ops/test_with_effects.py index e995959afba47..38e38c9e13f01 100644 --- a/test/higher_order_ops/test_with_effects.py +++ b/test/higher_order_ops/test_with_effects.py @@ -960,7 +960,11 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1): ) recorded_list.clear() - out2 = ep.module()(x) + # TODO: seems like invoke_subgraph's py_autograd impl calls the subgraph + # eagerly twice. Once for get_output_metadata and then once for + # InvokeSubgraphAutogradOp. This causes record_memory to be called twice. + with torch.no_grad(): + out2 = ep.module()(x) self.assertEqual(len(recorded_list), 4) self.assertTrue(torch.allclose(model(x)[0], out2[0])) diff --git a/torch/_higher_order_ops/invoke_subgraph.py b/torch/_higher_order_ops/invoke_subgraph.py index bb0d6cef3ee6f..7d066e132e011 100644 --- a/torch/_higher_order_ops/invoke_subgraph.py +++ b/torch/_higher_order_ops/invoke_subgraph.py @@ -305,62 +305,6 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None): def get_output_metadata(subgraph, *operands): - """ - Extract metadata about the subgraph outputs WITHOUT executing the subgraph. - This avoids running side-effectful operations twice (once here, once in forward). - We analyze the graph structure statically to extract metadata. - """ - # Unwrap FunctionalizeCtxWrapper if present - if isinstance(subgraph, FunctionalizeCtxWrapper): - subgraph = subgraph.subgraph - - # If not a GraphModule, fall back to execution-based metadata extraction - if not isinstance(subgraph, torch.fx.GraphModule): - return _get_output_metadata_by_execution(subgraph, *operands) - - output_metadata = OutputMetadata() - - # Extract output arguments from the output node - # The output node has args=(output_values,) where output_values is a tuple/list - output_node = next(reversed(subgraph.graph.find_nodes(op="output"))) - output_metadata.num_fw_outs = len(output_node.args[0]) - - for idx, output_arg in enumerate(output_node.args[0]): - if not isinstance(output_arg, torch.fx.Node): - if isinstance(output_arg, int): - output_metadata.indexes_with_symint.add(idx) - output_metadata.indexes_with_no_grad.add(idx) - continue - - # Check node metadata for type information - if output_arg.meta.get("val") is None: - # If we don't have complete metadata for all outputs, fall back to execution - # This is important for correctness (e.g., detecting SymInts) even though it - # runs side-effectful operations - return _get_output_metadata_by_execution(subgraph, *operands) - - val = output_arg.meta["val"] - if isinstance(val, torch.SymInt): - output_metadata.indexes_with_symint.add(idx) - output_metadata.indexes_with_no_grad.add(idx) - elif isinstance(val, torch.Tensor): - # Check if tensor requires grad from metadata - if hasattr(val, "requires_grad") and not val.requires_grad: - output_metadata.indexes_with_no_grad.add(idx) - else: - # Non-tensor, non-symint (shouldn't happen but be safe) - output_metadata.indexes_with_no_grad.add(idx) - - return output_metadata - - -def _get_output_metadata_by_execution(subgraph, *operands): - """ - Fallback: Extract metadata by executing the subgraph. - This should only be used when static analysis fails. - WARNING: This will run side-effectful operations! - """ - with suspend_functionalization(), disable_functional_mode(): with disable_proxy_modes_tracing(): # args are functional tensors, generate some example tensors @@ -380,15 +324,19 @@ def _get_output_metadata_by_execution(subgraph, *operands): num_fw_outs = len(fw_outs) + # Collect the indexes of none in the output to check that the grad + # is None at the corresponding index in the backward. This check is + # performed in the autograd.Function - InvokeSubgraphAutogradOp. + # Also collect the indexes of no_grad in the output to filter out + # the grad_outs in the `backward` method. output_metadata = OutputMetadata() - output_metadata.num_fw_outs = num_fw_outs + output_metadata.num_fw_outs = num_fw_outs for idx, fw_out in enumerate(fw_outs): if isinstance(fw_out, torch.SymInt): output_metadata.indexes_with_symint.add(idx) elif not fw_out.requires_grad: output_metadata.indexes_with_no_grad.add(idx) - return output_metadata From ca6175c8f0acb7fea3c5cd088fc55409ba03567f Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 20 Nov 2025 00:24:05 +0000 Subject: [PATCH 074/230] Revert "[hoo] Invoke subgraph + effect (#167231)" This reverts commit f49833de54450b03b808a5b9ad774ce14ff2c8a2. Reverted https://github.com/pytorch/pytorch/pull/167231 on behalf of https://github.com/yangw-dev due to the diff breaks tests internally ([comment](https://github.com/pytorch/pytorch/pull/167231#issuecomment-3555183647)) --- test/export/test_converter.py | 2 +- test/export/test_passes.py | 15 +- test/export/test_torchbind.py | 12 +- test/higher_order_ops/test_with_effects.py | 98 -------- torch/_guards.py | 18 -- torch/_higher_order_ops/invoke_subgraph.py | 50 ---- torch/_library/effects.py | 15 -- torch/export/_remove_effect_tokens_pass.py | 267 +++++++++------------ torch/export/_unlift.py | 24 +- torch/fx/node.py | 4 +- 10 files changed, 135 insertions(+), 370 deletions(-) diff --git a/test/export/test_converter.py b/test/export/test_converter.py index 5b608503a1168..e739e5c346677 100644 --- a/test/export/test_converter.py +++ b/test/export/test_converter.py @@ -1405,7 +1405,7 @@ def func3(x): # noqa: F841 ) # qnnpack not supported on s390x @xfailIfS390X - def test_ts2ep_convert_quantized_model1(self): + def test_ts2ep_convert_quantized_model(self): class Standalone(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/export/test_passes.py b/test/export/test_passes.py index 866eeaaee3986..9cf442c27a2bb 100644 --- a/test/export/test_passes.py +++ b/test/export/test_passes.py @@ -640,13 +640,16 @@ def forward(self, x): self.assertExpectedInline( without_token_ep.graph_module.code.strip(), """\ -def forward(self, obj_attr, x): - takes_foo_tuple_return_default = torch.ops._TorchScriptTesting.takes_foo_tuple_return.default(foo = obj_attr, x = x); x = None - getitem_1 = takes_foo_tuple_return_default[0] - getitem_2 = takes_foo_tuple_return_default[1]; takes_foo_tuple_return_default = None +def forward(self, token, obj_attr, x): + with_effects = torch.ops.higher_order.with_effects(token, torch.ops._TorchScriptTesting.takes_foo_tuple_return.default, foo = obj_attr, x = x); token = x = None + getitem = with_effects[0] + getitem_1 = with_effects[1] + getitem_2 = with_effects[2]; with_effects = None add = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = getitem_2 = None - takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(foo = obj_attr, x = add); obj_attr = add = None - return (takes_foo_default,)""", # noqa: B950 + with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, foo = obj_attr, x = add); getitem = obj_attr = add = None + getitem_3 = with_effects_1[0] + getitem_4 = with_effects_1[1]; with_effects_1 = None + return (getitem_3, getitem_4)""", # noqa: B950 ) def test_fakify_script_objects(self): diff --git a/test/export/test_torchbind.py b/test/export/test_torchbind.py index adf0986811648..246122433e06c 100644 --- a/test/export/test_torchbind.py +++ b/test/export/test_torchbind.py @@ -461,9 +461,9 @@ def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) attr = self.attr _guards_fn = self._guards_fn(x); _guards_fn = None - takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, x) - takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default); attr = takes_foo_default = None - add = torch.ops.aten.add.Tensor(x, takes_foo_default_1); x = takes_foo_default_1 = None + takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, x) + takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default_1); attr = takes_foo_default_1 = None + add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None return pytree.tree_unflatten((add,), self._out_spec)""", # noqa: B950 ) self.assertExpectedInline( @@ -1087,12 +1087,10 @@ def forward(self, token, tq, x): str(ep.graph_module.graph).strip(), """\ graph(): - %token : [num_users=1] = placeholder[target=token] %tq : [num_users=2] = placeholder[target=tq] %x : [num_users=1] = placeholder[target=x] - %with_effects : [num_users=1] = call_function[target=torch.ops.higher_order.with_effects](args = (%token, _TorchScriptTesting.queue_push.default, %tq, %x), kwargs = {}) - %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%with_effects, 0), kwargs = {}) - return (getitem, tq)""", # noqa: B950 + %queue_push_default : [num_users=0] = call_function[target=torch.ops._TorchScriptTesting.queue_push.default](args = (%tq, %x), kwargs = {}) + return (tq,)""", # noqa: B950 ) def test_deepcopy(self): diff --git a/test/higher_order_ops/test_with_effects.py b/test/higher_order_ops/test_with_effects.py index 38e38c9e13f01..2c4cf02bc1c8a 100644 --- a/test/higher_order_ops/test_with_effects.py +++ b/test/higher_order_ops/test_with_effects.py @@ -870,104 +870,6 @@ def forward(self, primals_2, getitem_1, tangents_1, tangents_token): finally: handle.destroy() - @unittest.skipIf(not TEST_CUDA, "triton") - def test_export_invoke_subgraph(self): - with torch.library._scoped_library("mylib", "FRAGMENT") as lib: - recorded_list = [] - - @torch.library.custom_op("mylib::record_memory", mutates_args=()) - def record_memory(prefix: str, module_name: str) -> None: - torch.cuda.synchronize() - mem_alloc = torch.cuda.memory_allocated() / 1024**2 - mem_reserved = torch.cuda.memory_reserved() / 1024**2 - memory_str = f"[{prefix}] {module_name}: allocated={mem_alloc:.2f} MB, reserved={mem_reserved:.2f} MB" - recorded_list.append(memory_str) - - @record_memory.register_fake - def record_memory_fake(prefix, module_name): - return - - record_memory.register_effect(_EffectType.ORDERED) - - class N(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear1 = torch.nn.Linear(1024, 1024) - self.relu = torch.nn.ReLU() - self.linear2 = torch.nn.Linear(1024, 1024) - - @torch.compiler.nested_compile_region - def forward(self, x): - torch.ops.mylib.record_memory("forward", "N") - x = self.linear1(x) - x = self.relu(x) - x = self.linear2(x) - return x - - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.mod_list = torch.nn.ModuleList(N() for _ in range(3)) - - def forward(self, x): - for m in self.mod_list: - x = m(x) - torch.ops.mylib.record_memory("forward", "N") - return (x,) - - model = M().to("cuda") - torch.cuda.reset_peak_memory_stats() - - x = torch.randn(32, 1024, requires_grad=True, device="cuda") - - ep = torch.export.export(model, (x,)) - ep = ep.run_decompositions() - self.assertEqual(len(list(ep.graph_module.named_modules())), 2) - - self.assertExpectedInline( - ep.graph_module.code.strip(), - """\ -def forward(self, token, p_mod_list_0_linear1_weight, p_mod_list_0_linear1_bias, p_mod_list_0_linear2_weight, p_mod_list_0_linear2_bias, p_mod_list_1_linear1_weight, p_mod_list_1_linear1_bias, p_mod_list_1_linear2_weight, p_mod_list_1_linear2_bias, p_mod_list_2_linear1_weight, p_mod_list_2_linear1_bias, p_mod_list_2_linear2_weight, p_mod_list_2_linear2_bias, x): - repeated_subgraph0 = self.repeated_subgraph0 - invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', token, x, p_mod_list_0_linear1_weight, p_mod_list_0_linear1_bias, p_mod_list_0_linear2_weight, p_mod_list_0_linear2_bias); repeated_subgraph0 = token = x = p_mod_list_0_linear1_weight = p_mod_list_0_linear1_bias = p_mod_list_0_linear2_weight = p_mod_list_0_linear2_bias = None - getitem = invoke_subgraph[0] - getitem_1 = invoke_subgraph[1]; invoke_subgraph = None - repeated_subgraph0_1 = self.repeated_subgraph0 - invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', getitem, getitem_1, p_mod_list_1_linear1_weight, p_mod_list_1_linear1_bias, p_mod_list_1_linear2_weight, p_mod_list_1_linear2_bias); repeated_subgraph0_1 = getitem = getitem_1 = p_mod_list_1_linear1_weight = p_mod_list_1_linear1_bias = p_mod_list_1_linear2_weight = p_mod_list_1_linear2_bias = None - getitem_2 = invoke_subgraph_1[0] - getitem_3 = invoke_subgraph_1[1]; invoke_subgraph_1 = None - repeated_subgraph0_2 = self.repeated_subgraph0 - invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_2, 'subgraph_0', getitem_2, getitem_3, p_mod_list_2_linear1_weight, p_mod_list_2_linear1_bias, p_mod_list_2_linear2_weight, p_mod_list_2_linear2_bias); repeated_subgraph0_2 = getitem_2 = getitem_3 = p_mod_list_2_linear1_weight = p_mod_list_2_linear1_bias = p_mod_list_2_linear2_weight = p_mod_list_2_linear2_bias = None - getitem_4 = invoke_subgraph_2[0] - getitem_5 = invoke_subgraph_2[1]; invoke_subgraph_2 = None - with_effects = torch.ops.higher_order.with_effects(getitem_4, torch.ops.mylib.record_memory.default, 'forward', 'N'); getitem_4 = None - getitem_6 = with_effects[0]; with_effects = None - return (getitem_6, getitem_5)""", - ) - - self.assertExpectedInline( - ep.graph_module.repeated_subgraph0.code.strip(), - """\ -def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1): - with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.mylib.record_memory.default, 'forward', 'N'); arg0_1 = None - getitem = with_effects[0]; with_effects = None - permute = torch.ops.aten.permute.default(arg2_1, [1, 0]); arg2_1 = None - addmm = torch.ops.aten.addmm.default(arg3_1, arg1_1, permute); arg3_1 = arg1_1 = permute = None - relu = torch.ops.aten.relu.default(addmm); addmm = None - permute_1 = torch.ops.aten.permute.default(arg4_1, [1, 0]); arg4_1 = None - addmm_1 = torch.ops.aten.addmm.default(arg5_1, relu, permute_1); arg5_1 = relu = permute_1 = None - return (getitem, addmm_1)""", - ) - - recorded_list.clear() - # TODO: seems like invoke_subgraph's py_autograd impl calls the subgraph - # eagerly twice. Once for get_output_metadata and then once for - # InvokeSubgraphAutogradOp. This causes record_memory to be called twice. - with torch.no_grad(): - out2 = ep.module()(x) - self.assertEqual(len(recorded_list), 4) - self.assertTrue(torch.allclose(model(x)[0], out2[0])) - if __name__ == "__main__": run_tests() diff --git a/torch/_guards.py b/torch/_guards.py index 1bd32fc7f08ec..32b796d71eea7 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -713,9 +713,6 @@ def __init__(self) -> None: self.lazy_bwd_cache: dict[ str, dict[tuple[object], tuple[torch.fx.GraphModule, int]] ] = defaultdict(dict) - self.effects_cache: dict[ - str, set - ] = {} # Maps identifier -> set of effect types def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None: self.dynamo_installed_submodules[fn_id].append(identifier) @@ -754,21 +751,6 @@ def get_lazy_bwd_entry( return self.lazy_bwd_cache[identifier].get(tangent_metadata, (None, None)) - def add_effects(self, identifier: str, effects: set) -> None: - """Store the effect types for a given invoke_subgraph identifier.""" - if prev_effects := self.effects_cache.get(identifier, None): - assert effects == prev_effects, ( - "Different number of effects were found for invoke_subgraph " - f"call with identifier {identifier}. \n" - f"Previously we had the following effects: {prev_effects}.\n" - f"But now we have: {effects}." - ) - self.effects_cache[identifier] = effects - - def get_effects(self, identifier: str) -> Optional[set]: - """Retrieve the effect types for a given invoke_subgraph identifier.""" - return self.effects_cache.get(identifier, None) - class HopDispatchSetCache: def __init__(self) -> None: diff --git a/torch/_higher_order_ops/invoke_subgraph.py b/torch/_higher_order_ops/invoke_subgraph.py index 7d066e132e011..e22b741631d3f 100644 --- a/torch/_higher_order_ops/invoke_subgraph.py +++ b/torch/_higher_order_ops/invoke_subgraph.py @@ -80,7 +80,6 @@ def __call__( assert all( isinstance(o, (torch.Tensor, int, torch.SymInt, torch.Generator)) for o in operands - if o is not None ), ( f"invoke_subgraph operands must be a list of tensors/ints/SymInts/Generator {operands}" ) @@ -563,34 +562,7 @@ def _(ctx, subgraph, identifier, *operands): do_auto_functionalize_v2, ) - # (in the functionalization metadata phase) Capture tokens before - tokens_before = dict(ctx.mode._tokens) - - # Check if this subgraph has effects stored in the cache - invoke_subgraph_cache = get_invoke_subgraph_cache() - effects = None - if invoke_subgraph_cache: - effects = invoke_subgraph_cache.get_effects(identifier) - - if effects: - assert len(effects) == 1, "Multiple effects within a subgraph NYI" - tokens = ctx.mode._tokens - effects = next(iter(effects)) - token_input = tokens[effects] - - operands = (token_input, *operands) - - def wrap_subgraph(subgraph): - def wrapped_subgraph(token, *args): - res = subgraph(*args) - return ctx.unwrap_tensors(ctx.mode._tokens[effects]), *res - - return wrapped_subgraph - - subgraph = wrap_subgraph(subgraph) - unwrapped_operands = ctx.unwrap_tensors(operands) - hop_instance = HopInstance.create(invoke_subgraph, subgraph, identifier, *operands) if can_auto_functionalize(hop_instance): # NOTE: [auto_functionalize x invoke_subgraph caching] @@ -615,28 +587,6 @@ def wrapped_subgraph(token, *args): # of invoke_subgraph ops if input aliasing/mutation is detected. functionalized_subgraph = FunctionalizeCtxWrapper(ctx, subgraph) out = invoke_subgraph(functionalized_subgraph, identifier, *unwrapped_operands) - - if effects: - (new_token, *out) = out - ctx.mode._tokens[effects] = new_token - - # (in the functionalization metadata phase) Capture tokens after and see if - # there are any differences (there are new effects or the token value for an - # effect type has changed) - tokens_after = dict(ctx.mode._tokens) - discovered_effects = set() - for effect_type, token in tokens_after.items(): - if effect_type not in tokens_before or tokens_before[effect_type] is not token: - discovered_effects.add(effect_type) - - if discovered_effects: - assert ctx.mode._allow_token_discovery, ( - f"Number of tokens changed by {len(discovered_effects)} when tracing subgraph {subgraph}." - ) - # Store discovered effects in the cache by identifier - if invoke_subgraph_cache: - invoke_subgraph_cache.add_effects(identifier, discovered_effects) - return ctx.wrap_tensors(out) diff --git a/torch/_library/effects.py b/torch/_library/effects.py index 3f765f380eab1..41fbaa4c1c7b4 100644 --- a/torch/_library/effects.py +++ b/torch/_library/effects.py @@ -35,18 +35,6 @@ def _set_default_effect(self) -> None: if namespace == "higher_order": return - # These classes do not have side effects as they just store quantization - # params, so we dont need to mark them as ordered - skip_classes = ( - "__torch__.torch.classes.quantized.Conv2dPackedParamsBase", - "__torch__.torch.classes.quantized.Conv3dPackedParamsBase", - "__torch__.torch.classes.quantized.EmbeddingPackedParamsBase", - "__torch__.torch.classes.quantized.LinearPackedParamsBase", - "__torch__.torch.classes.xnnpack.Conv2dOpContext", - "__torch__.torch.classes.xnnpack.LinearOpContext", - "__torch__.torch.classes.xnnpack.TransposeConv2dOpContext", - ) - opname = f"{namespace}::{opname}" if torch._C._get_operation_overload(opname, overload) is not None: # Since we call this when destroying the library, sometimes the @@ -54,9 +42,6 @@ def _set_default_effect(self) -> None: schema = torch._C._get_schema(opname, overload) for arg in schema.arguments: if isinstance(arg.type, torch.ClassType): - type_str = arg.type.str() # pyrefly: ignore[missing-attribute] - if type_str in skip_classes: - continue self._effect = EffectType.ORDERED return diff --git a/torch/export/_remove_effect_tokens_pass.py b/torch/export/_remove_effect_tokens_pass.py index 3ebcf6180d660..21930d81fe092 100644 --- a/torch/export/_remove_effect_tokens_pass.py +++ b/torch/export/_remove_effect_tokens_pass.py @@ -15,105 +15,113 @@ ) -def _get_custom_obj_for_node(node, inputs_to_lifted_custom_objs, constants): - """Extract the custom object from a node's arguments.""" - custom_obj_node = node - custom_obj_meta = custom_obj_node.meta["val"] # type: ignore[union-attr] - assert isinstance(custom_obj_meta, CustomObjArgument) - - if custom_obj_meta.fake_val: - return custom_obj_meta.fake_val - elif custom_obj_node.name in inputs_to_lifted_custom_objs: # type: ignore[union-attr] - return constants[inputs_to_lifted_custom_objs[custom_obj_node.name]] # type: ignore[union-attr] - else: - raise RuntimeError(f"Unable to find custom obj for node {node}") - - -def _replace_with_effects_node( - node, ep, inputs_to_lifted_custom_objs, output_tokens, input_tokens, module +def _remove_effect_tokens_from_graph_helper( + ep, num_tokens, input_token_names, output_token_names ): - """Replace a with_effects node with the underlying function call.""" - # Get the input nodes - token_node, func, *node_args = node.args - if token_node.op == "placeholder": - input_tokens.append(token_node) - - assert isinstance(func, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)) - - # Get the schema for the function - if func is torch.ops.higher_order.call_torchbind: - custom_obj = _get_custom_obj_for_node( - node_args[0], inputs_to_lifted_custom_objs, ep.constants - ) - schema = _get_schema(func, [custom_obj] + node_args[1:]) - else: - schema = _get_schema(func, node_args) - - # Create the replacement node - with module.graph.inserting_before(node): - new_node = module.graph.call_function(func, tuple(node_args), node.kwargs) - - # Update getitem nodes that extract outputs from with_effects - for user in list(node.users.keys()): - assert user.target is operator.getitem - # getitem(with_effects, 0) is the token node - if user.args[1] == 0: - for user_user in list(user.users.keys()): - if user_user.op == "output": - output_tokens.append(user) - - # Fix up the getitem nodes based on return count - if len(schema.returns) == 1: - # Single return: replace getitem(with_effects, 1) with the node itself - for user in list(node.users.keys()): - if user.args[1] == 1: + inputs_to_lifted_custom_objs = ep.graph_signature.inputs_to_lifted_custom_objs + + output_node = None + with_effect_nodes: list[torch.fx.Node] = [] + + # Output node need to check its args against output_token_names (collected from output_spec) + # Therefore, we only need to find the top-levele output node + output_node = next(reversed(ep.graph_module.graph.find_nodes(op="output"))) + for module in ep.graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + + for node in module.graph.nodes: + if not (node.op == "call_function" and node.target is with_effects): + continue + + with_effect_nodes.append(node) + + # Remove tokens from outputs + assert output_node is not None + output_args = output_node.args[0] + assert len(output_args) >= num_tokens + out_token_nodes = output_args[:num_tokens] + output_node.args = (tuple(output_args[num_tokens:]),) + for out_token in out_token_nodes: + assert out_token.name in output_token_names + out_token.users.clear() + ep.graph.erase_node(out_token) + + # Replace with_effects(token, func, args) with just func(args) + for node in reversed(with_effect_nodes): + func = node.args[1] + assert isinstance(func, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)) + + if func is torch.ops.higher_order.call_torchbind: + custom_obj_meta = node.args[2].meta["val"] # type: ignore[union-attr] + assert isinstance(custom_obj_meta, CustomObjArgument) + if custom_obj_meta.fake_val: + custom_obj = custom_obj_meta.fake_val + elif node.args[2].name in inputs_to_lifted_custom_objs: # type: ignore[union-attr] + custom_obj = ep.constants[ + inputs_to_lifted_custom_objs[node.args[2].name] # type: ignore[union-attr] + ] + else: + raise RuntimeError(f"Unable to find custom obj for node {node}") + schema = _get_schema(func, (custom_obj,) + node.args[3:]) + else: + schema = _get_schema(func, node.args[2:]) + + with ep.graph.inserting_before(node): + new_node = ep.graph.call_function(func, node.args[2:], node.kwargs) + for k, v in node.meta.items(): + new_node.meta[k] = v + if k == "unbacked_bindings": + # Remove the extra layer for effect token + old_bindings = new_node.meta[k] + new_bindings = { + k: path[1:] if path else path for k, path in old_bindings.items() + } + new_node.meta[k] = new_bindings + + node.replace_all_uses_with(new_node) + + # Update user getitem nodes + for user in list(new_node.users.keys()): + assert user.target is operator.getitem + # getitem(with_effects, 0) == token + if user.args[1] == 0: + ep.graph.erase_node(user) + + if len(schema.returns) == 1: + # If the function has 1 return then it will just directly return the + # result -- we don't need a getitem. So we can replace all the + # getitem(with_effects, 1) with just the note itself. + for user in list(new_node.users.keys()): + assert user.args[1] == 1 user.replace_all_uses_with(new_node) - new_node.meta["val"] = node.meta["val"][1] - elif len(schema.returns) > 1: - # Multiple returns: shift getitem indices down by 1 - for user in list(node.users.keys()): - if user.args[1] >= 1: - user.args = (new_node, user.args[1] - 1) - new_node.meta["val"] = node.meta["val"][1:] - else: - # No returns - assert len(schema.returns) == 0 - assert len(new_node.users) == 0 - new_node.meta["val"] = None - - # Copy metadata from old node to new node - for k, v in node.meta.items(): - new_node.meta[k] = v - if k == "unbacked_bindings": - # Remove the extra layer for effect token - old_bindings = new_node.meta[k] - new_bindings = { - k: path[1:] if path else path for k, path in old_bindings.items() - } - new_node.meta[k] = new_bindings - - -def _replace_invoke_subgraph_node(node, module, output_tokens, input_tokens): - """Replace an invoke_subgraph node to remove the token argument.""" - assert node.args[0].op == "get_attr" - submod = getattr(module, node.args[0].target) - if not submod.meta.get("has_with_effects", False): - return - - # Remove token from inputs - subgraph, identifier, token, *operands = node.args - node.args = (subgraph, identifier, *operands) - if token.op == "placeholder": - input_tokens.append(token) - - # Update getitem nodes to account for removed token output - for user in list(node.users.keys()): - if user.args[1] >= 1: - user.args = (node, user.args[1] - 1) - elif user.args[1] == 0: - for user_user in list(user.users.keys()): - if user_user.op == "output": - output_tokens.append(user) + + new_node.meta["val"] = node.meta["val"][1] + elif len(schema.returns) > 1: + # If the function has more than 1 return then since we got rid of + # the 1st return value (the token), we need to bump all the other + # getitem calls by 1 down + for user in list(new_node.users.keys()): + assert user.args[1] >= 1 + user.args = (user.args[0], user.args[1] - 1) + + new_node.meta["val"] = node.meta["val"][1:] + else: + assert len(schema.returns) == 0 + assert len(new_node.users) == 0 + new_node.meta["val"] = None + + ep.graph.erase_node(node) + + # Remove tokens from inputs + placeholders = [node for node in ep.graph.nodes if node.op == "placeholder"] + assert len(placeholders) >= num_tokens + inp_token_nodes = placeholders[:num_tokens] + for inp_token in inp_token_nodes: + assert inp_token.name in input_token_names + ep.graph.erase_node(inp_token) + + ep.graph.eliminate_dead_code() def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram: @@ -124,65 +132,6 @@ def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram: This function does an inplace modification on the given ExportedProgram. """ - print("before", ep) - inputs_to_lifted_custom_objs = ep.graph_signature.inputs_to_lifted_custom_objs - - # mark submodules with effects as having effects. This will be used in the following pass to remove effects from subgraphs - for _, module in ep.graph_module.named_modules(): - if not isinstance(module, torch.fx.GraphModule): - continue - - with_effect_nodes = [ - node for node in module.graph.nodes if node.target is with_effects - ] - if len(with_effect_nodes) > 0: - module.meta["has_with_effects"] = True - - # Process each module with the replace hook to ensure graph signature is updated - with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()): - for _, module in ep.graph_module.named_modules(): - if not isinstance(module, torch.fx.GraphModule): - continue - - input_tokens = [] - output_tokens = [] - - # Process with_effects and invoke_subgraph nodes - for node in module.graph.nodes: - if node.target is with_effects: - _replace_with_effects_node( - node, - ep, - inputs_to_lifted_custom_objs, - output_tokens, - input_tokens, - module, - ) - elif node.target is torch.ops.higher_order.invoke_subgraph: - _replace_invoke_subgraph_node( - node, module, output_tokens, input_tokens - ) - - # Remove tokens from the output node - if len(output_tokens) > 0: - output_node = next(reversed(module.graph.find_nodes(op="output"))) - output_args = output_node.args[0] - assert len(output_args) >= len(output_tokens), ( - f"{output_args} output arguments found\n" - f"{output_tokens} output tokens found\n" - f"{module.graph}" - ) - output_node.args = (tuple(output_args[len(output_tokens) :]),) - - module.graph.eliminate_dead_code() - - # Remove tokens from the input placeholders - for node in module.graph.nodes: - if node.op == "placeholder" and node in input_tokens: - module.graph.erase_node(node) - - module.recompile() - num_tokens: int = 0 input_token_names: list[str] = [] new_input_specs: list[InputSpec] = [] @@ -210,5 +159,9 @@ def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram: assert num_tokens == num_out_tokens - print("after", ep) + with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()): + _remove_effect_tokens_from_graph_helper( + ep, num_tokens, input_token_names, output_token_names + ) + return ep diff --git a/torch/export/_unlift.py b/torch/export/_unlift.py index 6239c5899c233..52d06a294fac1 100644 --- a/torch/export/_unlift.py +++ b/torch/export/_unlift.py @@ -748,23 +748,11 @@ def _unlift_exported_program_lifted_states( ) -> torch.fx.GraphModule: check_guards = check_guards and _ok_to_generate_guards_fn() - source_node_dict = { - node.name: node for node in ep.graph.nodes if node.op != "placeholder" - } - # placeholder node name might change after deepcopy - placeholder_source_node_dict = { - node.target: node for node in ep.graph.nodes if node.op == "placeholder" - } - - new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph)) - new_gm.meta.update(ep.graph_module.meta) - ep = copy.copy(ep) - ep._graph_module = new_gm - # TODO T206340015 if ep.verifiers[0].dialect != "TRAINING": ep = _remove_effect_tokens(ep) + new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph)) _register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants) forward_arg_names = ( sig.forward_arg_names if (sig := ep.module_call_graph[0].signature) else None @@ -798,13 +786,19 @@ def _unlift_exported_program_lifted_states( for out_spec in ep.graph_signature.output_specs ] + source_node_dict = { + node.name: node for node in ep.graph.nodes if node.op != "placeholder" + } + # placeholder node name might change after deepcopy + placeholder_source_node_dict = { + node.target: node for node in ep.graph.nodes if node.op == "placeholder" + } for node in new_gm.graph.nodes: source_node = None if node.op == "placeholder": source_node = placeholder_source_node_dict.get(node.target) else: - if node.name in source_node_dict: - source_node = source_node_dict.get(node.name) + source_node = source_node_dict.get(node.name) node.meta["from_node"] = [ NodeSource( source_node, diff --git a/torch/fx/node.py b/torch/fx/node.py index cb37b6ece75dd..294e15c550235 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -753,9 +753,7 @@ def is_impure(self, impure_random: bool = True) -> bool: # between eager and compiled execution, regardless of generator usage return True - from torch._higher_order_ops.effects import has_effects - - return self.target in _side_effectful_functions or has_effects(self.target) + return self.target in _side_effectful_functions # Check if an impure module. if self.op == "call_module": From 771be8c062abe71fcf77cfdc02e5fd8197bdb041 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 20 Nov 2025 00:31:01 +0000 Subject: [PATCH 075/230] Revert "[inductor] fix the decision of inner reduction (#167697)" This reverts commit 689d731ece80ceed232b59040afecabc1df520ec. Reverted https://github.com/pytorch/pytorch/pull/167697 on behalf of https://github.com/yangw-dev due to break internal tests, need to include internal changes ([comment](https://github.com/pytorch/pytorch/pull/167697#issuecomment-3555203239)) --- test/inductor/test_mix_order_reduction.py | 19 ++----------------- test/inductor/test_torchinductor.py | 15 --------------- torch/_inductor/ir.py | 4 +--- 3 files changed, 3 insertions(+), 35 deletions(-) diff --git a/test/inductor/test_mix_order_reduction.py b/test/inductor/test_mix_order_reduction.py index 1114810ceccdf..592e42ce41735 100644 --- a/test/inductor/test_mix_order_reduction.py +++ b/test/inductor/test_mix_order_reduction.py @@ -270,20 +270,11 @@ def f(x, y): ], ) @parametrize("split_reductions", (False, True)) - @parametrize( - "shape", ((1000000, 256), (32768, 2048), (32768, 768), (32768 + 1023, 768)) - ) + @parametrize("shape", ((32768, 2048), (32768, 768), (32768 + 1023, 768))) @parametrize("max_autotune", (False, True)) @parametrize("initial_xblock", (1, 2)) - @parametrize("add_1dim", (False, True)) def test_rms_norm_bwd( - self, - wdtype, - split_reductions, - shape, - max_autotune, - initial_xblock, - add_1dim, + self, wdtype, split_reductions, shape, max_autotune, initial_xblock ): # max_autotune can be slow and cost resource, trim down the tests # for max autotune @@ -296,9 +287,6 @@ def test_rms_norm_bwd( ): self.skipTest("Skip non-critical tests to save resources.") - if shape != (1000000, 256) and add_1dim: - self.skipTest("Skip non-critical tests to save resources.") - def f(x, w, eps): orig_dtype = x.dtype @@ -319,9 +307,6 @@ def fwd_bwd(f): # M, N = 1152 * 500, 384 M, N = shape x = torch.randn(M, N, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True) - if add_1dim: - x = x[:, None, :] - w = torch.randn(N, dtype=wdtype, device=GPU_TYPE, requires_grad=True) dy = torch.randn_like(x) eps = 1e-5 diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index f5d5c5107313f..ed65f07742945 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -14629,21 +14629,6 @@ def test_weight_norm_conv2d(self): self.assertTrue(same((ref, ref_grad), (act, act_grad), tol=1e-3)) - @skipIfMPS - def test_inner_reduction_detection(self): - if self.device == "cpu": - self.skipTest("Skip for CPU device") - - x = torch.randn(100000, 1, 256, device=self.device) - - @torch.compile - def f(x): - return x.sum(dim=(0, 1)) - - code = run_and_get_triton_code(f, x) - self.assertTrue("ReductionHint.OUTER" in code) - self.assertFalse("ReductionHint.INNER" in code) - @skip_if_halide @requires_cuda_and_triton @skip_if_cpp_wrapper("skip cpp wrapper") diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 72d8383d2b812..67e0174443882 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -1435,9 +1435,7 @@ def get_read_indices(r: Reduction) -> tuple[Sequence[Expr], bool]: strides = V.graph.sizevars.stride_hints( j, reduction_vars, list(ranges1.keys()) ) - # A 0 stride does not make a reduction contiguous. - # This can happen when the reduction ranges contains a 1. - outer = all(s == 0 or s > 1 for s in strides) + outer = all(s > 1 for s in strides) if outer: num_outer += 1 else: From f890837d979a9e9fb59fdd5bff3d3f1857278126 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 20 Nov 2025 00:34:06 +0000 Subject: [PATCH 076/230] Revert "dist: add list_keys to Store API (#167883)" This reverts commit ef7fa96fbfa4854e863683ff39cf572bc65b32a5. Reverted https://github.com/pytorch/pytorch/pull/167883 on behalf of https://github.com/yangw-dev due to break some internal test, error: use of undeclared identifier, reached out author but no resp, so revert this to keep diff train hygiene ([comment](https://github.com/pytorch/pytorch/pull/167883#issuecomment-3555212038)) --- test/distributed/test_store.py | 8 ------ torch/_C/_distributed_c10d.pyi | 1 - torch/csrc/distributed/c10d/FileStore.cpp | 13 --------- torch/csrc/distributed/c10d/FileStore.hpp | 2 -- torch/csrc/distributed/c10d/HashStore.cpp | 10 ------- torch/csrc/distributed/c10d/HashStore.hpp | 2 -- torch/csrc/distributed/c10d/PrefixStore.cpp | 14 ---------- torch/csrc/distributed/c10d/PrefixStore.hpp | 2 -- torch/csrc/distributed/c10d/Store.hpp | 5 ---- torch/csrc/distributed/c10d/TCPStore.cpp | 24 ---------------- torch/csrc/distributed/c10d/TCPStore.hpp | 2 -- .../csrc/distributed/c10d/TCPStoreBackend.cpp | 10 ------- .../csrc/distributed/c10d/TCPStoreBackend.hpp | 1 - .../distributed/c10d/TCPStoreLibUvBackend.cpp | 28 ------------------- torch/csrc/distributed/c10d/init.cpp | 6 ---- 15 files changed, 128 deletions(-) diff --git a/test/distributed/test_store.py b/test/distributed/test_store.py index e1412701807b6..5e063d373ffb5 100644 --- a/test/distributed/test_store.py +++ b/test/distributed/test_store.py @@ -253,14 +253,6 @@ def test_clone(self): a.set("foo", "bar") self.assertEqual(b.get("foo"), b"bar") - def test_list_keys(self): - a = self._create_store() - a.set("foo", "bar") - a.set("baz", "qux") - keys = a.list_keys() - self.assertIn("foo", keys) - self.assertIn("baz", keys) - # This is the number of keys used in test_set_get. Adding this as a class # property instead of hardcoding in the test since some Store # implementations will have differing number of keys. In the base case, diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 477b35b1811e4..a80efc696e17d 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -218,7 +218,6 @@ class Store: def queue_pop(self, key: str, block: bool = True) -> bytes: ... def queue_push(self, key: str, value: Union[bytes, str]) -> None: ... def queue_len(self, key: str) -> int: ... - def list_keys(self) -> list[str]: ... class FileStore(Store): def __init__(self, path: str, numWorkers: int = ...) -> None: ... diff --git a/torch/csrc/distributed/c10d/FileStore.cpp b/torch/csrc/distributed/c10d/FileStore.cpp index 969379e739438..7e22aa6fd0bd5 100644 --- a/torch/csrc/distributed/c10d/FileStore.cpp +++ b/torch/csrc/distributed/c10d/FileStore.cpp @@ -492,17 +492,4 @@ void FileStore::wait( } } -std::vector FileStore::listKeys() { - std::unique_lock l(activeFileOpLock_); - File file(path_, O_RDONLY, timeout_); - auto lock = file.lockShared(); - pos_ = refresh(file, pos_, cache_, deletePrefix_); - std::vector keys; - keys.reserve(cache_.size()); - for (const auto& kv : cache_) { - keys.push_back(kv.first.substr(regularPrefix_.size())); - } - return keys; -} - } // namespace c10d diff --git a/torch/csrc/distributed/c10d/FileStore.hpp b/torch/csrc/distributed/c10d/FileStore.hpp index 11ded19d8125a..563ac76e03bf5 100644 --- a/torch/csrc/distributed/c10d/FileStore.hpp +++ b/torch/csrc/distributed/c10d/FileStore.hpp @@ -45,8 +45,6 @@ class TORCH_API FileStore : public Store { return path_; } - std::vector listKeys() override; - protected: int64_t addHelper(const std::string& key, int64_t i); diff --git a/torch/csrc/distributed/c10d/HashStore.cpp b/torch/csrc/distributed/c10d/HashStore.cpp index 9073333fb9a48..15befd9ec34e2 100644 --- a/torch/csrc/distributed/c10d/HashStore.cpp +++ b/torch/csrc/distributed/c10d/HashStore.cpp @@ -217,14 +217,4 @@ int64_t HashStore::queueLen(const std::string& key) { return static_cast(it->second.size()); } -std::vector HashStore::listKeys() { - std::unique_lock lock(m_); - std::vector keys; - keys.reserve(map_.size()); - for (const auto& kv : map_) { - keys.push_back(kv.first); - } - return keys; -} - } // namespace c10d diff --git a/torch/csrc/distributed/c10d/HashStore.hpp b/torch/csrc/distributed/c10d/HashStore.hpp index f7aca03de8b22..4007d543a9371 100644 --- a/torch/csrc/distributed/c10d/HashStore.hpp +++ b/torch/csrc/distributed/c10d/HashStore.hpp @@ -59,8 +59,6 @@ class TORCH_API HashStore : public Store { int64_t queueLen(const std::string& key) override; - std::vector listKeys() override; - protected: bool checkLocked( const std::unique_lock& lock, diff --git a/torch/csrc/distributed/c10d/PrefixStore.cpp b/torch/csrc/distributed/c10d/PrefixStore.cpp index fa228c4467f01..057d198f93c2d 100644 --- a/torch/csrc/distributed/c10d/PrefixStore.cpp +++ b/torch/csrc/distributed/c10d/PrefixStore.cpp @@ -146,18 +146,4 @@ c10::intrusive_ptr PrefixStore::getUnderlyingNonPrefixStore() { return store; } -std::vector PrefixStore::listKeys() { - auto keys = store_->listKeys(); - std::vector filteredKeys; - filteredKeys.reserve(keys.size()); - - for (auto& key : keys) { - if (key.find(prefix_) == 0) { - key = key.substr(prefix_.size() + 1); - filteredKeys.push_back(std::move(key)); - } - } - return filteredKeys; -} - } // namespace c10d diff --git a/torch/csrc/distributed/c10d/PrefixStore.hpp b/torch/csrc/distributed/c10d/PrefixStore.hpp index f950ff96590a3..627d2153bb22b 100644 --- a/torch/csrc/distributed/c10d/PrefixStore.hpp +++ b/torch/csrc/distributed/c10d/PrefixStore.hpp @@ -64,8 +64,6 @@ class TORCH_API PrefixStore : public Store { // Recursively to fetch the store before layers of wrapping with PrefixStore. c10::intrusive_ptr getUnderlyingNonPrefixStore(); - std::vector listKeys() override; - protected: std::string prefix_; c10::intrusive_ptr store_; diff --git a/torch/csrc/distributed/c10d/Store.hpp b/torch/csrc/distributed/c10d/Store.hpp index 9a037c65ee7c2..8260d33597d9c 100644 --- a/torch/csrc/distributed/c10d/Store.hpp +++ b/torch/csrc/distributed/c10d/Store.hpp @@ -114,11 +114,6 @@ class TORCH_API Store : public torch::CustomClassHolder { C10_THROW_ERROR(NotImplementedError, "queue support is not implemented."); } - virtual std::vector listKeys() { - C10_THROW_ERROR( - NotImplementedError, "listKeys support is not implemented."); - } - protected: std::chrono::milliseconds timeout_; }; diff --git a/torch/csrc/distributed/c10d/TCPStore.cpp b/torch/csrc/distributed/c10d/TCPStore.cpp index 9f566032b5b3c..b664c5d3bb963 100644 --- a/torch/csrc/distributed/c10d/TCPStore.cpp +++ b/torch/csrc/distributed/c10d/TCPStore.cpp @@ -723,30 +723,6 @@ int64_t TCPStore::queueLen(const std::string& key) { return client_->receiveValue(); } -std::vector TCPStore::listKeys() { - STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__list); - - const std::lock_guard lock(activeOpLock_); - - detail::SendBuffer buffer(*client_, detail::QueryType::LIST_KEYS); - buffer.flush(); - - auto numKeys = client_->receiveValue(); - std::vector keys; - keys.reserve(numKeys); - for (auto i = 0; i < numKeys; ++i) { - auto bits = client_->receiveBits(); - std::string str(bits.begin(), bits.end()); - if (str.find(keyPrefix_) == 0) { - str = str.substr(keyPrefix_.size()); - } else { - continue; - } - keys.emplace_back(str); - } - return keys; -} - bool TCPStore::hasExtendedApi() const { return true; } diff --git a/torch/csrc/distributed/c10d/TCPStore.hpp b/torch/csrc/distributed/c10d/TCPStore.hpp index 09d7ae111c57a..2caab088a609a 100644 --- a/torch/csrc/distributed/c10d/TCPStore.hpp +++ b/torch/csrc/distributed/c10d/TCPStore.hpp @@ -121,8 +121,6 @@ class TORCH_API TCPStore : public Store { int64_t queueLen(const std::string& key) override; - std::vector listKeys() override; - // Waits for all workers to join. void waitForWorkers(); diff --git a/torch/csrc/distributed/c10d/TCPStoreBackend.cpp b/torch/csrc/distributed/c10d/TCPStoreBackend.cpp index dd25729a6ee13..22455a22a4610 100644 --- a/torch/csrc/distributed/c10d/TCPStoreBackend.cpp +++ b/torch/csrc/distributed/c10d/TCPStoreBackend.cpp @@ -78,7 +78,6 @@ class TCPStoreMasterDaemon : public BackgroundThread { void multiGetHandler(int socket); void multiSetHandler(int socket); void cancelWaitHandler(int socket); - void listKeysHandler(int socket); void addMiscellaneousSocket(int socket); void removeMiscellaneousSocket(int socket); bool isMiscellaneousSocket(int socket); @@ -296,8 +295,6 @@ void TCPStoreMasterDaemon::query(int socket) { multiSetHandler(socket); } else if (qt == QueryType::CANCEL_WAIT) { cancelWaitHandler(socket); - } else if (qt == QueryType::LIST_KEYS) { - listKeysHandler(socket); } else { TORCH_CHECK(false, "Unexpected query type"); } @@ -485,13 +482,6 @@ void TCPStoreMasterDaemon::cancelWaitHandler(int socket) { socket, detail::WaitResponseType::WAIT_CANCELED); } -void TCPStoreMasterDaemon::listKeysHandler(int socket) { - tcputil::sendValue(socket, tcpStore_.size()); - for (const auto& kv : tcpStore_) { - tcputil::sendString(socket, kv.first); - } -} - bool TCPStoreMasterDaemon::checkKeys( const std::vector& keys) const { return std::all_of(keys.begin(), keys.end(), [this](const std::string& s) { diff --git a/torch/csrc/distributed/c10d/TCPStoreBackend.hpp b/torch/csrc/distributed/c10d/TCPStoreBackend.hpp index d176ccb702838..d5f7f0248bba5 100644 --- a/torch/csrc/distributed/c10d/TCPStoreBackend.hpp +++ b/torch/csrc/distributed/c10d/TCPStoreBackend.hpp @@ -36,7 +36,6 @@ enum class QueryType : uint8_t { QUEUE_PUSH, QUEUE_POP, QUEUE_LEN, - LIST_KEYS, }; enum class CheckResponseType : uint8_t { READY, NOT_READY }; diff --git a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp index 7427848b8445b..edb640785a170 100644 --- a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp +++ b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp @@ -683,7 +683,6 @@ class LibUVStoreDaemon : public BackgroundThread { const std::string& queueName, const c10::intrusive_ptr& client); int64_t queueLen(const std::string& queueName); - std::vector listKeys(); void registerClient(const c10::intrusive_ptr& client); void unregisterClient(const c10::intrusive_ptr& client); @@ -823,10 +822,6 @@ class UvClient : public UvTcpSocket { if (!parse_queue_len_command()) return; break; - case QueryType::LIST_KEYS: - if (!parse_list_keys_command()) - return; - break; default: C10D_DEBUG( "Client sent invalid command. client:{} command:{}", @@ -1169,19 +1164,6 @@ class UvClient : public UvTcpSocket { return true; } - bool parse_list_keys_command() { - C10D_TRACE("list_keys address:{}", this->address()); - - auto keys = store->listKeys(); - StreamWriter sw(iptr()); - sw.write_value(static_cast(keys.size())); - for (const auto& key : keys) { - sw.write_string(key); - } - sw.send(); - return true; - } - public: explicit UvClient(uv_loop_t* loop, LibUVStoreDaemon* store) : UvTcpSocket(loop), store(store) {} @@ -1560,16 +1542,6 @@ int64_t LibUVStoreDaemon::queueLen(const std::string& key) { } return static_cast(it->second.size()); } - -std::vector LibUVStoreDaemon::listKeys() { - std::vector keys; - keys.reserve(tcpStore_.size()); - for (const auto& kv : tcpStore_) { - keys.push_back(kv.first); - } - return keys; -} - #endif std::unique_ptr create_libuv_tcpstore_backend( diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 255e793eaa4df..6f38cd9cd2c6f 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1657,12 +1657,6 @@ See queue_push for more details. Arguments: key (str): The key of the queue to get the length. -)") - .def( - "list_keys", - &::c10d::Store::listKeys, - R"( -Returns a list of all keys in the store. )") .def( "has_extended_api", From bc8da6339631a0570e9860024375607aa39f49dc Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Wed, 19 Nov 2025 11:29:47 -0800 Subject: [PATCH 077/230] Move MemoryFormat/Layout to headeronly (#168034) ~This PR does change the semantics of the >> operator by using STD_TORCH_CHECK to throw the error instead of TORCH_CHECK. Jane (who is writing this message) thinks it is okay because it is the error case when an invalid MemoryFormat or Layout is getting passed into >>, so the UX benefits of TORCH_CHECK over STD_TORCH_CHECK there are not significant enough to warrant making a new copy of Layout and MemoryFormat's >> APIs.~ Never mind! We shouldn't change TORCH_CHECK to STD_TORCH_CHECK for core usage ever, cuz the traceback info and c10::Error is very much desired!! So the solution is to not migrate the >>s. I pushed new commits to the stack to remove the >> code, but for reference, https://github.com/pytorch/pytorch/pull/168034/commits/8a30179fab3a52ef23fffe19ddad765d5a230ca5 has all the code that I ended up deleting. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168034 Approved by: https://github.com/janeyx99 ghstack dependencies: #168025, #167802, #167803, #167804, #167962 Co-authored-by: Jane Xu --- c10/core/Layout.h | 23 +--- c10/core/MemoryFormat.h | 32 +---- test/cpp/aoti_abi_check/CMakeLists.txt | 2 + test/cpp/aoti_abi_check/test_layout.cpp | 20 +++ test/cpp/aoti_abi_check/test_memoryformat.cpp | 23 ++++ .../libtorch_agnostic_2_10/csrc/my_empty.cpp | 8 +- .../libtorch_agnostic_2_10/ops.py | 10 +- test/cpp_extensions/test_libtorch_agnostic.py | 80 ++++++++--- torch/csrc/stable/ops.h | 18 +-- torch/csrc/stable/stableivalue_conversions.h | 126 ++++++++++++++++++ torch/header_only_apis.txt | 15 +++ torch/headeronly/core/Layout.h | 44 ++++++ torch/headeronly/core/MemoryFormat.h | 46 +++++++ 13 files changed, 366 insertions(+), 81 deletions(-) create mode 100644 test/cpp/aoti_abi_check/test_layout.cpp create mode 100644 test/cpp/aoti_abi_check/test_memoryformat.cpp create mode 100644 torch/headeronly/core/Layout.h create mode 100644 torch/headeronly/core/MemoryFormat.h diff --git a/c10/core/Layout.h b/c10/core/Layout.h index a85f2ee6911ce..7cd25b04c5bb6 100644 --- a/c10/core/Layout.h +++ b/c10/core/Layout.h @@ -3,30 +3,9 @@ #include #include -#include -#include +#include namespace c10 { -enum class Layout : int8_t { - Strided, - Sparse, - SparseCsr, - Mkldnn, - SparseCsc, - SparseBsr, - SparseBsc, - Jagged, - NumOptions -}; - -constexpr auto kStrided = Layout::Strided; -constexpr auto kSparse = Layout::Sparse; -constexpr auto kSparseCsr = Layout::SparseCsr; -constexpr auto kMkldnn = Layout::Mkldnn; -constexpr auto kSparseCsc = Layout::SparseCsc; -constexpr auto kSparseBsr = Layout::SparseBsr; -constexpr auto kSparseBsc = Layout::SparseBsc; -constexpr auto kJagged = Layout::Jagged; inline Layout layout_from_backend(Backend backend) { C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") diff --git a/c10/core/MemoryFormat.h b/c10/core/MemoryFormat.h index 8c8531d014713..7271a281e5ddb 100644 --- a/c10/core/MemoryFormat.h +++ b/c10/core/MemoryFormat.h @@ -3,46 +3,18 @@ #include #include +#include + #include -#include #include -// Memory format is not the property of a Tensor. It is the way to tell an -// operator how the result should be organized in memory and nothing more. That -// means memory format should never be used as return value for any tensor state -// interrogation functions (internally and externally). -// -// Possible options are: -// Preserve: -// If any of the input tensors is in channels_last format, operator output -// should be in channels_last format -// -// Contiguous: -// Regardless of input tensors format, the output should be contiguous -// Tensor. -// -// ChannelsLast: -// Regardless of input tensors format, the output should be in channels_last -// format. - namespace c10 { -enum class MemoryFormat : int8_t { - Contiguous, - Preserve, - ChannelsLast, - ChannelsLast3d, - NumOptions -}; // If you are seeing this, it means that this call site was not checked if // the memory format could be preserved, and it was switched to old default // behaviour of contiguous #define LEGACY_CONTIGUOUS_MEMORY_FORMAT c10::get_contiguous_memory_format() -inline MemoryFormat get_contiguous_memory_format() { - return MemoryFormat::Contiguous; -} - inline std::ostream& operator<<( std::ostream& stream, at::MemoryFormat memory_format) { diff --git a/test/cpp/aoti_abi_check/CMakeLists.txt b/test/cpp/aoti_abi_check/CMakeLists.txt index 483814a0326d2..4146819e2f1a1 100644 --- a/test/cpp/aoti_abi_check/CMakeLists.txt +++ b/test/cpp/aoti_abi_check/CMakeLists.txt @@ -16,8 +16,10 @@ set(AOTI_ABI_CHECK_TEST_SRCS ${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_headeronlyarrayref.cpp + ${AOTI_ABI_CHECK_TEST_ROOT}/test_layout.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_math.cpp + ${AOTI_ABI_CHECK_TEST_ROOT}/test_memoryformat.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_metaprogramming.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_rand.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_scalartype.cpp diff --git a/test/cpp/aoti_abi_check/test_layout.cpp b/test/cpp/aoti_abi_check/test_layout.cpp new file mode 100644 index 0000000000000..7bb45a6897434 --- /dev/null +++ b/test/cpp/aoti_abi_check/test_layout.cpp @@ -0,0 +1,20 @@ +#include + +#include + +TEST(TestLayout, TestLayout) { + using torch::headeronly::Layout; + constexpr Layout expected_layouts[] = { + torch::headeronly::kStrided, + torch::headeronly::kSparse, + torch::headeronly::kSparseCsr, + torch::headeronly::kMkldnn, + torch::headeronly::kSparseCsc, + torch::headeronly::kSparseBsr, + torch::headeronly::kSparseBsc, + torch::headeronly::kJagged, + }; + for (int8_t i = 0; i < static_cast(Layout::NumOptions); i++) { + EXPECT_EQ(static_cast(i), expected_layouts[i]); + } +} diff --git a/test/cpp/aoti_abi_check/test_memoryformat.cpp b/test/cpp/aoti_abi_check/test_memoryformat.cpp new file mode 100644 index 0000000000000..b0a584b15e299 --- /dev/null +++ b/test/cpp/aoti_abi_check/test_memoryformat.cpp @@ -0,0 +1,23 @@ +#include + +#include + +TEST(TestMemoryFormat, TestMemoryFormat) { + using torch::headeronly::MemoryFormat; + constexpr MemoryFormat expected_memory_formats[] = { + MemoryFormat::Contiguous, + MemoryFormat::Preserve, + MemoryFormat::ChannelsLast, + MemoryFormat::ChannelsLast3d, + }; + for (int8_t i = 0; i < static_cast(MemoryFormat::NumOptions); i++) { + EXPECT_EQ(static_cast(i), expected_memory_formats[i]); + } +} + +TEST(TestMemoryFormat, get_contiguous_memory_format) { + using torch::headeronly::get_contiguous_memory_format; + using torch::headeronly::MemoryFormat; + + EXPECT_EQ(get_contiguous_memory_format(), MemoryFormat::Contiguous); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_empty.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_empty.cpp index 6278dca9f281d..4b17b113135e6 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_empty.cpp +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_empty.cpp @@ -10,14 +10,16 @@ using torch::stable::Tensor; Tensor my_empty( torch::headeronly::HeaderOnlyArrayRef size, std::optional dtype, + std::optional layout, std::optional device, - std::optional pin_memory) { - return empty(size, dtype, device, pin_memory); + std::optional pin_memory, + std::optional memory_format) { + return empty(size, dtype, layout, device, pin_memory, memory_format); } STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { m.def( - "my_empty(int[] size, ScalarType? dtype=None, Device? device=None, bool? pin_memory=None) -> Tensor"); + "my_empty(int[] size, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"); } STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py index a740df8c9e25f..8d05741869ebd 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py @@ -156,20 +156,24 @@ def test_get_num_threads() -> int: return torch.ops.libtorch_agnostic_2_10.test_get_num_threads.default() -def my_empty(size, dtype=None, device=None, pin_memory=None) -> Tensor: +def my_empty( + size, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None +) -> Tensor: """ - Creates an empty tensor with the specified size, dtype, device, and pin_memory. + Creates an empty tensor with the specified size, dtype, layout, device, pin_memory, and memory_format. Args: size: list[int] - size of the tensor to create dtype: ScalarType or None - data type of the tensor + layout: Layout or None - layout of the tensor device: Device or None - device on which to create the tensor pin_memory: bool or None - whether to use pinned memory + memory_format: MemoryFormat or None - memory format of the tensor Returns: Tensor - an uninitialized tensor with the specified properties """ return torch.ops.libtorch_agnostic_2_10.my_empty.default( - size, dtype, device, pin_memory + size, dtype, layout, device, pin_memory, memory_format ) diff --git a/test/cpp_extensions/test_libtorch_agnostic.py b/test/cpp_extensions/test_libtorch_agnostic.py index ef92fc316daa7..f24731ee5666a 100644 --- a/test/cpp_extensions/test_libtorch_agnostic.py +++ b/test/cpp_extensions/test_libtorch_agnostic.py @@ -14,6 +14,7 @@ from torch.testing._internal.common_utils import ( install_cpp_extension, IS_WINDOWS, + parametrize, run_tests, skipIfTorchDynamo, TestCase, @@ -618,7 +619,11 @@ def test_get_num_threads(self, device): self.assertEqual(num_threads, expected_num_threads) @skipIfTorchVersionLessThan(2, 10) - def test_my_empty(self, device): + @parametrize("layout", [None, torch.strided, torch.sparse_coo]) + @parametrize( + "memory_format", [None, torch.channels_last, torch.contiguous_format] + ) + def test_my_empty(self, device, layout, memory_format): import libtorch_agnostic_2_10 as libtorch_agnostic deterministic = torch.are_deterministic_algorithms_enabled() @@ -626,35 +631,80 @@ def test_my_empty(self, device): # set use_deterministic_algorithms to fill uninitialized memory torch.use_deterministic_algorithms(True) - size = [2, 3] - result = libtorch_agnostic.ops.my_empty(size, None, None, None) - expected = torch.empty(size) - self.assertEqual(result, expected, exact_device=True) + # Use 4D size for channels_last, 2D otherwise + size = [2, 3, 4, 5] if memory_format == torch.channels_last else [2, 3] + + # sparse_coo layout doesn't support memory_format parameter + if layout == torch.sparse_coo and memory_format is not None: + return + + # Test default parameters + result = libtorch_agnostic.ops.my_empty( + size, None, layout, None, None, memory_format + ) + expected = torch.empty(size, layout=layout, memory_format=memory_format) + self.assertEqual(result, expected, exact_device=True, exact_layout=True) + # Test with dtype result_float = libtorch_agnostic.ops.my_empty( - size, torch.float32, None, None + size, torch.float32, layout, None, None, memory_format + ) + expected_float = torch.empty( + size, + dtype=torch.float32, + layout=layout, + memory_format=memory_format, + ) + self.assertEqual( + result_float, expected_float, exact_device=True, exact_layout=True ) - expected_float = torch.empty(size, dtype=torch.float32) - self.assertEqual(result_float, expected_float, exact_device=True) + # Test with dtype and device result_with_device = libtorch_agnostic.ops.my_empty( - size, torch.float64, device, None + size, torch.float64, layout, device, None, memory_format ) expected_with_device = torch.empty( - size, dtype=torch.float64, device=device + size, + dtype=torch.float64, + layout=layout, + device=device, + memory_format=memory_format, ) self.assertEqual( - result_with_device, expected_with_device, exact_device=True + result_with_device, + expected_with_device, + exact_device=True, + exact_layout=True, ) - if device == "cuda": + # Verify layout if specified + if layout is not None: + self.assertEqual(result_with_device.layout, layout) + + # Verify memory format if specified + if memory_format == torch.channels_last: + self.assertTrue( + result_with_device.is_contiguous( + memory_format=torch.channels_last + ) + ) + elif memory_format == torch.contiguous_format: + self.assertTrue(result_with_device.is_contiguous()) + + # Test pin_memory on CUDA (only once, not for every parameter combination) + if device == "cuda" and layout is None and memory_format is None: result_pinned = libtorch_agnostic.ops.my_empty( - size, torch.float32, "cpu", True + [2, 3], torch.float32, None, "cpu", True, None ) expected_pinned = torch.empty( - size, dtype=torch.float32, device="cpu", pin_memory=True + [2, 3], dtype=torch.float32, device="cpu", pin_memory=True + ) + self.assertEqual( + result_pinned, + expected_pinned, + exact_device=True, + exact_layout=True, ) - self.assertEqual(result_pinned, expected_pinned) self.assertTrue(result_pinned.is_pinned()) finally: torch.use_deterministic_algorithms(deterministic) diff --git a/torch/csrc/stable/ops.h b/torch/csrc/stable/ops.h index c90db39cb1b98..923cbf398a104 100644 --- a/torch/csrc/stable/ops.h +++ b/torch/csrc/stable/ops.h @@ -326,24 +326,26 @@ inline uint32_t get_num_threads() { return num_threads; } -// We expect this to be the stable version of the empty op that takes in -// device and dtype parameters. The empty op creates a tensor with uninitialized -// values of the specified size, dtype, and device. -// This function is only available in 2.10 because it uses the stableivalue -// conversion for HeaderOnlyArrayRef, which is only available in 2.10. +// We expect this to be the stable version of the empty.memory_format op that +// takes in device and dtype parameters. This function is only available in 2.10 +// because it uses the stableivalue conversion for HeaderOnlyArrayRef, which +// is only available in 2.10. inline torch::stable::Tensor empty( torch::headeronly::IntHeaderOnlyArrayRef size, std::optional dtype = std::nullopt, + std::optional layout = std::nullopt, std::optional device = std::nullopt, - std::optional pin_memory = std::nullopt) { + std::optional pin_memory = std::nullopt, + std::optional memory_format = + std::nullopt) { const auto num_args = 6; std::array stack{ torch::stable::detail::from(size), torch::stable::detail::from(dtype), - torch::stable::detail::from(std::nullopt), + torch::stable::detail::from(layout), torch::stable::detail::from(device), torch::stable::detail::from(pin_memory), - torch::stable::detail::from(std::nullopt)}; + torch::stable::detail::from(memory_format)}; TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( "aten::empty", "memory_format", stack.data(), TORCH_ABI_VERSION)); return torch::stable::detail::to(stack[0]); diff --git a/torch/csrc/stable/stableivalue_conversions.h b/torch/csrc/stable/stableivalue_conversions.h index 0e09eeb7f7b14..4538781594785 100644 --- a/torch/csrc/stable/stableivalue_conversions.h +++ b/torch/csrc/stable/stableivalue_conversions.h @@ -5,6 +5,8 @@ #include #include #include +#include +#include #include #include #include @@ -268,6 +270,68 @@ struct FromImpl { // ============================================================================= #if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 +// Specialization for c10::Layout => StableIValue +// Note that we call into the shim to translate between the user's +// Layout and libtorch's Layout, which can be different! +using c10::Layout; +template <> +struct FromImpl { + static StableIValue call( + Layout val, + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { + switch (val) { + case Layout::Strided: + return from(aoti_torch_layout_strided()); + case Layout::Sparse: + return from(aoti_torch_layout_sparse_coo()); + case Layout::SparseCsr: + return from(aoti_torch_layout_sparse_csr()); + case Layout::SparseCsc: + return from(aoti_torch_layout_sparse_csc()); + case Layout::SparseBsr: + return from(aoti_torch_layout_sparse_bsr()); + case Layout::SparseBsc: + return from(aoti_torch_layout_sparse_bsc()); + case Layout::Mkldnn: + return from(aoti_torch_layout__mkldnn()); + case Layout::Jagged: + return from(aoti_torch_layout_jagged()); + default: + STD_TORCH_CHECK( + false, + "Not yet supported Layout, please file an issue describing your use case."); + } + } +}; + +// Specialization for c10::MemoryFormat => StableIValue +// Note that we call into the shim to translate between the user's +// MemoryFormat and libtorch's MemoryFormat, which can be different! +using c10::MemoryFormat; +template <> +struct FromImpl { + static StableIValue call( + MemoryFormat val, + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { + switch (val) { + case MemoryFormat::Contiguous: + return from(aoti_torch_memory_format_contiguous_format()); + case MemoryFormat::Preserve: + return from(aoti_torch_memory_format_preserve_format()); + case MemoryFormat::ChannelsLast: + return from(aoti_torch_memory_format_channels_last()); + case MemoryFormat::ChannelsLast3d: + return from(aoti_torch_memory_format_channels_last_3d()); + default: + STD_TORCH_CHECK( + false, + "Not yet supported MemoryFormat, please file an issue describing your use case."); + } + } +}; + // Specialization for torch::headeronly::HeaderOnlyArrayRef => StableIValue // Returns a new owning reference of the underlying list. template @@ -529,6 +593,68 @@ struct ToImpl { // ============================================================================= #if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 +// Specialization for StableIValue => c10::Layout +template <> +struct ToImpl { + static Layout call( + StableIValue val, + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { + int32_t shim_layout = to(val); + if (shim_layout == aoti_torch_layout_strided()) { + return Layout::Strided; + } else if (shim_layout == aoti_torch_layout_sparse_coo()) { + return Layout::Sparse; + } else if (shim_layout == aoti_torch_layout_sparse_csr()) { + return Layout::SparseCsr; + } else if (shim_layout == aoti_torch_layout_sparse_csc()) { + return Layout::SparseCsc; + } else if (shim_layout == aoti_torch_layout_sparse_bsr()) { + return Layout::SparseBsr; + } else if (shim_layout == aoti_torch_layout_sparse_bsc()) { + return Layout::SparseBsc; + } else if (shim_layout == aoti_torch_layout__mkldnn()) { + return Layout::Mkldnn; + } else if (shim_layout == aoti_torch_layout_jagged()) { + return Layout::Jagged; + } else { + STD_TORCH_CHECK( + false, + "Not yet supported Layout ", + std::to_string(shim_layout), + ", please file an issue describing your use case."); + } + } +}; + +// Specialization for StableIValue => c10::MemoryFormat +template <> +struct ToImpl { + static MemoryFormat call( + StableIValue val, + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { + int32_t shim_memory_format = to(val); + if (shim_memory_format == aoti_torch_memory_format_contiguous_format()) { + return MemoryFormat::Contiguous; + } else if ( + shim_memory_format == aoti_torch_memory_format_preserve_format()) { + return MemoryFormat::Preserve; + } else if (shim_memory_format == aoti_torch_memory_format_channels_last()) { + return MemoryFormat::ChannelsLast; + } else if ( + shim_memory_format == aoti_torch_memory_format_channels_last_3d()) { + return MemoryFormat::ChannelsLast3d; + } else { + STD_TORCH_CHECK( + false, + "Not yet supported MemoryFormat ", + std::to_string(shim_memory_format), + ", please file an issue describing your use case."); + } + } +}; + // Specialization for StableIValue => std::vector // std::vector should be represented as a StableListHandle // filled with StableIValues diff --git a/torch/header_only_apis.txt b/torch/header_only_apis.txt index 598ca377f794b..9f422b720d4e6 100644 --- a/torch/header_only_apis.txt +++ b/torch/header_only_apis.txt @@ -179,6 +179,21 @@ toString << toUnderlying +# torch/headeronly/core/Layout.h +Layout +kStrided +kSparse +kSparseCsr +kSparseCsc +kSparseBsr +kSparseBsc +kMkldnn +kJagged + +# torch/headeronly/core/MemoryFormat.h +MemoryFormat +get_contiguous_memory_format + # torch/headeronly/core/Dispatch_v2.h THO_DISPATCH_V2_TMPL THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL diff --git a/torch/headeronly/core/Layout.h b/torch/headeronly/core/Layout.h new file mode 100644 index 0000000000000..62e34ff67b457 --- /dev/null +++ b/torch/headeronly/core/Layout.h @@ -0,0 +1,44 @@ +#pragma once + +#include +#include + +#include +#include + +namespace c10 { + +enum class Layout : int8_t { + Strided, + Sparse, + SparseCsr, + Mkldnn, + SparseCsc, + SparseBsr, + SparseBsc, + Jagged, + NumOptions +}; + +constexpr auto kStrided = Layout::Strided; +constexpr auto kSparse = Layout::Sparse; +constexpr auto kSparseCsr = Layout::SparseCsr; +constexpr auto kMkldnn = Layout::Mkldnn; +constexpr auto kSparseCsc = Layout::SparseCsc; +constexpr auto kSparseBsr = Layout::SparseBsr; +constexpr auto kSparseBsc = Layout::SparseBsc; +constexpr auto kJagged = Layout::Jagged; + +} // namespace c10 + +HIDDEN_NAMESPACE_BEGIN(torch, headeronly) +using c10::kJagged; +using c10::kMkldnn; +using c10::kSparse; +using c10::kSparseBsc; +using c10::kSparseBsr; +using c10::kSparseCsc; +using c10::kSparseCsr; +using c10::kStrided; +using c10::Layout; +HIDDEN_NAMESPACE_END(torch, headeronly) diff --git a/torch/headeronly/core/MemoryFormat.h b/torch/headeronly/core/MemoryFormat.h new file mode 100644 index 0000000000000..ad02a901e0169 --- /dev/null +++ b/torch/headeronly/core/MemoryFormat.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include + +#include +#include + +// Memory format is not the property of a Tensor. It is the way to tell an +// operator how the result should be organized in memory and nothing more. That +// means memory format should never be used as return value for any tensor state +// interrogation functions (internally and externally). +// +// Possible options are: +// Preserve: +// If any of the input tensors is in channels_last format, operator output +// should be in channels_last format +// +// Contiguous: +// Regardless of input tensors format, the output should be contiguous +// Tensor. +// +// ChannelsLast: +// Regardless of input tensors format, the output should be in channels_last +// format. + +namespace c10 { + +enum class MemoryFormat : int8_t { + Contiguous, + Preserve, + ChannelsLast, + ChannelsLast3d, + NumOptions +}; + +inline MemoryFormat get_contiguous_memory_format() { + return MemoryFormat::Contiguous; +} + +} // namespace c10 + +HIDDEN_NAMESPACE_BEGIN(torch, headeronly) +using c10::get_contiguous_memory_format; +using c10::MemoryFormat; +HIDDEN_NAMESPACE_END(torch, headeronly) From c055ebebf9282d896a5c6d71813a493a238f3765 Mon Sep 17 00:00:00 2001 From: morrison-turnansky Date: Thu, 20 Nov 2025 00:58:41 +0000 Subject: [PATCH 078/230] Change NamedTupleVariable implementation to subclass UserDefinedTupleVariable (#167468) Continuation of work from previous PR, see link for context https://github.com/pytorch/pytorch/pull/161645#discussion_r2323094922 I think this PR is a step in that direction. There is probably some room for simplification. At a high level, the new class NamedTupleVariable handles methods that branch on structseq or the more dynamic subclasses of namedtuple, and falls back to UserDefinedTupleVariable otherwise. Please let me know what you think. @StrongerXi Pull Request resolved: https://github.com/pytorch/pytorch/pull/167468 Approved by: https://github.com/guilhermeleobas, https://github.com/StrongerXi, https://github.com/mlazos --- test/dynamo/test_functions.py | 4 +- torch/_dynamo/variables/dicts.py | 5 +- torch/_dynamo/variables/lists.py | 201 +++++++----------------- torch/_dynamo/variables/user_defined.py | 24 ++- 4 files changed, 85 insertions(+), 149 deletions(-) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index bac435cebfdfc..a6ba5bd0e8a20 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -2053,7 +2053,7 @@ def test_namedtuple_defaults(a, b): return mytuple(tmp.x, tmp[1], tmp.xy + b) @make_test - def test_namedtuple_replace(a, b): + def test_namedtuple_replace_1(a, b): mytuple = collections.namedtuple("mytuple", ["x", "y"]) t = mytuple(a, b) t._replace(x=b) @@ -2109,7 +2109,7 @@ def test_namedtuple_user_methods(a, b): return mytuple.add(), mytuple.static_method(), mytuple.class_method() @make_test - def test_namedtuple_replace(a, b): + def test_namedtuple_replace_2(a, b): mytuple = FunctionTests.MyNamedTuple(a, b) replaced = mytuple._replace(first=b) return mytuple.first + mytuple.second + replaced.first + replaced.second diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 24cd5007da37d..b651c1d454bac 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -44,7 +44,6 @@ ) from .base import ValueMutationNew, VariableTracker from .constant import ConstantVariable -from .lists import ListIteratorVariable if TYPE_CHECKING: @@ -792,6 +791,8 @@ def call_method( self.call_method(tx, "update", args, kwargs) return self elif name == "__iter__": + from .lists import ListIteratorVariable + if self.source and not is_constant_source(self.source): tx.output.guard_on_key_order.add(self.source) return ListIteratorVariable( @@ -1462,6 +1463,8 @@ def call_method( if name == "__len__": return self.dv_dict.call_method(tx, name, args, kwargs) elif name == "__iter__": + from .lists import ListIteratorVariable + return ListIteratorVariable( self.view_items_vt, mutation_type=ValueMutationNew() ) diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 2ac355bd53417..1959af40d7654 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -15,7 +15,6 @@ class that handles its unique behaviors while integrating with Dynamo's """ import collections -import inspect import operator import sys from collections.abc import Sequence @@ -39,7 +38,6 @@ class that handles its unique behaviors while integrating with Dynamo's get_fake_value, guard_if_dyn, iter_contains, - Lit, namedtuple_fields, odict_values, raise_args_mismatch, @@ -48,8 +46,8 @@ class that handles its unique behaviors while integrating with Dynamo's ) from .base import ValueMutationNew, VariableTracker from .constant import ConstantVariable -from .functions import UserFunctionVariable, UserMethodVariable from .iter import IteratorVariable +from .user_defined import UserDefinedTupleVariable if TYPE_CHECKING: @@ -1296,24 +1294,51 @@ def call_obj_hasattr( return variables.ConstantVariable.create(hasattr(torch.Size, name)) -class NamedTupleVariable(TupleVariable): +class NamedTupleVariable(UserDefinedTupleVariable): _nonvar_fields = { "tuple_cls", "dynamic_attributes", - *TupleVariable._nonvar_fields, + *UserDefinedTupleVariable._nonvar_fields, } def __init__( self, items: list[VariableTracker], - tuple_cls: type, + tuple_cls: type[tuple], dynamic_attributes: Optional[dict[str, VariableTracker]] = None, **kwargs: Any, ) -> None: - super().__init__(items, **kwargs) + tuple_vt = variables.TupleVariable( + items, mutation_type=kwargs.get("mutation_type", ValueMutationNew()) + ) + + # Create a dummy instance for method resolution + # This allows _maybe_get_baseclass_method to work correctly + fields = namedtuple_fields(tuple_cls) + num_fields = len(fields) + if tuple_cls.__module__ == "torch.return_types": + # Structseq: single iterable argument + dummy_value = tuple_cls([None] * num_fields) + else: + # Namedtuple: positional arguments + dummy_value = tuple_cls(*([None] * num_fields)) # type: ignore[arg-type] + + super().__init__( + value=dummy_value, + tuple_vt=tuple_vt, + init_args=None, + **kwargs, + ) + self.tuple_cls = tuple_cls + if len(self.tuple_cls.__mro__) < 3: + raise ValueError("NamedTuple should inherit from Tuple and Object.") self.dynamic_attributes = dynamic_attributes if dynamic_attributes else {} + @property + def items(self) -> list[VariableTracker]: + return self._tuple_vt.items + def is_namedtuple(self) -> bool: return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable( getattr(self.tuple_cls, "_make", None) @@ -1325,17 +1350,7 @@ def is_structseq(self) -> bool: def fields(self) -> tuple[str, ...]: return namedtuple_fields(self.tuple_cls) - def debug_repr(self) -> str: - if self.is_structseq(): - # StructSequenceType(iterable) - return repr(self.tuple_cls([Lit(x.debug_repr()) for x in self.items])) - # NamedTupleType(*iterable) - return repr(self.tuple_cls(*(Lit(x.debug_repr()) for x in self.items))) - - def python_type(self) -> type: - return self.tuple_cls - - def as_python_constant(self) -> Any: + def as_python_constant(self): if self.is_structseq(): # StructSequenceType(iterable) result = self.python_type()([x.as_python_constant() for x in self.items]) @@ -1357,37 +1372,32 @@ def as_python_constant(self) -> Any: return result - def as_proxy(self) -> Any: - assert self.python_type() is not SizeVariable + def as_proxy(self): if self.is_structseq(): - # StructSequenceType(iterable) - return self.python_type()(self._as_proxy()) - # NamedTupleType(*iterable) - return self.python_type()(*self._as_proxy()) + return self.python_type()([x.as_proxy() for x in self._tuple_vt.items]) + return self.python_type()(*[x.as_proxy() for x in self._tuple_vt.items]) def reconstruct(self, codegen: "PyCodegen") -> None: - # Always reconstruct the NamedTuple normally first - # Constructors: - # StructSequenceType(iterable) - # NamedTupleType(*iterable) - # NamedTupleType._make(iterable) if self.is_structseq(): create_fn = self.tuple_cls else: create_fn = self.tuple_cls._make # type: ignore[attr-defined] + codegen.add_push_null( lambda: codegen.append_output( codegen.create_load_const_unchecked(create_fn) ) ) - codegen.foreach(self.items) + codegen.foreach(self._tuple_vt.items) codegen.extend_output( [ - create_build_tuple(len(self.items)), + create_build_tuple(len(self._tuple_vt.items)), ] + create_call_function(1, False) ) + # Apply initial dynamic attributes after construction (if any) + # Runtime dynamic attributes are tracked via side effects system for name, value in self.dynamic_attributes.items(): codegen.dup_top() codegen(value) @@ -1395,19 +1405,6 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.store_attr(name) def _is_method_overridden(self, method_name: str) -> bool: - """Checks if a method is overridden in the NamedTuple subclass. - - Args: - method_name (str): The name of the method to check. - - Returns: - bool: True if the method is overridden in the subclass, False otherwise. - - Raises: - ValueError: If the NamedTuple class does not inherit from both Tuple and Object. - """ - if len(self.tuple_cls.__mro__) < 3: - raise ValueError("NamedTuple should inherit from Tuple and Object.") if getattr(self.tuple_cls, method_name, None) == getattr( self.tuple_cls.__mro__[-3], method_name, None ): @@ -1421,7 +1418,10 @@ def call_method( args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: - if name == "__setattr__": + if self._is_method_overridden(name): + # Fall back to UserDefinedTupleVariable + return super().call_method(tx, name, args, kwargs) + elif name == "__setattr__": if kwargs or len(args) != 2: raise_args_mismatch( tx, @@ -1429,121 +1429,42 @@ def call_method( "2 args and 0 kwargs", f"{len(args)} args and {len(kwargs)} kwargs", ) - attr, value = args - attr = attr.as_python_constant() + attr_var, value = args + attr = attr_var.as_python_constant() + if ( # structseq is immutable self.is_structseq() # namedtuple directly created by `collections.namedtuple` is immutable or self.tuple_cls.__bases__ == (tuple,) - # fields are immutable or attr in self.fields() ): raise_observed_exception(AttributeError, tx) - # Subclass of namedtuple type can have dynamic attributes - tx.output.side_effects.mutation(self) - if self.source: - tx.output.side_effects.store_attr(self, attr, value) - self.dynamic_attributes[attr] = value - return ConstantVariable.create(None) - elif name == "_replace": - # NamedTuple._replace should create a new instance with replaced fields - if args: - raise_args_mismatch(tx, name, "0 args", f"{len(args)} args") - - # Get the field names for validation - fields = self.fields() - - # Start with current items (copy them) - new_items = list(self.items) - - # Replace fields specified in kwargs - for field_name, new_value in kwargs.items(): - if field_name not in fields: - raise_observed_exception( - ValueError, - tx, - args=[ - ConstantVariable.create( - f"Got unexpected field name: '{field_name}'" - ) - ], - ) - - # Replace the item at the field's index - field_index = fields.index(field_name) - new_items[field_index] = new_value - return NamedTupleVariable(new_items, self.tuple_cls) + result = self.method_setattr_standard(tx, attr_var, value) + # Also update self.dynamic_attributes + self.dynamic_attributes[attr] = value + return result return super().call_method(tx, name, args, kwargs) - def getitem_const( - self, tx: "InstructionTranslator", arg: VariableTracker - ) -> VariableTracker: - if isinstance(arg, SliceVariable): - # slicing a namedtuple produces a tuple - return TupleVariable( - self.items[arg.as_python_constant()], - source=None, - ) - return super().getitem_const(tx, arg) - - def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: - def check_and_create_method() -> Optional[VariableTracker]: - method = inspect.getattr_static(self.tuple_cls, name, None) - if isinstance(method, classmethod): - # We need the unbounded cls method to avoid the inline __self__ - return UserMethodVariable( - method.__func__, - variables.UserDefinedClassVariable(self.tuple_cls), - ) - elif isinstance(method, staticmethod): - # pyrefly: ignore[bad-argument-type] - return UserFunctionVariable(method.__func__) - elif inspect.isfunction(method): - return UserMethodVariable(method, self) - else: - return None - - # Avoid UserMethodVariable fallback precisely when methods NamedTuple methods have not been overwritten. - if ( - name == "_replace" - and not self._is_method_overridden("_replace") - and not self._is_method_overridden("__getattr__") - ): - # Return a BuiltinVariable for the _replace method - # Get the actual _replace method from the tuple class - actual_replace_method = getattr(self.tuple_cls, "_replace", None) - if actual_replace_method: - from ..source import AttrSource - - source = AttrSource(self.source, name) if self.source else None - return variables.GetAttrVariable(self, name, source=source) - # Fallback if _replace doesn't exist (shouldn't happen for proper NamedTuples) - return super().var_getattr(tx, name) + def python_type(self) -> type: + return self.tuple_cls + def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": if name == "_fields": - result_source = NamedTupleFieldsSource(self.source) if self.source else None - return VariableTracker.build(tx, self.fields(), source=result_source) + source = NamedTupleFieldsSource(self.source) if self.source else None + return VariableTracker.build(tx, self.fields(), source=source) if name in self.dynamic_attributes: return self.dynamic_attributes[name] fields = self.fields() - if name not in fields: - method = check_and_create_method() - if not method: - return super().var_getattr(tx, name) - return method - return self.items[fields.index(name)] + if name in fields: + field_index = fields.index(name) + return self._tuple_vt.items[field_index] - def call_obj_hasattr( - self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: - return variables.ConstantVariable.create( - name in self.dynamic_attributes or hasattr(self.tuple_cls, name) - ) + return super().var_getattr(tx, name) class SliceVariable(VariableTracker): diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index ec378a5512a01..0f145b21fd402 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -38,7 +38,7 @@ import types import warnings import weakref -from typing import TYPE_CHECKING +from typing import Optional, TYPE_CHECKING from typing_extensions import is_typeddict import torch._dynamo.config @@ -96,7 +96,6 @@ ) from .base import raise_type_error_exc, ValueMutationNew, VariableTracker from .dicts import ConstDictVariable, DefaultDictVariable -from .lists import SizeVariable try: @@ -114,6 +113,8 @@ from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator + from .lists import TupleVariable + def is_standard_setattr(val): return val in (object.__setattr__, BaseException.__setattr__) @@ -737,13 +738,16 @@ def deque_signature(iterable=None, maxlen=None): # Modify mutability of namedtuple for sourcelesss instantiations. from .base import AttributeMutationNew + from .lists import NamedTupleVariable # XXX Today come back to this - return variables.NamedTupleVariable( + return NamedTupleVariable( items, self.value, mutation_type=AttributeMutationNew() ) elif self.value is torch.Size: # This simulates `THPSize_pynew`, the C impl for `Size.__new__`. tup = variables.BuiltinVariable(tuple).call_function(tx, args, kwargs) + from .lists import SizeVariable + return SizeVariable(tup.items) elif is_frozen_dataclass(self.value) and self.is_standard_new(): fields = dataclasses.fields(self.value) @@ -2210,10 +2214,15 @@ class UserDefinedTupleVariable(UserDefinedObjectVariable): _nonvar_fields = UserDefinedObjectVariable._nonvar_fields - def __init__(self, value, tuple_vt=None, init_args=None, **kwargs): + def __init__( + self, + value, + tuple_vt: Optional["TupleVariable"] = None, + init_args=None, + **kwargs, + ): super().__init__(value, init_args=init_args, **kwargs) - self._tuple_vt = tuple_vt - if self._tuple_vt is None: + if tuple_vt is None: assert self.source is None, ( "tuple_vt must be constructed by builder.py when source is present" ) @@ -2229,6 +2238,9 @@ def __init__(self, value, tuple_vt=None, init_args=None, **kwargs): elems, mutation_type=ValueMutationNew() ) + else: + self._tuple_vt = tuple_vt + def call_method( self, tx, From 192b96e42b82b8e61bccef1c389e1f03a3c58356 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 20 Nov 2025 01:59:23 +0000 Subject: [PATCH 079/230] Revert "[AOTI] Fix a GPU memory leak caused by reference circle (#168063)" This reverts commit cdca10b2753909d1eaeb096c4e91c47add3935b9. Reverted https://github.com/pytorch/pytorch/pull/168063 on behalf of https://github.com/yangw-dev due to Internal test breaks, contacted author to revert it and fix it test_codegen_int_array_var_fix_memory_leak, self.assertTrue(allocated_memory[1] == allocated_memory[2]) AssertionError: False is not true ([comment](https://github.com/pytorch/pytorch/pull/168063#issuecomment-3555419672)) --- test/inductor/test_aot_inductor.py | 44 ---------------------- torch/_inductor/codegen/cpp_wrapper_cpu.py | 24 +----------- 2 files changed, 2 insertions(+), 66 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 69f5eb92b58ce..5f0447c32264e 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -7437,50 +7437,6 @@ def forward(self, x): "RAIIAtenTensorHandle buf0(buf0_handle_restrided);" ).run(code) - def test_codegen_int_array_var_fix_memory_leak(self): - """ - Fix https://github.com/pytorch/pytorch/issues/167630 - """ - if self.device != "cuda": - raise unittest.SkipTest("test is only for cuda") - - def make_mlp(in_dim=128, hidden=256, out_dim=64, depth=3): - layers = [] - d = in_dim - for _ in range(depth): - layers += [nn.Linear(d, hidden), nn.ReLU()] - d = hidden - layers += [nn.Linear(d, out_dim)] - return nn.Sequential(*layers) - - batch = 32 - in_dim = 2048 - hidden = 512 - out_dim = 10 - depth = 6 - - import gc - - allocated_memory = [] - for _ in range(3): - torch.cuda.reset_peak_memory_stats() - - model = make_mlp(in_dim, hidden, out_dim, depth).to(self.device) - example_inputs = (torch.randn(batch, in_dim, device=self.device),) - ep = torch.export.export( - model, - example_inputs, - ) - torch._inductor.aoti_compile_and_package(ep) - - del model, example_inputs, ep - torch.cuda.synchronize() - torch.cuda.empty_cache() - gc.collect() - allocated_memory.append(torch.cuda.memory_allocated()) - - self.assertTrue(allocated_memory[1] == allocated_memory[2]) - @unittest.skipIf(IS_MACOS, "might have no readelf on Mac") def test_libtorch_free_so(self): class Model(torch.nn.Module): diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 65d356dce0979..61a97fd740cbc 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -96,7 +96,6 @@ def __init__(self): self.include_extra_header = functools.lru_cache(None)( # type: ignore[method-assign] self._include_extra_header ) - self.codegen_int_array_var_cache = {} @staticmethod def create( @@ -1637,33 +1636,14 @@ def codegen_memory_format(self, memory_format): self.used_cached_memory_formats.add(memory_format_str) return f"cached_torch_memory_format_{memory_format_str}" + @functools.cache # noqa: B019 def codegen_int_array_var( self, int_array: str, writeline: Callable[..., None], known_statically=False, graph=None, # for per-graph caching - ) -> str: - # Use id(graph) for caching to avoid circular references - cache_key = ( - int_array, - id(writeline), - known_statically, - id(graph) if graph else None, - ) - if cache_key not in self.codegen_int_array_var_cache: - self.codegen_int_array_var_cache[cache_key] = ( - self._codegen_int_array_var_impl(int_array, writeline, known_statically) - ) - - return self.codegen_int_array_var_cache[cache_key] - - def _codegen_int_array_var_impl( - self, - int_array: str, - writeline: Callable[..., None], - known_statically: bool, - ) -> str: + ): # Used for size/stride declaration # # Because the memory planning is done in two passes (see the implementation From 9e9e8fae2f492cf1e3e7ea3766e5dbae701cc60b Mon Sep 17 00:00:00 2001 From: Divyansh Khanna Date: Thu, 20 Nov 2025 02:04:26 +0000 Subject: [PATCH 080/230] [torch/utils/data] Update CODEOWNERS (#168172) Adding Lavender to the list. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168172 Approved by: https://github.com/ramanishsingh, https://github.com/aelavender --- CODEOWNERS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CODEOWNERS b/CODEOWNERS index 137031066090e..7516c4ad7ec06 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -135,7 +135,7 @@ torch/profiler/ @sraikund16 test/functorch/test_aotdispatch.py @ezyang @Chillee # Dataloader -torch/utils/data/ @divyanshk @ramanishsingh @scotts +torch/utils/data/ @divyanshk @ramanishsingh @scotts @aelavender # hipify torch/utils/hipify/ @jeffdaily @jithunnair-amd From bb4009a28d727012ca7c8f105f5acc6cbe56c0ce Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Thu, 20 Nov 2025 02:19:29 +0000 Subject: [PATCH 081/230] [Inductor] Naive foreach autotune support (#162053) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Initial autotuning support for foreach kernels, 4x improvement for some kernels in internal workload. More improvements can surely be made here in the future. Removing num_warps for definition to enable autotune support in generated wrapper code. Before: triton_for_fused_18.kd 🔍 | 4.986 ms | 4.986 ms | 2.493 ms | 2 | triton_for_fused_6.kd 🔍 | 0.098 ms | 0.098 ms | 0.049 ms | 2 | triton_for_fused_7.kd 🔍 | 0.036 ms | 0.036 ms | 0.018 ms | 2 | After: triton_for_fused_18.kd 🔍 | 1.273 ms | 1.273 ms | 0.636 ms | 2 | triton_for_fused_6.kd 🔍 | 0.044 ms | 0.044 ms | 0.022 ms | 2 | triton_for_fused_7.kd 🔍 | 0.024 ms | 0.024 ms | 0.012 ms | 2 | num_warps=8 default due to https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/triton_combo_kernel.py#L374 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162053 Approved by: https://github.com/mlazos, https://github.com/naromero77amd, https://github.com/jeffdaily Co-authored-by: Nichols A. Romero --- torch/_inductor/codegen/triton_combo_kernel.py | 2 +- torch/_inductor/runtime/triton_heuristics.py | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index 615913933326e..41b12d05cd32e 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -627,7 +627,7 @@ def jit_line( if heuristics == "foreach": heuristics_line = f""" @triton_heuristics.foreach( - num_warps={self.num_warps}, + filename=__file__, triton_meta={triton_meta!r}, inductor_meta={inductor_meta!r}, ) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index d5851eeceeb24..d59d9bbf3fee4 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -3621,13 +3621,24 @@ def user_autotune( ) -def foreach(triton_meta, num_warps, filename=None, inductor_meta=None): +def foreach(triton_meta, filename=None, inductor_meta=None): """ Compile a triton foreach kernel """ + configs = [] + + # Naive autotuning path for num_warps + if not ( + inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") + ): + configs.append(triton.Config({}, num_stages=1, num_warps=8)) + else: + for warps in [1, 2, 4, 8]: + configs.append(triton.Config({}, num_stages=1, num_warps=warps)) + return cached_autotune( None, - [triton.Config({}, num_stages=1, num_warps=num_warps)], + configs, triton_meta=triton_meta, inductor_meta=inductor_meta, heuristic_type=HeuristicType.TEMPLATE, From c3320ed9cb14df1f39dae5f097e47b425e52afcf Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Wed, 19 Nov 2025 14:57:24 -0800 Subject: [PATCH 082/230] [3.14] Add python version adjustment for frame count changes (#168190) ``sys.getrefcount(lib)`` got impacted due to python3.13 optimization. ``sys.getrefcount(lib._op_impls)`` and others remain the same. Test plan: ``python test/test_python_dispatch.py TestPythonRegistration.test_finalizer`` in local ``python=3.14`` env Pull Request resolved: https://github.com/pytorch/pytorch/pull/168190 Approved by: https://github.com/williamwen42, https://github.com/azahed98 --- test/test_python_dispatch.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index 515ce435b72a7..359236602e61e 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -299,8 +299,12 @@ def test_finalizer(self): lib = Library(self.test_ns, "FRAGMENT") # noqa: TOR901 lib.define("foo123(Tensor x) -> Tensor") - # 1 for `lib`, 1 for sys.getrefcount - self.assertEqual(sys.getrefcount(lib), 2) + # 1 for `lib`, 1 for sys.getrefcount' for previous python version (<=3.12) + # In Python 3.13+, sys.getrefcount() was optimized to not create + # a temporary reference, so expected counts are 1 less than before + expected_refcount = 1 if sys.version_info >= (3, 14) else 2 + self.assertEqual(sys.getrefcount(lib), expected_refcount) + # We gained an additional reference that gets cleared when the finalizer runs self.assertEqual(sys.getrefcount(torch.library._impls), impls_refcnt + 1) # 1 for `lib` @@ -318,7 +322,7 @@ def foo123(x): saved_op_impls = lib._op_impls # del will definitely work if the following passes - self.assertEqual(sys.getrefcount(lib), 2) + self.assertEqual(sys.getrefcount(lib), expected_refcount) del lib # 1 for saved_op_impls @@ -326,7 +330,7 @@ def foo123(x): # This function should be the last user of lib._op_impls: # - lib should not have a reference anymore (it was del'ed) # - lib's finalizer should not have a reference anymore - self.assertEqual(sys.getrefcount(saved_op_impls), 2) + self.assertEqual(sys.getrefcount(saved_op_impls), expected_refcount) self.assertTrue(key not in torch.library._impls) From 7a064ed3eafa43f17412d434b395240c727b3000 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 20 Nov 2025 03:32:02 +0000 Subject: [PATCH 083/230] Revert "Change NamedTupleVariable implementation to subclass UserDefinedTupleVariable (#167468)" This reverts commit c055ebebf9282d896a5c6d71813a493a238f3765. Reverted https://github.com/pytorch/pytorch/pull/167468 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/167468#issuecomment-3555613000)) --- test/dynamo/test_functions.py | 4 +- torch/_dynamo/variables/dicts.py | 5 +- torch/_dynamo/variables/lists.py | 201 +++++++++++++++++------- torch/_dynamo/variables/user_defined.py | 24 +-- 4 files changed, 149 insertions(+), 85 deletions(-) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index a6ba5bd0e8a20..bac435cebfdfc 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -2053,7 +2053,7 @@ def test_namedtuple_defaults(a, b): return mytuple(tmp.x, tmp[1], tmp.xy + b) @make_test - def test_namedtuple_replace_1(a, b): + def test_namedtuple_replace(a, b): mytuple = collections.namedtuple("mytuple", ["x", "y"]) t = mytuple(a, b) t._replace(x=b) @@ -2109,7 +2109,7 @@ def test_namedtuple_user_methods(a, b): return mytuple.add(), mytuple.static_method(), mytuple.class_method() @make_test - def test_namedtuple_replace_2(a, b): + def test_namedtuple_replace(a, b): mytuple = FunctionTests.MyNamedTuple(a, b) replaced = mytuple._replace(first=b) return mytuple.first + mytuple.second + replaced.first + replaced.second diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index b651c1d454bac..24cd5007da37d 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -44,6 +44,7 @@ ) from .base import ValueMutationNew, VariableTracker from .constant import ConstantVariable +from .lists import ListIteratorVariable if TYPE_CHECKING: @@ -791,8 +792,6 @@ def call_method( self.call_method(tx, "update", args, kwargs) return self elif name == "__iter__": - from .lists import ListIteratorVariable - if self.source and not is_constant_source(self.source): tx.output.guard_on_key_order.add(self.source) return ListIteratorVariable( @@ -1463,8 +1462,6 @@ def call_method( if name == "__len__": return self.dv_dict.call_method(tx, name, args, kwargs) elif name == "__iter__": - from .lists import ListIteratorVariable - return ListIteratorVariable( self.view_items_vt, mutation_type=ValueMutationNew() ) diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 1959af40d7654..2ac355bd53417 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -15,6 +15,7 @@ class that handles its unique behaviors while integrating with Dynamo's """ import collections +import inspect import operator import sys from collections.abc import Sequence @@ -38,6 +39,7 @@ class that handles its unique behaviors while integrating with Dynamo's get_fake_value, guard_if_dyn, iter_contains, + Lit, namedtuple_fields, odict_values, raise_args_mismatch, @@ -46,8 +48,8 @@ class that handles its unique behaviors while integrating with Dynamo's ) from .base import ValueMutationNew, VariableTracker from .constant import ConstantVariable +from .functions import UserFunctionVariable, UserMethodVariable from .iter import IteratorVariable -from .user_defined import UserDefinedTupleVariable if TYPE_CHECKING: @@ -1294,51 +1296,24 @@ def call_obj_hasattr( return variables.ConstantVariable.create(hasattr(torch.Size, name)) -class NamedTupleVariable(UserDefinedTupleVariable): +class NamedTupleVariable(TupleVariable): _nonvar_fields = { "tuple_cls", "dynamic_attributes", - *UserDefinedTupleVariable._nonvar_fields, + *TupleVariable._nonvar_fields, } def __init__( self, items: list[VariableTracker], - tuple_cls: type[tuple], + tuple_cls: type, dynamic_attributes: Optional[dict[str, VariableTracker]] = None, **kwargs: Any, ) -> None: - tuple_vt = variables.TupleVariable( - items, mutation_type=kwargs.get("mutation_type", ValueMutationNew()) - ) - - # Create a dummy instance for method resolution - # This allows _maybe_get_baseclass_method to work correctly - fields = namedtuple_fields(tuple_cls) - num_fields = len(fields) - if tuple_cls.__module__ == "torch.return_types": - # Structseq: single iterable argument - dummy_value = tuple_cls([None] * num_fields) - else: - # Namedtuple: positional arguments - dummy_value = tuple_cls(*([None] * num_fields)) # type: ignore[arg-type] - - super().__init__( - value=dummy_value, - tuple_vt=tuple_vt, - init_args=None, - **kwargs, - ) - + super().__init__(items, **kwargs) self.tuple_cls = tuple_cls - if len(self.tuple_cls.__mro__) < 3: - raise ValueError("NamedTuple should inherit from Tuple and Object.") self.dynamic_attributes = dynamic_attributes if dynamic_attributes else {} - @property - def items(self) -> list[VariableTracker]: - return self._tuple_vt.items - def is_namedtuple(self) -> bool: return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable( getattr(self.tuple_cls, "_make", None) @@ -1350,7 +1325,17 @@ def is_structseq(self) -> bool: def fields(self) -> tuple[str, ...]: return namedtuple_fields(self.tuple_cls) - def as_python_constant(self): + def debug_repr(self) -> str: + if self.is_structseq(): + # StructSequenceType(iterable) + return repr(self.tuple_cls([Lit(x.debug_repr()) for x in self.items])) + # NamedTupleType(*iterable) + return repr(self.tuple_cls(*(Lit(x.debug_repr()) for x in self.items))) + + def python_type(self) -> type: + return self.tuple_cls + + def as_python_constant(self) -> Any: if self.is_structseq(): # StructSequenceType(iterable) result = self.python_type()([x.as_python_constant() for x in self.items]) @@ -1372,32 +1357,37 @@ def as_python_constant(self): return result - def as_proxy(self): + def as_proxy(self) -> Any: + assert self.python_type() is not SizeVariable if self.is_structseq(): - return self.python_type()([x.as_proxy() for x in self._tuple_vt.items]) - return self.python_type()(*[x.as_proxy() for x in self._tuple_vt.items]) + # StructSequenceType(iterable) + return self.python_type()(self._as_proxy()) + # NamedTupleType(*iterable) + return self.python_type()(*self._as_proxy()) def reconstruct(self, codegen: "PyCodegen") -> None: + # Always reconstruct the NamedTuple normally first + # Constructors: + # StructSequenceType(iterable) + # NamedTupleType(*iterable) + # NamedTupleType._make(iterable) if self.is_structseq(): create_fn = self.tuple_cls else: create_fn = self.tuple_cls._make # type: ignore[attr-defined] - codegen.add_push_null( lambda: codegen.append_output( codegen.create_load_const_unchecked(create_fn) ) ) - codegen.foreach(self._tuple_vt.items) + codegen.foreach(self.items) codegen.extend_output( [ - create_build_tuple(len(self._tuple_vt.items)), + create_build_tuple(len(self.items)), ] + create_call_function(1, False) ) - # Apply initial dynamic attributes after construction (if any) - # Runtime dynamic attributes are tracked via side effects system for name, value in self.dynamic_attributes.items(): codegen.dup_top() codegen(value) @@ -1405,6 +1395,19 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.store_attr(name) def _is_method_overridden(self, method_name: str) -> bool: + """Checks if a method is overridden in the NamedTuple subclass. + + Args: + method_name (str): The name of the method to check. + + Returns: + bool: True if the method is overridden in the subclass, False otherwise. + + Raises: + ValueError: If the NamedTuple class does not inherit from both Tuple and Object. + """ + if len(self.tuple_cls.__mro__) < 3: + raise ValueError("NamedTuple should inherit from Tuple and Object.") if getattr(self.tuple_cls, method_name, None) == getattr( self.tuple_cls.__mro__[-3], method_name, None ): @@ -1418,10 +1421,7 @@ def call_method( args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: - if self._is_method_overridden(name): - # Fall back to UserDefinedTupleVariable - return super().call_method(tx, name, args, kwargs) - elif name == "__setattr__": + if name == "__setattr__": if kwargs or len(args) != 2: raise_args_mismatch( tx, @@ -1429,42 +1429,121 @@ def call_method( "2 args and 0 kwargs", f"{len(args)} args and {len(kwargs)} kwargs", ) - attr_var, value = args - attr = attr_var.as_python_constant() - + attr, value = args + attr = attr.as_python_constant() if ( # structseq is immutable self.is_structseq() # namedtuple directly created by `collections.namedtuple` is immutable or self.tuple_cls.__bases__ == (tuple,) + # fields are immutable or attr in self.fields() ): raise_observed_exception(AttributeError, tx) - - result = self.method_setattr_standard(tx, attr_var, value) - # Also update self.dynamic_attributes + # Subclass of namedtuple type can have dynamic attributes + tx.output.side_effects.mutation(self) + if self.source: + tx.output.side_effects.store_attr(self, attr, value) self.dynamic_attributes[attr] = value - return result + return ConstantVariable.create(None) + elif name == "_replace": + # NamedTuple._replace should create a new instance with replaced fields + if args: + raise_args_mismatch(tx, name, "0 args", f"{len(args)} args") + + # Get the field names for validation + fields = self.fields() + + # Start with current items (copy them) + new_items = list(self.items) + + # Replace fields specified in kwargs + for field_name, new_value in kwargs.items(): + if field_name not in fields: + raise_observed_exception( + ValueError, + tx, + args=[ + ConstantVariable.create( + f"Got unexpected field name: '{field_name}'" + ) + ], + ) + + # Replace the item at the field's index + field_index = fields.index(field_name) + new_items[field_index] = new_value + + return NamedTupleVariable(new_items, self.tuple_cls) return super().call_method(tx, name, args, kwargs) - def python_type(self) -> type: - return self.tuple_cls + def getitem_const( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: + if isinstance(arg, SliceVariable): + # slicing a namedtuple produces a tuple + return TupleVariable( + self.items[arg.as_python_constant()], + source=None, + ) + return super().getitem_const(tx, arg) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + def check_and_create_method() -> Optional[VariableTracker]: + method = inspect.getattr_static(self.tuple_cls, name, None) + if isinstance(method, classmethod): + # We need the unbounded cls method to avoid the inline __self__ + return UserMethodVariable( + method.__func__, + variables.UserDefinedClassVariable(self.tuple_cls), + ) + elif isinstance(method, staticmethod): + # pyrefly: ignore[bad-argument-type] + return UserFunctionVariable(method.__func__) + elif inspect.isfunction(method): + return UserMethodVariable(method, self) + else: + return None + + # Avoid UserMethodVariable fallback precisely when methods NamedTuple methods have not been overwritten. + if ( + name == "_replace" + and not self._is_method_overridden("_replace") + and not self._is_method_overridden("__getattr__") + ): + # Return a BuiltinVariable for the _replace method + # Get the actual _replace method from the tuple class + actual_replace_method = getattr(self.tuple_cls, "_replace", None) + if actual_replace_method: + from ..source import AttrSource + + source = AttrSource(self.source, name) if self.source else None + return variables.GetAttrVariable(self, name, source=source) + # Fallback if _replace doesn't exist (shouldn't happen for proper NamedTuples) + return super().var_getattr(tx, name) - def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": if name == "_fields": - source = NamedTupleFieldsSource(self.source) if self.source else None - return VariableTracker.build(tx, self.fields(), source=source) + result_source = NamedTupleFieldsSource(self.source) if self.source else None + return VariableTracker.build(tx, self.fields(), source=result_source) if name in self.dynamic_attributes: return self.dynamic_attributes[name] fields = self.fields() - if name in fields: - field_index = fields.index(name) - return self._tuple_vt.items[field_index] + if name not in fields: + method = check_and_create_method() + if not method: + return super().var_getattr(tx, name) + return method + return self.items[fields.index(name)] - return super().var_getattr(tx, name) + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> VariableTracker: + return variables.ConstantVariable.create( + name in self.dynamic_attributes or hasattr(self.tuple_cls, name) + ) class SliceVariable(VariableTracker): diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 0f145b21fd402..ec378a5512a01 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -38,7 +38,7 @@ import types import warnings import weakref -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from typing_extensions import is_typeddict import torch._dynamo.config @@ -96,6 +96,7 @@ ) from .base import raise_type_error_exc, ValueMutationNew, VariableTracker from .dicts import ConstDictVariable, DefaultDictVariable +from .lists import SizeVariable try: @@ -113,8 +114,6 @@ from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator - from .lists import TupleVariable - def is_standard_setattr(val): return val in (object.__setattr__, BaseException.__setattr__) @@ -738,16 +737,13 @@ def deque_signature(iterable=None, maxlen=None): # Modify mutability of namedtuple for sourcelesss instantiations. from .base import AttributeMutationNew - from .lists import NamedTupleVariable # XXX Today come back to this - return NamedTupleVariable( + return variables.NamedTupleVariable( items, self.value, mutation_type=AttributeMutationNew() ) elif self.value is torch.Size: # This simulates `THPSize_pynew`, the C impl for `Size.__new__`. tup = variables.BuiltinVariable(tuple).call_function(tx, args, kwargs) - from .lists import SizeVariable - return SizeVariable(tup.items) elif is_frozen_dataclass(self.value) and self.is_standard_new(): fields = dataclasses.fields(self.value) @@ -2214,15 +2210,10 @@ class UserDefinedTupleVariable(UserDefinedObjectVariable): _nonvar_fields = UserDefinedObjectVariable._nonvar_fields - def __init__( - self, - value, - tuple_vt: Optional["TupleVariable"] = None, - init_args=None, - **kwargs, - ): + def __init__(self, value, tuple_vt=None, init_args=None, **kwargs): super().__init__(value, init_args=init_args, **kwargs) - if tuple_vt is None: + self._tuple_vt = tuple_vt + if self._tuple_vt is None: assert self.source is None, ( "tuple_vt must be constructed by builder.py when source is present" ) @@ -2238,9 +2229,6 @@ def __init__( elems, mutation_type=ValueMutationNew() ) - else: - self._tuple_vt = tuple_vt - def call_method( self, tx, From 34bb9c4f5d06f9370a954ad377117ceb41e5e547 Mon Sep 17 00:00:00 2001 From: Qi Li Date: Thu, 20 Nov 2025 03:32:30 +0000 Subject: [PATCH 084/230] [AOTI] Fix unknown constant type for device-moved constants (#168138) ### Issue When we have the flag `use_runtime_constant_folding=False`, if we move a constant (buffer or parameter) to a different device, we'll generate a new buf/param during compilation time with a new name where the new device (+counter) will be appended, e.g.: ``` # noramlised name orig buf: model_x_submodule_y_buf0_name moved buf: model_x_submodule_y_buf0_name_cpu0 ``` However, these new names are not registered in `V.graph.constants`. During cpp wrapper code generation, they won't be recognised, hence will get the `ConstantType::Unknown`. It'll cause issues for model loading during runtime. https://github.com/pytorch/pytorch/blob/b8a3165d28b672ac6d84128e66265bf471b92a55/torch/_inductor/codegen/cpp_wrapper_cpu.py#L851-L862 ### Fix After we do the new const name allocation following device movement, check if the original constant is any recognised buffer or parameter, if so, register the new ones with graph as well. ### Failed Unittest before the patch ``` =========================================================================== short test summary info ============================================================================ FAILED [3.9054s] test/inductor/test_aot_inductor.py::AOTInductorTestABICompatibleCpu::test_device_moved_constant_cpu - RuntimeError: Expected to not find "torch::aot_inductor::ConstantType::Unknown" but found it FAILED [3.1852s] test/inductor/test_aot_inductor.py::AOTInductorTestABICompatibleGpu::test_device_moved_constant_cuda - RuntimeError: Expected to not find "torch::aot_inductor::ConstantType::Unknown" but found it ================================================================ 2 failed, 1 skipped, 916 deselected in 11.81s ================================================================= ``` cc. @muchulee8 @desertfire Pull Request resolved: https://github.com/pytorch/pytorch/pull/168138 Approved by: https://github.com/muchulee8 --- test/inductor/test_aot_inductor.py | 43 ++++++++++++++++++++++++++++++ torch/_inductor/graph.py | 26 +++++++++++++++++- 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 5f0447c32264e..56700bdac835f 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -671,6 +671,49 @@ def forward(self, x): code ) + @requires_gpu + def test_device_moved_constant(self): + # testing both directions + device_movements = [ + (torch.device(type=GPU_TYPE, index=0), torch.device("cpu")), + (torch.device("cpu"), torch.device(type=GPU_TYPE, index=0)), + ] + + class Model(torch.nn.Module): + def __init__(self, from_device): + super().__init__() + self.register_buffer("_buf", torch.randn(6, 7, device=from_device)) + self._param = torch.nn.Parameter( + torch.rand(6, 7, device=from_device), requires_grad=False + ) + + def forward(self, x): + to_device = x.device + moved_buf = self._buf.to(to_device) + moved_param = self._param.to(to_device) + return moved_buf, moved_param + + with config.patch( + { + "aot_inductor.use_runtime_constant_folding": False, + } + ): + for from_device, to_device in device_movements: + model = Model(from_device) + example_inputs = (torch.randn(6, 7, device=to_device),) + _, code = run_and_get_cpp_code( + AOTIRunnerUtil.compile, model, example_inputs + ) + FileCheck().check_not("torch::aot_inductor::ConstantType::Unknown").run( + code + ) + FileCheck().check_count( + "torch::aot_inductor::ConstantType::Buffer", 2, exactly=True + ).run(code) + FileCheck().check_count( + "torch::aot_inductor::ConstantType::Parameter", 2, exactly=True + ).run(code) + def test_subclasses(self): device_to_init = self.device diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 1eaab41130675..517d6c3e39d1b 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1114,11 +1114,35 @@ def constant_name(self, name: str, device_override: Optional[torch.device]) -> s with torch.utils._python_dispatch._disable_current_modes(): # caller might have OrderedSet fake tensor mode which will create a fake tensor # when calling .to, so unset modes here - return self.allocate_non_dup_const_name( + non_dup_const_name = self.allocate_non_dup_const_name( f"{name}_{device_override.type}{device_override.index or 0}", self.constants[name].to(device_override), ) + assert non_dup_const_name in self.constants, ( + f"{non_dup_const_name} should be in V.graph.constants already" + ) + + # register device-copied buffers and parameters to graph as well + # to codegen correct torch::aot_inductor::ConstantType for them rather than `Unknown` + if any( + name == normalize_name(buffer_name) + for buffer_name in self.named_buffers + ): + self.named_buffers[non_dup_const_name] = self.constants[ + non_dup_const_name + ] + + if any( + name == normalize_name(param_name) + for param_name in self.named_parameters + ): + self.named_parameters[non_dup_const_name] = self.constants[ + non_dup_const_name + ] + + return non_dup_const_name + # pyrefly: ignore [bad-override] def placeholder( self, From 9177d6ec23aec3fb83ceda5d96dfcd077b463b9f Mon Sep 17 00:00:00 2001 From: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com> Date: Thu, 20 Nov 2025 04:25:42 +0000 Subject: [PATCH 085/230] [ROCm][CI] Add ROCm noble image caching to docker-cache-rocm.yml (#168202) Needed due to https://github.com/pytorch/pytorch/pull/168162, which means rocm-mi300.yml (uses noble images) and periodic-rocm-mi300.yml (uses jammy images) will both run on the new MI3xx capacity. Also re-enable `workflow_dispatch` with inputs required to run successfully Pull Request resolved: https://github.com/pytorch/pytorch/pull/168202 Approved by: https://github.com/jeffdaily --- .github/workflows/docker-cache-rocm.yml | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/.github/workflows/docker-cache-rocm.yml b/.github/workflows/docker-cache-rocm.yml index c973656018944..380b8c2d1e257 100644 --- a/.github/workflows/docker-cache-rocm.yml +++ b/.github/workflows/docker-cache-rocm.yml @@ -6,9 +6,19 @@ on: branches: [main, release] types: - completed + workflow_dispatch: + inputs: + branch: + type: string + description: Branch corresponding to the docker images being cached + required: true + run_id: + type: string + description: Workflow run id to pull artifacts from + required: true concurrency: - group: ${{ github.workflow }}-${{ github.event.workflow_run.head_branch }} + group: ${{ github.workflow }}-${{ github.event.workflow_run.head_branch || github.event.inputs.branch }} cancel-in-progress: true permissions: @@ -29,7 +39,7 @@ jobs: - name: Download artifacts uses: actions/download-artifact@v4.1.7 with: - run-id: ${{ github.event.workflow_run.id }} + run-id: ${{ github.event.workflow_run.id || github.event.inputs.run_id }} path: ./docker-builds-artifacts merge-multiple: true github-token: ${{ secrets.GITHUB_TOKEN }} @@ -49,9 +59,8 @@ jobs: matrix: runner: [linux.rocm.gfx942.docker-cache] docker-image: [ - "${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}" - #"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}", - #"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}", + "${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}", + "${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}" #"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}" ] runs-on: "${{ matrix.runner }}" @@ -91,7 +100,7 @@ jobs: docker_image_tag=${{ matrix.docker-image }} docker_image_tag="${docker_image_tag#*:}" # Remove everything before and including first ":" docker_image_tag="${docker_image_tag%-*}" # Remove everything after and including last "-" - ref_name=${{ github.event.workflow_run.head_branch }} + ref_name=${{ github.event.workflow_run.head_branch || github.event.inputs.branch }} if [[ $ref_name =~ "release/" ]]; then ref_suffix="release" elif [[ $ref_name == "main" ]]; then From 9bca3c14d7e058fb4e30f6b82a3578aa355d77dc Mon Sep 17 00:00:00 2001 From: Jithun Nair Date: Thu, 20 Nov 2025 04:38:13 +0000 Subject: [PATCH 086/230] [ROCm][CI] Expand trunk.yml coverage for ROCm (#168162) We are expanding the test coverage on pre-submit (PR-based) trunk.yml runs for ROCm to the full list of unit tests. Consequently, we are swapping the labels (CSPs) for the rocm-mi300.yml and periodic-rocm-mi300.yml workflows to balance capacity concerns. We will be disabling the shadow workflow trunk-rocm-mi300.yml as it is not required due to this PR anymore. Fixes https://github.com/pytorch/pytorch/issues/166108 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168162 Approved by: https://github.com/jeffdaily --- .github/workflows/periodic-rocm-mi300.yml | 6 +++--- .github/workflows/rocm-mi300.yml | 12 ++++++------ .github/workflows/trunk.yml | 13 +++++++++---- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/.github/workflows/periodic-rocm-mi300.yml b/.github/workflows/periodic-rocm-mi300.yml index ce68ee8bc8e03..12a20a2993f8d 100644 --- a/.github/workflows/periodic-rocm-mi300.yml +++ b/.github/workflows/periodic-rocm-mi300.yml @@ -60,9 +60,9 @@ jobs: docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 test-matrix: | { include: [ - { config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4", owners: ["module:rocm", "oncall:distributed"] }, - { config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4", owners: ["module:rocm", "oncall:distributed"] }, - { config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4", owners: ["module:rocm", "oncall:distributed"] }, + { config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4.b", owners: ["module:rocm", "oncall:distributed"] }, + { config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4.b", owners: ["module:rocm", "oncall:distributed"] }, + { config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4.b", owners: ["module:rocm", "oncall:distributed"] }, ]} secrets: inherit diff --git a/.github/workflows/rocm-mi300.yml b/.github/workflows/rocm-mi300.yml index d20b37be20876..99059a1ff857c 100644 --- a/.github/workflows/rocm-mi300.yml +++ b/.github/workflows/rocm-mi300.yml @@ -48,12 +48,12 @@ jobs: docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, - { config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, - { config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, - { config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, - { config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, - { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" }, + { config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" }, + { config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" }, + { config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" }, + { config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" }, + { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" }, ]} secrets: inherit diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index eeba4c08a0c68..d458bde5f9d30 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -203,9 +203,15 @@ jobs: sync-tag: rocm-build test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, - { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, - { config: "distributed", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.4" }, + { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4" }, + { config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4" }, + { config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4" }, ]} secrets: inherit @@ -223,7 +229,6 @@ jobs: build-environment: linux-jammy-rocm-py3.10 docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} - tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor distributed/test_c10d_common distributed/test_c10d_nccl" secrets: inherit inductor-build: From c614128a0c1277aa7e708cd6a4b39981ee27c85c Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 19 Nov 2025 18:58:27 -0800 Subject: [PATCH 087/230] [DTensor] support Replicate -> Partial("avg") + support distribute_tensor with Partial placements (#168133) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168133 Approved by: https://github.com/ezyang --- .../tensor/debug/test_debug_mode.py | 46 +++++++++---------- test/distributed/tensor/test_api.py | 12 ++++- test/distributed/tensor/test_dtensor_ops.py | 2 +- test/distributed/tensor/test_pointwise_ops.py | 25 ++++++++++ torch/distributed/tensor/_api.py | 5 ++ torch/distributed/tensor/_ops/_math_ops.py | 10 ++++ torch/distributed/tensor/placement_types.py | 24 ++++++---- 7 files changed, 89 insertions(+), 35 deletions(-) diff --git a/test/distributed/tensor/debug/test_debug_mode.py b/test/distributed/tensor/debug/test_debug_mode.py index 0b7acebbd8aac..5d4db74b6a929 100644 --- a/test/distributed/tensor/debug/test_debug_mode.py +++ b/test/distributed/tensor/debug/test_debug_mode.py @@ -76,7 +76,7 @@ def test_debug_mode_mm(self): _c10d_functional::all_gather_into_tensor(t$2: f32[1, 32], 8, 0) -> t$3: f32[8, 32] _c10d_functional::wait_tensor(t$3: f32[8, 32]) -> t$3: f32[8, 32] aten::mm(t$4: f32[1, 8], t$3: f32[8, 32]) -> t$5: f32[1, 32] - (dt$6: f32[8, 32]| S(0)) -> dt$8: f32[]| P + (dt$6: f32[8, 32]| S(0)) -> dt$8: f32[]| P(sum) aten::sum(dt$6: f32[8, 32]| S(0)) aten::sum(t$5: f32[1, 32]) -> t$7: f32[]""", ) @@ -179,8 +179,8 @@ def test_debug_mode_backward(self): (dt: f32[8, 8]| S(0)) aten::sum(dt: f32[8, 8]| S(0)) aten::sum(t: f32[1, 8]) - torch._tensor.backward(dt: f32[]| P, gradient=None, retain_graph=None, create_graph=False, inputs=None) - aten::ones_like(dt: f32[]| P, pin_memory=False, memory_format=torch.preserve_format) + torch._tensor.backward(dt: f32[]| P(sum), gradient=None, retain_graph=None, create_graph=False, inputs=None) + aten::ones_like(dt: f32[]| P(sum), pin_memory=False, memory_format=torch.preserve_format) aten::ones_like(t: f32[], pin_memory=False, memory_format=torch.preserve_format) aten::expand(dt: f32[]| R, [8, 8]) aten::expand(t: f32[], [8, 8]) @@ -189,9 +189,9 @@ def test_debug_mode_backward(self): aten::clone(t: f32[8, 1]) aten::_to_copy(t: f32[8, 1], dtype=torch.float32, layout=torch.strided, device=cpu) redistribute_input(t: f32[8, 8], trace: R->S(0)) - aten::detach(t: f32[8, 1]) aten::split.Tensor(t: f32[8, 8], 1) aten::clone(t: f32[1, 8]) + aten::detach(t: f32[8, 1]) aten::_to_copy(t: f32[1, 8], dtype=torch.float32, layout=torch.strided, device=cpu) aten::detach(t: f32[1, 8])""", ) @@ -253,38 +253,38 @@ def test_debug_mode_einsum(self): self.assertExpectedInline( debug_mode.debug_string(), """\ - torch.functional.einsum(bld,dnh->blnh, dt: f32[16, 6, 8]| PR, dt: f32[8, 4, 4]| RP) - aten::unsqueeze(dt: f32[16, 6, 8]| PR, 3) + torch.functional.einsum(bld,dnh->blnh, dt: f32[16, 6, 8]| P(sum)R, dt: f32[8, 4, 4]| RP(sum)) + aten::unsqueeze(dt: f32[16, 6, 8]| P(sum)R, 3) aten::unsqueeze(t: f32[16, 6, 8], 3) - aten::unsqueeze(dt: f32[16, 6, 8, 1]| PR, 4) + aten::unsqueeze(dt: f32[16, 6, 8, 1]| P(sum)R, 4) aten::unsqueeze(t: f32[16, 6, 8, 1], 4) - aten::permute(dt: f32[16, 6, 8, 1, 1]| PR, [0, 1, 3, 4, 2]) + aten::permute(dt: f32[16, 6, 8, 1, 1]| P(sum)R, [0, 1, 3, 4, 2]) aten::permute(t: f32[16, 6, 8, 1, 1], [0, 1, 3, 4, 2]) - aten::unsqueeze(dt: f32[8, 4, 4]| RP, 3) + aten::unsqueeze(dt: f32[8, 4, 4]| RP(sum), 3) aten::unsqueeze(t: f32[8, 4, 4], 3) - aten::unsqueeze(dt: f32[8, 4, 4, 1]| RP, 4) + aten::unsqueeze(dt: f32[8, 4, 4, 1]| RP(sum), 4) aten::unsqueeze(t: f32[8, 4, 4, 1], 4) - aten::permute(dt: f32[8, 4, 4, 1, 1]| RP, [3, 4, 1, 2, 0]) + aten::permute(dt: f32[8, 4, 4, 1, 1]| RP(sum), [3, 4, 1, 2, 0]) aten::permute(t: f32[8, 4, 4, 1, 1], [3, 4, 1, 2, 0]) - aten::permute(dt: f32[16, 6, 1, 1, 8]| PR, [0, 1, 4, 2, 3]) + aten::permute(dt: f32[16, 6, 1, 1, 8]| P(sum)R, [0, 1, 4, 2, 3]) aten::permute(t: f32[16, 6, 1, 1, 8], [0, 1, 4, 2, 3]) - aten::view(dt: f32[16, 6, 8, 1, 1]| PR, [1, 96, 8]) + aten::view(dt: f32[16, 6, 8, 1, 1]| P(sum)R, [1, 96, 8]) aten::view(t: f32[16, 6, 8, 1, 1], [1, 96, 8]) - aten::permute(dt: f32[1, 1, 4, 4, 8]| RP, [4, 2, 3, 0, 1]) + aten::permute(dt: f32[1, 1, 4, 4, 8]| RP(sum), [4, 2, 3, 0, 1]) aten::permute(t: f32[1, 1, 4, 4, 8], [4, 2, 3, 0, 1]) - aten::view(dt: f32[8, 4, 4, 1, 1]| RP, [1, 8, 16]) + aten::view(dt: f32[8, 4, 4, 1, 1]| RP(sum), [1, 8, 16]) aten::view(t: f32[8, 4, 4, 1, 1], [1, 8, 16]) - aten::bmm(dt: f32[1, 96, 8]| PR, dt: f32[1, 8, 16]| RP) - redistribute_input(0, PR -> S(2)[0]S(2)[1]) - redistribute_input(t: f32[1, 96, 8], trace: PR->S(2)R->S(2)[0]S(2)[1]) + aten::bmm(dt: f32[1, 96, 8]| P(sum)R, dt: f32[1, 8, 16]| RP(sum)) + redistribute_input(0, P(sum)R -> S(2)[0]S(2)[1]) + redistribute_input(t: f32[1, 96, 8], trace: P(sum)R->S(2)R->S(2)[0]S(2)[1]) aten::chunk(t: f32[1, 96, 8], 4, 2) aten::cat(['t: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]']) _c10d_functional::reduce_scatter_tensor(t: f32[4, 96, 2], sum, 4, 1) _c10d_functional::wait_tensor(t: f32[1, 96, 2]) aten::chunk(t: f32[1, 96, 2], 2, 2) aten::clone(t: f32[1, 96, 1]) - redistribute_input(1, RP -> S(1)[0]S(1)[1]) - redistribute_input(t: f32[1, 8, 16], trace: RP->S(1)P->S(1)[0]S(1)[1]) + redistribute_input(1, RP(sum) -> S(1)[0]S(1)[1]) + redistribute_input(t: f32[1, 8, 16], trace: RP(sum)->S(1)P(sum)->S(1)[0]S(1)[1]) aten::chunk(t: f32[1, 8, 16], 4, 1) aten::clone(t: f32[1, 2, 16]) aten::chunk(t: f32[1, 2, 16], 2, 1) @@ -292,11 +292,11 @@ def test_debug_mode_einsum(self): _c10d_functional::reduce_scatter_tensor(t: f32[2, 1, 16], sum, 2, 3) _c10d_functional::wait_tensor(t: f32[1, 1, 16]) aten::bmm(t: f32[1, 96, 1], t: f32[1, 1, 16]) - aten::view(dt: f32[1, 96, 16]| PP, [16, 6, 1, 4, 4]) + aten::view(dt: f32[1, 96, 16]| P(sum)P(sum), [16, 6, 1, 4, 4]) aten::view(t: f32[1, 96, 16], [16, 6, 1, 4, 4]) - aten::permute(dt: f32[16, 6, 1, 4, 4]| PP, [0, 1, 3, 4, 2]) + aten::permute(dt: f32[16, 6, 1, 4, 4]| P(sum)P(sum), [0, 1, 3, 4, 2]) aten::permute(t: f32[16, 6, 1, 4, 4], [0, 1, 3, 4, 2]) - aten::view(dt: f32[16, 6, 4, 4, 1]| PP, [16, 6, 4, 4]) + aten::view(dt: f32[16, 6, 4, 4, 1]| P(sum)P(sum), [16, 6, 4, 4]) aten::view(t: f32[16, 6, 4, 4, 1], [16, 6, 4, 4])""", ) diff --git a/test/distributed/tensor/test_api.py b/test/distributed/tensor/test_api.py index e1790f4829907..12897ee822e87 100644 --- a/test/distributed/tensor/test_api.py +++ b/test/distributed/tensor/test_api.py @@ -79,7 +79,13 @@ def test_distribute_tensor_rank(self): dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_minus_spec) self.assertEqual(dist_tensor.placements[0].dim, 1) - placement_combs = [[Shard(0)], [Shard(1)], [Replicate()]] + placement_combs = [ + [Shard(0)], + [Shard(1)], + [Replicate()], + [Partial(reduce_op="sum")], + [Partial(reduce_op="avg")], + ] if not self.is_local_tensor_enabled: # test src_data_rank == 1 @@ -125,6 +131,10 @@ def test_distribute_tensor_errors(self): shard_spec = [Shard(0)] distribute_tensor(tensor_to_distribute, device_mesh, shard_spec) + with self.assertRaisesRegex(ValueError, "conversion is not supported"): + new_spec = [Replicate(), Partial(reduce_op="prod")] + distribute_tensor(tensor_to_distribute, device_mesh, new_spec) + with self.assertRaisesRegex(RuntimeError, "distribute leaf tensor"): shard_spec = [Shard(0)] global_tensor = torch.randn(*tensor_shape, requires_grad=True) diff --git a/test/distributed/tensor/test_dtensor_ops.py b/test/distributed/tensor/test_dtensor_ops.py index 5880efb3734bf..c3eb791fd0e41 100644 --- a/test/distributed/tensor/test_dtensor_ops.py +++ b/test/distributed/tensor/test_dtensor_ops.py @@ -725,7 +725,7 @@ def run_mean(self): self.assertEqual(full_tensor, tensor.mean(dim=reduce_dim)) if is_evenly_shardable: - self.assertTrue("P->R" in debug_mode.debug_string()) + self.assertTrue("P(avg)->R" in debug_mode.debug_string()) else: self.assertTrue("S(0)->R" in debug_mode.debug_string()) diff --git a/test/distributed/tensor/test_pointwise_ops.py b/test/distributed/tensor/test_pointwise_ops.py index d2c4e7dea06b4..9d35e10f24ba8 100644 --- a/test/distributed/tensor/test_pointwise_ops.py +++ b/test/distributed/tensor/test_pointwise_ops.py @@ -148,6 +148,30 @@ def test_partial_add(self): d_3 = d_1 + d_2 self.assertTrue(d_3._spec.placements[0].is_partial()) + def test_partial_replicate_add(self): + device_mesh = self.build_device_mesh() + comm_mode = CommDebugMode() + + for reduce_op in ("sum", "avg"): + d_1 = DTensor.from_local( + torch.rand(2, 2), + device_mesh, + [Partial(reduce_op=reduce_op)], + ) + d_2 = DTensor.from_local( + torch.rand(2, 1), + device_mesh, + [Replicate()], + run_check=True, + ) + + with comm_mode: + d_3 = d_1 + d_2 + + self.assertEqual(comm_mode.get_total_counts(), 0) + self.assertEqual(d_3.placements, (Partial(reduce_op=reduce_op),)) + self.assertEqual(d_3.full_tensor(), d_1.full_tensor() + d_2.full_tensor()) + def test_activations(self): device_mesh = self.build_device_mesh() self._run_sharded_elementwise_ops( @@ -247,6 +271,7 @@ def test_dropout_backward(self): ), ) + @skip_unless_torch_gpu def test_dropout_errors(self): device_mesh = self.build_device_mesh() with self.assertRaisesRegex(RuntimeError, "supported"): diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py index f21ef72533658..3946f9249d0de 100644 --- a/torch/distributed/tensor/_api.py +++ b/torch/distributed/tensor/_api.py @@ -818,6 +818,11 @@ def distribute_tensor( local_tensor = Replicate._make_replicate_tensor( local_tensor, device_mesh, idx, src_data_rank ) + elif isinstance(placement, Partial): + local_tensor = Replicate._make_replicate_tensor( + local_tensor, device_mesh, idx, src_data_rank + ) + local_tensor = placement._partition_value(local_tensor, device_mesh, idx) else: raise RuntimeError( f"Trying to distribute tensor with unsupported placements {placement} on device mesh dimension {idx}!" diff --git a/torch/distributed/tensor/_ops/_math_ops.py b/torch/distributed/tensor/_ops/_math_ops.py index 63a352adc8dc7..7ccca23c0dab5 100644 --- a/torch/distributed/tensor/_ops/_math_ops.py +++ b/torch/distributed/tensor/_ops/_math_ops.py @@ -163,6 +163,16 @@ def __eq__(self, other: object) -> bool: def __hash__(self) -> int: return 1 + hash(self.norm_type) + def __repr__(self) -> str: + """ + machine readable representation of the _NormPartial placement + """ + return f"_NormPartial(reduce_op={self.reduce_op}, norm_type={self.norm_type})" + + def __str__(self) -> str: + """human readable representation of the _NormPartial placement""" + return f"_NormP({self.reduce_op}, {self.norm_type})" + def _infer_reduction_dims(dims_arg: object, ndim: int) -> Optional[list[int]]: if dims_arg is None: diff --git a/torch/distributed/tensor/placement_types.py b/torch/distributed/tensor/placement_types.py index 2cf6e572dcdf7..726abc5971376 100644 --- a/torch/distributed/tensor/placement_types.py +++ b/torch/distributed/tensor/placement_types.py @@ -816,14 +816,18 @@ def _partition_value( # Partial placement contract #3: # _partition_value: partition the value of a replicated tensor on the mesh dimension - # _partition_value is the conjugate operation of _reduce_value - # - i.e. _partition_value on a sum reduce op is just a division operation - # - the _reduce_value on a sum reduce op would just be a sum(allreduce) operation - # TODO: if the reduce_op is min/max, etc. the _partition_value should be a - # different operation - assert self.reduce_op == "sum", "only support replicate to PartialSUM for now!" + # _partition_value is the conjugate operation of _reduce_value, e.g. + # - _partition_value on a sum reduce op is just a division operation + # - _reduce_value on a sum reduce op would just be a sum(allreduce) operation num_chunks = mesh.size(mesh_dim=mesh_dim) - return tensor / num_chunks + if self.reduce_op == "sum": + return tensor / num_chunks + elif self.reduce_op in ("avg", "min", "max"): + return tensor + else: + raise ValueError( + f"Replicate to Partial({self.reduce_op}) conversion is not supported." + ) def __hash__(self) -> int: return 1 + hash(self.reduce_op) @@ -838,7 +842,7 @@ def __str__(self) -> str: """ human readable representation of the Partial placement """ - return "P" + return f"P({self.reduce_op})" # We keep the old _Partial name for a while for BC reason @@ -982,10 +986,10 @@ def __repr__(self) -> str: """ machine readable representation of the MaskPartial placement """ - return f"MaskPartial(offset_shape={self.offset_shape}, offset_dim={self.offset_dim})" + return f"MaskPartial(reduce_op={self.reduce_op}, offset_shape={self.offset_shape}, offset_dim={self.offset_dim})" def __str__(self) -> str: """ human readable representation of the MaskPartial placement """ - return "MaskP" + return f"MaskP({self.reduce_op}, {self.offset_shape}, {self.offset_dim})" From 6fa7791bab2785bdcae096bd2f80b2528112b859 Mon Sep 17 00:00:00 2001 From: Shuai Yang Date: Thu, 20 Nov 2025 06:41:35 +0000 Subject: [PATCH 088/230] Reland"Fix different seq length (#167481)" (#168144) Differential Revision: D87413883 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168144 Approved by: https://github.com/eellison --- torch/_inductor/scheduler.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index b7f36aa306a43..45cf9e409b656 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2714,12 +2714,22 @@ def _init(self, nodes: list[ir.Operation]) -> None: if ( used_non_deterministic_runtime_estimations() and config_comms.runtime_estimations_align_across_all_distributed_ranks - ): - from .comms import ( - align_runtime_estimations_across_all_distributed_ranks, + and ( + config.runtime_estimations_mms_benchmark + or config_comms.runtime_estimations_use_nccl_lib_estimations ) + ): + has_collectives = False + for node in self.nodes: + if is_collective(node.node): + has_collectives = True + break + if has_collectives: + from .comms import ( + align_runtime_estimations_across_all_distributed_ranks, + ) - align_runtime_estimations_across_all_distributed_ranks(self.nodes) + align_runtime_estimations_across_all_distributed_ranks(self.nodes) from torch._logging import trace_structured @@ -2742,8 +2752,11 @@ def _init(self, nodes: list[ir.Operation]) -> None: self.process_grouped_nodes() if ( + # pyrefly: ignore[unbound-name] config.graph_partition + # pyrefly: ignore[unbound-name] and config.triton.cudagraphs + # pyrefly: ignore[unbound-name] and config.triton.reorder_for_reducing_graph_partitions ): self.nodes = self.maybe_reorder_for_minimizing_partition(self.nodes) @@ -2755,6 +2768,7 @@ def _init(self, nodes: list[ir.Operation]) -> None: self.insert_memory_check_nodes() log_ir_post_fusion(self.nodes) + # pyrefly: ignore[unbound-name] V.debug.graph_diagram(self.nodes) self.debug_draw_graph() From 25a64df8d4bff6c79214383e4d01c202e30c04f8 Mon Sep 17 00:00:00 2001 From: skishore Date: Thu, 20 Nov 2025 07:23:10 +0000 Subject: [PATCH 089/230] [ROCm] add torch.version.rocm, distinct from torch.version.hip (#168097) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Historically, HIP and ROCm versions were interchangeable, but moving forward these versions are allowed to diverge. ROCm version represents the full ROCm software stack, while HIP is a component of the ROCm stack. Issue #166068 was fixed by [switching from using HIP_VERSION to ROCM_VERSION_DEV](https://github.com/pytorch/pytorch/pull/166336). However, this broke the build of ROCm apex because the hip version from `hipcc --version` no longer matched `torch.version.hip`. This highlights the need for both versions to be exposed. Bitsandbytes has also been impacted by the change in behavior of `torch.version.hip`: https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1799#issuecomment-3534269635 The solution is to fix the `torch.version.hip` so that it uses the hipcc header values and removes the trailing hash code. In addition, `torch.version.rocm` variable is created to store the ROCm version. ## Technical Details ### Fix torch.version.hip HIP_VERSION variable is computed in https://github.com/ROCm/hip/blob/develop/cmake/FindHIP.cmake. This runs hipcc –version and extracts the output of HIP version line, e.g., ``` hipcc --version HIP version: 7.1.25421-32f9fa6ca5 ``` The HIP_VERSION variable may contain a hash code at the end. This trailing hashcode is removed from the HIP_VERSION variable so that the torch.version.hip can be parsed by packaging version parse method, e.g., ``` import torch from packaging import version print(version.parse(torch.version.hip)) ``` ### Add torch.version.rocm Code changes: - Add rocm variable to torch/version.py.tpl - Add code to write rocm variable in tools/generate_torch_version.py - Write rocm version in installation process - torch/CMakeLists.txt ## Testing Tested on a preview of ROCm 7.2. Successfully built pytorch and apex. Tested above parsing torch.version.hip code. ``` >>> import torch >>> torch.version.hip '7.1.25421' >>> torch.version.rocm '7.2.0' ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/168097 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily --- cmake/public/LoadHIP.cmake | 12 +++++++++++- tools/generate_torch_version.py | 5 ++++- torch/CMakeLists.txt | 3 ++- torch/version.py.tpl | 1 + 4 files changed, 18 insertions(+), 3 deletions(-) diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index 018bca837a5a8..7ecaff5109f42 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -83,8 +83,18 @@ find_package_and_print_version(HIP 1.0 MODULE) if(HIP_FOUND) set(PYTORCH_FOUND_HIP TRUE) find_package_and_print_version(hip REQUIRED CONFIG) + if(HIP_VERSION) + # Check if HIP_VERSION contains a dash (e.g., "7.1.25421-32f9fa6ca5") + # and strip everything after it to get clean numeric version + string(FIND "${HIP_VERSION}" "-" DASH_POS) + if(NOT DASH_POS EQUAL -1) + string(SUBSTRING "${HIP_VERSION}" 0 ${DASH_POS} HIP_VERSION_CLEAN) + set(HIP_VERSION "${HIP_VERSION_CLEAN}") + endif() + message("HIP version: ${HIP_VERSION}") +endif() - # The rocm-core package was only introduced in ROCm 6.4, so we make it optional. +# The rocm-core package was only introduced in ROCm 6.4, so we make it optional. find_package(rocm-core CONFIG) # Some old consumer HIP SDKs do not distribute rocm_version.h, so we allow diff --git a/tools/generate_torch_version.py b/tools/generate_torch_version.py index d1004cdc3a955..ff19dadcfdc02 100644 --- a/tools/generate_torch_version.py +++ b/tools/generate_torch_version.py @@ -119,6 +119,7 @@ def get_torch_version(sha: str | None = None) -> str: ) parser.add_argument("--cuda-version", "--cuda_version", type=str) parser.add_argument("--hip-version", "--hip_version", type=str) + parser.add_argument("--rocm-version", "--rocm_version", type=str) parser.add_argument("--xpu-version", "--xpu_version", type=str) args = parser.parse_args() @@ -126,6 +127,7 @@ def get_torch_version(sha: str | None = None) -> str: assert args.is_debug is not None args.cuda_version = None if args.cuda_version == "" else args.cuda_version args.hip_version = None if args.hip_version == "" else args.hip_version + args.rocm_version = None if args.rocm_version == "" else args.rocm_version args.xpu_version = None if args.xpu_version == "" else args.xpu_version pytorch_root = Path(__file__).parent.parent @@ -141,7 +143,7 @@ def get_torch_version(sha: str | None = None) -> str: with open(version_path, "w") as f: f.write("from typing import Optional\n\n") f.write( - "__all__ = ['__version__', 'debug', 'cuda', 'git_version', 'hip', 'xpu']\n" + "__all__ = ['__version__', 'debug', 'cuda', 'git_version', 'hip', 'rocm', 'xpu']\n" ) f.write(f"__version__ = '{version}'\n") # NB: This is not 100% accurate, because you could have built the @@ -151,4 +153,5 @@ def get_torch_version(sha: str | None = None) -> str: f.write(f"cuda: Optional[str] = {repr(args.cuda_version)}\n") f.write(f"git_version = {repr(sha)}\n") f.write(f"hip: Optional[str] = {repr(args.hip_version)}\n") + f.write(f"rocm: Optional[str] = {repr(args.rocm_version)}\n") f.write(f"xpu: Optional[str] = {repr(args.xpu_version)}\n") diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index d92b9e19a76c5..c7a43f30e49d5 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -490,7 +490,8 @@ add_custom_target( "${Python_EXECUTABLE}" "${TOOLS_PATH}/generate_torch_version.py" --is-debug=${TORCH_VERSION_DEBUG} --cuda-version=${CUDA_VERSION} - --hip-version=${ROCM_VERSION_DEV} + --hip-version=${HIP_VERSION_CLEAN} + --rocm-version=${ROCM_VERSION_DEV} --xpu-version=${SYCL_COMPILER_VERSION} BYPRODUCTS ${TORCH_SRC_DIR}/version.py COMMENT "Regenerating version file..." diff --git a/torch/version.py.tpl b/torch/version.py.tpl index 1b7eab07ac949..ee37a91b7ffdc 100644 --- a/torch/version.py.tpl +++ b/torch/version.py.tpl @@ -4,6 +4,7 @@ cuda = '{{CUDA_VERSION}}' # TODO: use workspace status to stamp the correct version git_version = "" hip = None +rocm = None # This is a gross monkey-patch hack that depends on the order of imports # in torch/__init__.py From 7ffa5111a054e6c1610256fcd6c01ae4fad7b17b Mon Sep 17 00:00:00 2001 From: mansiag05 Date: Thu, 20 Nov 2025 07:33:04 +0000 Subject: [PATCH 090/230] [Distributed] Optimize ND shard overlap detection (#167073) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fixes high quadratic cost in `validate_non_overlapping_shards_metadata` when shard count is large by replacing the O(n²) nested-loop scan in `_find_nd_overlapping_shards` with a sweep-line pass giving O(n log n) behavior for ND overlap detection. * Add test cases in `test_check_overlapping` covering 2D grid patterns, adjacent shards, and 3D multi-shard overlap scenarios to validate the optimized path. Fixes #166941 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167073 Approved by: https://github.com/Skylion007, https://github.com/wconstab --- .../sharding_spec/test_sharding_spec.py | 63 ++++++++++++++++ .../_shard/sharding_spec/_internals.py | 75 ++++++++++++------- 2 files changed, 109 insertions(+), 29 deletions(-) diff --git a/test/distributed/_shard/sharding_spec/test_sharding_spec.py b/test/distributed/_shard/sharding_spec/test_sharding_spec.py index 73018c1025619..37ad69075068f 100644 --- a/test/distributed/_shard/sharding_spec/test_sharding_spec.py +++ b/test/distributed/_shard/sharding_spec/test_sharding_spec.py @@ -490,6 +490,69 @@ def test_check_overlapping(self): with self.assertRaisesRegex(ValueError, "overlap"): validate_non_overlapping_shards_metadata(shards) + shards = [ + ShardMetadata( + shard_offsets=[0, 0], + shard_sizes=[5, 5], + placement="cuda:0", + ), + ShardMetadata( + shard_offsets=[0, 5], + shard_sizes=[5, 5], + placement="cuda:1", + ), + ShardMetadata( + shard_offsets=[5, 0], + shard_sizes=[5, 5], + placement="cuda:2", + ), + ShardMetadata( + shard_offsets=[5, 5], + shard_sizes=[5, 5], + placement="cuda:3", + ), + ] + validate_non_overlapping_shards_metadata(shards) + + shards = [ + ShardMetadata( + shard_offsets=[0, 0], + shard_sizes=[5, 5], + placement="cuda:0", + ), + ShardMetadata( + shard_offsets=[5, 5], + shard_sizes=[5, 5], + placement="cuda:1", + ), + ] + validate_non_overlapping_shards_metadata(shards) + + shards = [ + ShardMetadata( + shard_offsets=[0, 0, 0], + shard_sizes=[5, 5, 5], + placement="cuda:0", + ), + ShardMetadata( + shard_offsets=[5, 0, 0], + shard_sizes=[5, 5, 5], + placement="cuda:1", + ), + ShardMetadata( + shard_offsets=[10, 0, 0], + shard_sizes=[5, 5, 5], + placement="cuda:2", + ), + ShardMetadata( + shard_offsets=[10, 3, 0], + shard_sizes=[5, 5, 5], + placement="cuda:3", + ), + ] + with self.assertRaisesRegex(ValueError, "overlap"): + validate_non_overlapping_shards_metadata(shards) + # Custom ShardingSpec, an simple example to do grid sharding @dataclass diff --git a/torch/distributed/_shard/sharding_spec/_internals.py b/torch/distributed/_shard/sharding_spec/_internals.py index 26788f4054bce..9825edd352c1f 100644 --- a/torch/distributed/_shard/sharding_spec/_internals.py +++ b/torch/distributed/_shard/sharding_spec/_internals.py @@ -1,5 +1,7 @@ # mypy: allow-untyped-defs import math +import sys +from bisect import bisect_right, insort from typing import Optional from torch.distributed._shard.metadata import ShardMetadata @@ -27,31 +29,48 @@ def _check_shard_metadata_pair_overlap(shard1: ShardMetadata, shard2: ShardMetad def _find_nd_overlapping_shards( shards: list[ShardMetadata], sharded_dims: list[int] ) -> Optional[tuple[int, int]]: - # Each rank has len(sharded_dims) tuples. Each tuple represent the - # [begin, end] (inclusive) pair of that dimension. - shard_intervals = [ - [ - (s.shard_offsets[dim], s.shard_offsets[dim] + s.shard_sizes[dim] - 1) - for dim in sharded_dims - ] - for s in shards - ] - - for i in range(len(shards)): - shard_i = shard_intervals[i] - for j in range(i + 1, len(shards)): - shard_j = shard_intervals[j] - # For each dim of each shard, check if one shard resides on the other - # end of second shard with respect to that dim. As an example for a 2D - # shard, we would check if one shard is above or on the left of the - # other shard. - overlap = True - for interval_i, interval_j in zip(shard_i, shard_j): - if interval_i[0] > interval_j[1] or interval_j[0] > interval_i[1]: - overlap = False - break - if overlap: - return (i, j) + """Find overlapping shards using sweep-line algorithm.""" + if len(shards) <= 1: + return None + + dims = len(sharded_dims) + if dims == 0: + return None + + sweep_dim_idx = 0 + if dims > 1: + max_size = 0 + for i, dim in enumerate(sharded_dims): + dim_size = shards[0].shard_offsets[dim] + shards[0].shard_sizes[dim] + if dim_size > max_size: + max_size = dim_size + sweep_dim_idx = i + sweep_dim = sharded_dims[sweep_dim_idx] + + sorted_indices = sorted( + range(len(shards)), + key=lambda idx: ( + shards[idx].shard_offsets[sweep_dim], + *(shards[idx].shard_offsets[d] for d in sharded_dims if d != sweep_dim), + ), + ) + active: list[tuple[int, int]] = [] + + for idx in sorted_indices: + current = shards[idx] + start = current.shard_offsets[sweep_dim] + end = start + current.shard_sizes[sweep_dim] + + cutoff = bisect_right(active, (start, sys.maxsize)) + if cutoff: + del active[:cutoff] + + for _, other_idx in active: + other = shards[other_idx] + + if _check_shard_metadata_pair_overlap(current, other): + return (other_idx, idx) + insort(active, (end, idx)) return None @@ -112,10 +131,8 @@ def validate_non_overlapping_shards_metadata(shards: list[ShardMetadata]): # using a O(nlogn) overlapping interval algorithm. pair = _find_1d_overlapping_shards(shards, sharded_dims[0]) else: - # Shards are partitioned over more than one dimension. Fall back to - # pair-wise check. Even though O(nlogn) algorithms (line sweep) exist - # for 2D overlap, the implementation is not trivial and may not justify - # the time saving in most cases. + # Shards are partitioned over more than one dimension. + # Use sweep-line algorithm for O(n log n) complexity. pair = _find_nd_overlapping_shards(shards, sharded_dims) if pair: From 21c11daffad809034507a2218dca6be946892b34 Mon Sep 17 00:00:00 2001 From: hipudding Date: Thu, 20 Nov 2025 09:24:22 +0000 Subject: [PATCH 091/230] Improve OpenReg test coverage (#167819) - add failure-path tests for device, stream, memory, event APIs - cover async memcpy, pointer attributes, event timing, addTask errors - verified via cmake --build build && ctest --test-dir build Pull Request resolved: https://github.com/pytorch/pytorch/pull/167819 Approved by: https://github.com/fffrog --- .../openreg/tests/device_tests.cpp | 15 +++ .../third_party/openreg/tests/event_tests.cpp | 93 +++++++++++++++++++ .../openreg/tests/memory_tests.cpp | 76 +++++++++++++++ .../openreg/tests/stream_tests.cpp | 54 +++++++++++ 4 files changed, 238 insertions(+) diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/device_tests.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/device_tests.cpp index b7501c81d7b7c..f8fc5946fd6e8 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/device_tests.cpp +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/device_tests.cpp @@ -16,12 +16,22 @@ TEST_F(DeviceTest, GetDeviceCountValid) { EXPECT_EQ(count, 2); } +TEST_F(DeviceTest, GetDeviceCountNullptr) { + // orGetDeviceCount should reject null output pointers. + EXPECT_EQ(orGetDeviceCount(nullptr), orErrorUnknown); +} + TEST_F(DeviceTest, GetDeviceValid) { int device = -1; EXPECT_EQ(orGetDevice(&device), orSuccess); EXPECT_EQ(device, 0); } +TEST_F(DeviceTest, GetDeviceNullptr) { + // Defensive path: null output pointer must return an error. + EXPECT_EQ(orGetDevice(nullptr), orErrorUnknown); +} + TEST_F(DeviceTest, SetDeviceValid) { EXPECT_EQ(orSetDevice(1), orSuccess); @@ -38,4 +48,9 @@ TEST_F(DeviceTest, SetDeviceInvalidNegative) { EXPECT_EQ(orSetDevice(-1), orErrorUnknown); } +TEST_F(DeviceTest, SetDeviceInvalidTooLarge) { + // Device indices are 0-based and strictly less than DEVICE_COUNT (2). + EXPECT_EQ(orSetDevice(2), orErrorUnknown); +} + } // namespace diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/event_tests.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/event_tests.cpp index 416c50a863435..f45bd2690d41e 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/event_tests.cpp +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/event_tests.cpp @@ -29,6 +29,13 @@ TEST_F(EventTest, EventCreateWithFlagsTiming) { EXPECT_EQ(orEventDestroy(event), orSuccess); } +TEST_F(EventTest, EventCreationNullptr) { + // Creation APIs must fail fast on null handles to mirror CUDA semantics. + EXPECT_EQ(orEventCreate(nullptr), orErrorUnknown); + EXPECT_EQ( + orEventCreateWithFlags(nullptr, orEventEnableTiming), orErrorUnknown); +} + TEST_F(EventTest, EventRecordAndSynchronize) { orStream_t stream = nullptr; EXPECT_EQ(orStreamCreate(&stream), orSuccess); @@ -44,6 +51,23 @@ TEST_F(EventTest, EventRecordAndSynchronize) { EXPECT_EQ(orStreamDestroy(stream), orSuccess); } +TEST_F(EventTest, EventRecordInvalidArgs) { + orEvent_t event = nullptr; + EXPECT_EQ(orEventCreate(&event), orSuccess); + + orStream_t stream = nullptr; + EXPECT_EQ(orStreamCreate(&stream), orSuccess); + + // Record/sync/destroy should validate both stream and event pointers. + EXPECT_EQ(orEventRecord(nullptr, stream), orErrorUnknown); + EXPECT_EQ(orEventRecord(event, nullptr), orErrorUnknown); + EXPECT_EQ(orEventSynchronize(nullptr), orErrorUnknown); + EXPECT_EQ(orEventDestroy(nullptr), orErrorUnknown); + + EXPECT_EQ(orEventDestroy(event), orSuccess); + EXPECT_EQ(orStreamDestroy(stream), orSuccess); +} + TEST_F(EventTest, EventElapsedTime) { orStream_t stream = nullptr; EXPECT_EQ(orStreamCreate(&stream), orSuccess); @@ -70,6 +94,60 @@ TEST_F(EventTest, EventElapsedTime) { EXPECT_EQ(orEventDestroy(end), orSuccess); } +// TODO: recording events to a stream is not allowed +// if the stream and the event are not on the same device +// Uncomment this test case after the issue is fixed. +// see #167819 +TEST_F(EventTest, DISABLED_EventElapsedTimeDifferentDevicesFails) { + orStream_t stream = nullptr; + EXPECT_EQ(orStreamCreate(&stream), orSuccess); + + orEvent_t start = nullptr; + orEvent_t end = nullptr; + EXPECT_EQ(orEventCreateWithFlags(&start, orEventEnableTiming), orSuccess); + + EXPECT_EQ(orEventRecord(start, stream), orSuccess); + + // Switch device before creating the end event to force a mismatch. + EXPECT_EQ(orSetDevice(1), orSuccess); + EXPECT_EQ(orEventCreateWithFlags(&end, orEventEnableTiming), orSuccess); + EXPECT_EQ(orSetDevice(0), orSuccess); + + EXPECT_EQ(orEventRecord(end, stream), orSuccess); + EXPECT_EQ(orEventSynchronize(start), orSuccess); + EXPECT_EQ(orEventSynchronize(end), orSuccess); + + float elapsed_ms = 0.0f; + EXPECT_EQ(orEventElapsedTime(&elapsed_ms, start, end), orErrorUnknown); + + EXPECT_EQ(orEventDestroy(start), orSuccess); + EXPECT_EQ(orEventDestroy(end), orSuccess); + EXPECT_EQ(orStreamDestroy(stream), orSuccess); +} + +TEST_F(EventTest, EventElapsedTimeRequiresTimingFlag) { + orStream_t stream = nullptr; + EXPECT_EQ(orStreamCreate(&stream), orSuccess); + + orEvent_t start = nullptr; + orEvent_t end = nullptr; + EXPECT_EQ(orEventCreate(&start), orSuccess); + EXPECT_EQ(orEventCreate(&end), orSuccess); + + EXPECT_EQ(orEventRecord(start, stream), orSuccess); + EXPECT_EQ(orEventRecord(end, stream), orSuccess); + EXPECT_EQ(orEventSynchronize(start), orSuccess); + EXPECT_EQ(orEventSynchronize(end), orSuccess); + + // Without timing-enabled events, querying elapsed time must fail. + float elapsed_ms = 0.0f; + EXPECT_EQ(orEventElapsedTime(&elapsed_ms, start, end), orErrorUnknown); + + EXPECT_EQ(orEventDestroy(start), orSuccess); + EXPECT_EQ(orEventDestroy(end), orSuccess); + EXPECT_EQ(orStreamDestroy(stream), orSuccess); +} + TEST_F(EventTest, StreamWaitEvent) { orStream_t stream = nullptr; EXPECT_EQ(orStreamCreate(&stream), orSuccess); @@ -85,4 +163,19 @@ TEST_F(EventTest, StreamWaitEvent) { EXPECT_EQ(orStreamDestroy(stream), orSuccess); } +TEST_F(EventTest, StreamWaitEventInvalidArgs) { + orStream_t stream = nullptr; + EXPECT_EQ(orStreamCreate(&stream), orSuccess); + + orEvent_t event = nullptr; + EXPECT_EQ(orEventCreate(&event), orSuccess); + + // Validate both stream and event inputs for wait calls. + EXPECT_EQ(orStreamWaitEvent(nullptr, event, 0), orErrorUnknown); + EXPECT_EQ(orStreamWaitEvent(stream, nullptr, 0), orErrorUnknown); + + EXPECT_EQ(orEventDestroy(event), orSuccess); + EXPECT_EQ(orStreamDestroy(stream), orSuccess); +} + } // namespace diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/memory_tests.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/memory_tests.cpp index e36ad4c0da3ee..3a5ccb54ad85a 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/memory_tests.cpp +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/memory_tests.cpp @@ -26,6 +26,12 @@ TEST_F(MemoryTest, AllocateAndFreeHost) { EXPECT_EQ(orFreeHost(ptr), orSuccess); } +TEST_F(MemoryTest, FreeNullptrIsNoop) { + // Freeing a nullptr should behave like CUDA: treated as a no-op success. + EXPECT_EQ(orFree(nullptr), orSuccess); + EXPECT_EQ(orFreeHost(nullptr), orSuccess); +} + TEST_F(MemoryTest, AllocateNullptr) { EXPECT_EQ(orMalloc(nullptr, 4096), orErrorUnknown); EXPECT_EQ(orMallocHost(nullptr, 4096), orErrorUnknown); @@ -86,6 +92,48 @@ TEST_F(MemoryTest, MemcpyInvalidKind) { EXPECT_EQ(orFree(dev_ptr), orSuccess); } +TEST_F(MemoryTest, MemcpyInvalidCombinations) { + void *dev_src = nullptr, *dev_dst = nullptr; + EXPECT_EQ(orMalloc(&dev_src, 8), orSuccess); + EXPECT_EQ(orMalloc(&dev_dst, 8), orSuccess); + + char host_buf[8] = {}; + + // Deliberately pass mismatched kinds to ensure validation coverage. + EXPECT_EQ( + orMemcpy(host_buf, dev_src, 4, orMemcpyHostToDevice), orErrorUnknown); + EXPECT_EQ( + orMemcpy(dev_dst, host_buf, 4, orMemcpyDeviceToHost), orErrorUnknown); + EXPECT_EQ( + orMemcpy(dev_dst, dev_src, 4, orMemcpyHostToDevice), orErrorUnknown); + + EXPECT_EQ(orFree(dev_src), orSuccess); + EXPECT_EQ(orFree(dev_dst), orSuccess); +} + +TEST_F(MemoryTest, MemcpyAsyncHostToDevice) { + orStream_t stream = nullptr; + EXPECT_EQ(orStreamCreate(&stream), orSuccess); + + const char host_src[] = "async"; + char host_dst[6] = {}; + void* dev_ptr = nullptr; + EXPECT_EQ(orMalloc(&dev_ptr, sizeof(host_src)), orSuccess); + + // Async copies should complete once the stream is synchronized. + EXPECT_EQ( + orMemcpyAsync(dev_ptr, host_src, sizeof(host_src), orMemcpyHostToDevice, stream), + orSuccess); + EXPECT_EQ(orStreamSynchronize(stream), orSuccess); + EXPECT_EQ(orMemcpy( + host_dst, dev_ptr, sizeof(host_src), orMemcpyDeviceToHost), + orSuccess); + EXPECT_STREQ(host_dst, host_src); + + EXPECT_EQ(orFree(dev_ptr), orSuccess); + EXPECT_EQ(orStreamDestroy(stream), orSuccess); +} + TEST_F(MemoryTest, PointerAttributes) { void* dev_ptr = nullptr; EXPECT_EQ(orMalloc(&dev_ptr, 32), orSuccess); @@ -102,6 +150,14 @@ TEST_F(MemoryTest, PointerAttributes) { EXPECT_EQ(orFree(dev_ptr), orSuccess); } +TEST_F(MemoryTest, PointerAttributesInvalidArgs) { + // Attribute queries must fail on null inputs to avoid dereferencing. + char buffer[8] = {}; + orPointerAttributes attr{}; + EXPECT_EQ(orPointerGetAttributes(nullptr, buffer), orErrorUnknown); + EXPECT_EQ(orPointerGetAttributes(&attr, nullptr), orErrorUnknown); +} + TEST_F(MemoryTest, ProtectUnprotectDevice) { void* dev_ptr = nullptr; EXPECT_EQ(orMalloc(&dev_ptr, 64), orSuccess); @@ -112,4 +168,24 @@ TEST_F(MemoryTest, ProtectUnprotectDevice) { EXPECT_EQ(orFree(dev_ptr), orSuccess); } +TEST_F(MemoryTest, ProtectReferenceCounting) { + void* dev_ptr = nullptr; + EXPECT_EQ(orMalloc(&dev_ptr, 64), orSuccess); + + // Call unprotect/protect twice to exercise the refcount transitions. + EXPECT_EQ(orMemoryUnprotect(dev_ptr), orSuccess); + EXPECT_EQ(orMemoryUnprotect(dev_ptr), orSuccess); + EXPECT_EQ(orMemoryProtect(dev_ptr), orSuccess); + EXPECT_EQ(orMemoryProtect(dev_ptr), orSuccess); + + EXPECT_EQ(orFree(dev_ptr), orSuccess); +} + +TEST_F(MemoryTest, DoubleFreeFails) { + void* dev_ptr = nullptr; + EXPECT_EQ(orMalloc(&dev_ptr, 32), orSuccess); + EXPECT_EQ(orFree(dev_ptr), orSuccess); + EXPECT_EQ(orFree(dev_ptr), orErrorUnknown); +} + } // namespace diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/stream_tests.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/stream_tests.cpp index e91abaa1e7fe9..fbf5cb900a811 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/stream_tests.cpp +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/stream_tests.cpp @@ -21,6 +21,11 @@ TEST_F(StreamTest, StreamCreateAndDestroy) { EXPECT_EQ(orStreamDestroy(stream), orSuccess); } +TEST_F(StreamTest, StreamCreateNullptr) { + // Creation API should reject null double-pointer inputs. + EXPECT_EQ(orStreamCreate(nullptr), orErrorUnknown); +} + TEST_F(StreamTest, StreamCreateWithInvalidPriority) { orStream_t stream = nullptr; int min_p, max_p; @@ -30,6 +35,36 @@ TEST_F(StreamTest, StreamCreateWithInvalidPriority) { EXPECT_EQ(orStreamCreateWithPriority(&stream, 0, max_p + 1), orErrorUnknown); } +TEST_F(StreamTest, StreamCreateWithPriorityValidBounds) { + orStream_t stream = nullptr; + int min_p, max_p; + orDeviceGetStreamPriorityRange(&min_p, &max_p); + + // Lowest priority should be accepted. + EXPECT_EQ(orStreamCreateWithPriority(&stream, 0, min_p), orSuccess); + EXPECT_EQ(orStreamDestroy(stream), orSuccess); + + // Highest priority should also be accepted. + EXPECT_EQ(orStreamCreateWithPriority(&stream, 0, max_p), orSuccess); + EXPECT_EQ(orStreamDestroy(stream), orSuccess); +} + +TEST_F(StreamTest, StreamDestroyNullptr) { + // Destroying nullptr should follow CUDA error behavior. + EXPECT_EQ(orStreamDestroy(nullptr), orErrorUnknown); +} + +TEST_F(StreamTest, StreamGetPriority) { + orStream_t stream = nullptr; + EXPECT_EQ(orStreamCreate(&stream), orSuccess); + + int priority = -1; + EXPECT_EQ(orStreamGetPriority(stream, &priority), orSuccess); + EXPECT_EQ(priority, 0); + + EXPECT_EQ(orStreamDestroy(stream), orSuccess); +} + TEST_F(StreamTest, StreamTaskExecution) { orStream_t stream = nullptr; EXPECT_EQ(orStreamCreate(&stream), orSuccess); @@ -43,6 +78,11 @@ TEST_F(StreamTest, StreamTaskExecution) { EXPECT_EQ(orStreamDestroy(stream), orSuccess); } +TEST_F(StreamTest, AddTaskToStreamNullptr) { + // Queueing work should fail fast if the stream handle is invalid. + EXPECT_EQ(openreg::addTaskToStream(nullptr, [] {}), orErrorUnknown); +} + TEST_F(StreamTest, StreamQuery) { orStream_t stream = nullptr; EXPECT_EQ(orStreamCreate(&stream), orSuccess); @@ -76,4 +116,18 @@ TEST_F(StreamTest, DeviceSynchronize) { EXPECT_EQ(orStreamDestroy(stream2), orSuccess); } +TEST_F(StreamTest, DeviceSynchronizeWithNoStreams) { + // Even without registered streams, device sync should succeed. + EXPECT_EQ(orDeviceSynchronize(), orSuccess); +} + +TEST_F(StreamTest, StreamPriorityRange) { + int min_p = -1; + int max_p = -1; + // OpenReg currently exposes only one priority level; verify the fixed range. + EXPECT_EQ(orDeviceGetStreamPriorityRange(&min_p, &max_p), orSuccess); + EXPECT_EQ(min_p, 0); + EXPECT_EQ(max_p, 0); +} + } // namespace From a6b63835eefe7f78f09e7dc1ed33e5ee53f2ae44 Mon Sep 17 00:00:00 2001 From: Usamah Zaheer Date: Thu, 20 Nov 2025 11:38:31 +0000 Subject: [PATCH 092/230] [ARM] Improve LLM performance & mem usage using int4-bf16 KleidiAI kernels (#158250) Co-authored-by: Nikhil Gupta [nikhil.gupta2@arm.com](mailto:nikhil.gupta2@arm.com) This PR enables the use of KleidiAI INT4 kernels that directly produce BF16 outputs within PyTorch to boost LLM prefill & decode performance **This change improves decode throughput by ~15% & reduces memory required to inference the model by 50%** ### Benchmark Setup ``` Model: meta-llama/Llama-3.1-8B Test Platform: Neoverse V2 ``` ### Detailed Results | Metric | With `--compile` | Without `--compile` | |----------------------------------|---------------------------|---------------------------| | Quantization Scheme | INT4 symmetric channelwise | INT4 symmetric channelwise | | Input Precision | BF16 | BF16 | | Number of Layers Quantized | 32 | 32 | | Average Compression Ratio | 87.49% | 87.49% | | Total Quantization Time (s) | 9.62 | 10.32 | | Compile Time (First) (s) | 134.48 | 1.69 | | Compile Time (Second) (s) | 80.44 | 1.60 | | Compile Time (Subsequent) (s) | 0.19 | 0.22 | | Prefill Tokens | 54 | 54 | | Decoded Tokens | 33 | 33 | | Prefill Time (s) | 0.19 | 0.22 | | Decode Time (s) | 0.76 | 1.38 | | E2E Generation Time (s) | 0.95 | 1.60 | | Prefill Throughput (tokens/s) | 288.13 | 249.91 | | Decode Throughput (tokens/s) | 43.42 | 23.83 | Pull Request resolved: https://github.com/pytorch/pytorch/pull/158250 Approved by: https://github.com/malfet, https://github.com/aditew01, https://github.com/fadara01 Co-authored-by: Nikhil Gupta Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> --- aten/src/ATen/native/LinearAlgebra.cpp | 4 +- aten/src/ATen/native/cpu/int4mm_kernel.cpp | 343 +++++++++++++----- aten/src/ATen/native/kleidiai/kai_kernels.cpp | 200 ++++++++-- aten/src/ATen/native/kleidiai/kai_kernels.h | 3 +- aten/src/ATen/native/kleidiai/kai_pack.h | 9 +- .../native/kleidiai/kai_ukernel_interface.cpp | 34 ++ .../native/kleidiai/kai_ukernel_interface.h | 89 ++++- test/inductor/test_torchinductor.py | 108 +++++- torch/_meta_registrations.py | 11 +- 9 files changed, 664 insertions(+), 137 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 169f340e955d6..2cc7cf913cdcb 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -3554,9 +3554,9 @@ Tensor _dyn_quant_matmul_4bit_cpu( const int64_t out_features) { auto M = inp.size(0); TORCH_CHECK( - inp.dtype() == kFloat, + inp.dtype() == kFloat || (inp.dtype() == kBFloat16 && block_size == in_features), __func__, - " : expect input to be 32-bit float tensor."); + " : expect input to be float32 or bfloat16 tensor."); TORCH_CHECK( block_size == in_features || (!(block_size % 32) && !(in_features % block_size)), diff --git a/aten/src/ATen/native/cpu/int4mm_kernel.cpp b/aten/src/ATen/native/cpu/int4mm_kernel.cpp index 33aae4fbf27a5..1ffaa7bcd90b7 100644 --- a/aten/src/ATen/native/cpu/int4mm_kernel.cpp +++ b/aten/src/ATen/native/cpu/int4mm_kernel.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -793,6 +794,139 @@ bool can_use_kleidiai( } #endif +static void ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16( + size_t m, + size_t n, + size_t k, + const uint16_t* lhs_bf16, + const uint8_t* rhs_qs4cx, + const float* rhs_scales, + uint16_t* dst_bf16, + float scalar_min, + float scalar_max, + const float* bias) { + // Roundup lambda for internal stride calculations + auto roundup = [](size_t a, size_t b) { return ((a + b - 1) / b) * b; }; + + // Cast bfloat16 to float32 inline + auto cast_bf16_to_f32 = [](uint16_t bf16_val) { + uint32_t tmp = static_cast(bf16_val) << 16; + float f; + std::memcpy(&f, &tmp, sizeof(f)); + return f; + }; + + // Cast float32 to bfloat16 inline + auto cast_f32_to_bf16 = [](float f) { + uint32_t bits; + std::memcpy(&bits, &f, sizeof(bits)); + return static_cast(bits >> 16); + }; + + // Quantization pack lambda (channelwise QA8DX) + auto quant_pack_8bit_channelwise = + [&](size_t M, size_t K, const uint16_t* src_bf16, int8_t* dst_qa8dx) { + constexpr int8_t kI8Min = std::numeric_limits::lowest(); + constexpr int8_t kI8Max = std::numeric_limits::max(); + + const size_t dst_stride = + K * sizeof(int8_t) + sizeof(float) + sizeof(int32_t); + for (size_t i = 0; i < M; ++i) { + const uint16_t* row_ptr = src_bf16 + i * K; + // find min/max + float mn = FLT_MAX, mx = -FLT_MAX; + for (size_t j = 0; j < K; ++j) { + float v = cast_bf16_to_f32(row_ptr[j]); + mn = std::min(mn, v); + mx = std::max(mx, v); + } + float rmin = std::min(0.0f, mn); + float rmax = std::max(0.0f, mx); + constexpr float qmin = static_cast(kI8Min); + constexpr float qmax = static_cast(kI8Max); + float scale = (rmin == rmax) ? 1.f : (qmax - qmin) / (rmax - rmin); + float recip = scale ? 1.0f / scale : 0.0f; + int32_t zp; + float des_min = rmin * scale; + float des_max = rmax * scale; + float err_min = qmin + des_min; + float err_max = qmax + des_max; + float zp_f = + (err_min + err_max) > 0 ? qmin - des_min : qmax - des_max; + zp_f = std::clamp(zp_f, qmin, qmax); + zp = std::lrintf(zp_f); + int8_t* out_ptr = dst_qa8dx + i * dst_stride; + // store header + *reinterpret_cast(out_ptr) = recip; + *reinterpret_cast(out_ptr + sizeof(float)) = -zp; + out_ptr += sizeof(float) + sizeof(int32_t); + // quantize + for (size_t j = 0; j < K; ++j) { + float v = cast_bf16_to_f32(row_ptr[j]); + int32_t q = static_cast(std::round(v * scale)) + zp; + q = std::clamp( + q, static_cast(kI8Min), static_cast(kI8Max)); + *out_ptr++ = static_cast(q); + } + } + }; + + // MatMul lambda (MXN x MXK -> MNXK BF16) + auto matmul_kernel = [&](size_t M, + size_t N, + size_t K, + const int8_t* lhs, + const uint8_t* rhs, + const float* scales, + uint16_t* dst, + float lo, + float hi) { + const size_t lhs_stride = + K * sizeof(int8_t) + sizeof(float) + sizeof(int32_t); + const size_t rhs_stride = roundup(K, 2) / 2; + for (size_t i = 0; i < M; ++i) { + const int8_t* lhs_row = lhs + i * lhs_stride; + for (size_t j = 0; j < N; ++j) { + int32_t acc = 0; + const int8_t* lptr = lhs_row; + const uint8_t* rptr = rhs + j * rhs_stride; + float lhs_scale = *reinterpret_cast(lptr); + int32_t lhs_off = + *reinterpret_cast(lptr + sizeof(float)); + lptr += sizeof(float) + sizeof(int32_t); + for (size_t t = 0; t < K; ++t) { + int32_t lv = static_cast(lptr[t]); + uint8_t bv = rptr[t / 2]; + int32_t rv = ((t & 1) == 0) ? (static_cast(bv & 0xF) - 8) + : (static_cast(bv >> 4) - 8); + acc += lv * rv + lhs_off * rv; + } + float res = static_cast(acc) * scales[j] * lhs_scale; + if (bias) { + res += bias[j]; + } + res = std::clamp(res, lo, hi); + *dst++ = cast_f32_to_bf16(res); + } + } + }; + + // allocate and run + std::unique_ptr packed( + new int8_t[m * (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t))]); + quant_pack_8bit_channelwise(m, k, lhs_bf16, packed.get()); + matmul_kernel( + m, + n, + k, + packed.get(), + rhs_qs4cx, + rhs_scales, + dst_bf16, + scalar_min, + scalar_max); +} + /** * The Int4 quantized weights must be represented as a uint8 tensor * For matrix multiplication with a weight shape of (N x K) @@ -819,21 +953,21 @@ void dyn_quant_pack_4bit_weight_kernel( #if AT_KLEIDIAI_ENABLED() if (can_use_kleidiai(scales_zeros, K, block_size)) { const int64_t weight_packed_size = - kleidiai::kai_pack_rhs_int4_size(N, K, block_size); + kleidiai::kai_pack_rhs_int4_size(N, K, block_size, weights.scalar_type()); packed_weights.resize_({weight_packed_size}); kleidiai::kai_pack_int4_rhs( packed_weights, weights, scales_zeros, bias, N, K, block_size); } else #endif { - TORCH_CHECK( - bias.has_value() == 0, - __func__, - " : Bias is unsupported in reference implementation"); packed_weights = packed_weights.to(kFloat); - auto weight_reshaped = weights.view({-1}).to(kFloat); - auto scales_zeros_reshaped = scales_zeros.view({-1}).to(kFloat); - auto res = at::cat({weight_reshaped, scales_zeros_reshaped}, 0); + auto weight_reshaped = weights.reshape({-1}).to(kFloat); + auto scales_zeros_reshaped = scales_zeros.reshape({-1}).to(kFloat); + std::vector tensors_to_cat = {weight_reshaped, scales_zeros_reshaped}; + if (bias.has_value()) { + tensors_to_cat.push_back(bias.value().view({-1}).to(kFloat)); + } + auto res = at::cat(tensors_to_cat, 0); packed_weights.resize_(res.sizes()).copy_(res); } } @@ -847,7 +981,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel( const float* rhs_scales_f32, float* dst_f32, float scalar_min, - float scalar_max) { + float scalar_max, + const float* bias) { const size_t input_size_8bit = m * (k + sizeof(int32_t) + sizeof(float)); auto lhs_qa8dx_buffer = std::make_unique(input_size_8bit); @@ -857,6 +992,9 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel( // required format for matmul auto input_quant_pack_8bit_channelwise = [&](size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) { + constexpr int8_t kI8Min = std::numeric_limits::lowest(); + constexpr int8_t kI8Max = std::numeric_limits::max(); + const size_t dst_stride = (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t)); @@ -877,8 +1015,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel( } // Maximum/minimum int8 values - const float qmin = (float)INT8_MIN; - const float qmax = (float)INT8_MAX; + constexpr float qmin = static_cast(kI8Min); + constexpr float qmax = static_cast(kI8Max); const float rmin0 = std::min(0.0f, min0); const float rmax0 = std::max(0.0f, max0); @@ -904,7 +1042,7 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel( zero_point0 = std::min(zero_point0, qmax); // Round to nearest integer - const int32_t nudged_zero_point0 = lrintf(zero_point0); + const int32_t nudged_zero_point0 = std::lrintf(zero_point0); int8_t* dst_ptr = lhs_qa8dx + m_idx * dst_stride; @@ -922,8 +1060,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel( int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0)); v0_s32 = v0_s32 + nudged_zero_point0; - v0_s32 = std::max(v0_s32, static_cast(INT8_MIN)); - v0_s32 = std::min(v0_s32, static_cast(INT8_MAX)); + v0_s32 = std::max(v0_s32, static_cast(kI8Min)); + v0_s32 = std::min(v0_s32, static_cast(kI8Max)); dst_ptr[0] = (int8_t)v0_s32; dst_ptr += sizeof(int8_t); } @@ -987,6 +1125,10 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel( main_acc = main_acc * lhs_scale; + if (bias) { + main_acc += bias[n_idx]; + } + // Clamp (min-max) operation main_acc = std::max(main_acc, scalar_min); main_acc = std::min(main_acc, scalar_max); @@ -1007,12 +1149,16 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel( const float* rhs_scales_fp32, float* dst_f32, float scalar_min, - float scalar_max) { + float scalar_max, + const float* bias) { // Lambda for LHS quantization auto lhs_quant_pack = [&](size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) { + constexpr int8_t kI8Min = std::numeric_limits::lowest(); + constexpr int8_t kI8Max = std::numeric_limits::max(); + const size_t dst_stride = (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t)); @@ -1028,8 +1174,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel( min0 = std::min(src0_0, min0); } - const float qmin = (float)INT8_MIN; - const float qmax = (float)INT8_MAX; + constexpr float qmin = static_cast(kI8Min); + constexpr float qmax = static_cast(kI8Max); const float rmin0 = std::min(0.0f, min0); const float rmax0 = std::max(0.0f, max0); @@ -1046,7 +1192,7 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel( zero_point0 = std::max(zero_point0, qmin); zero_point0 = std::min(zero_point0, qmax); - const int32_t nudged_zero_point0 = lrintf(zero_point0); + const int32_t nudged_zero_point0 = std::lrintf(zero_point0); int8_t* dst_ptr = lhs_qa8dx + row_idx * dst_stride; @@ -1059,9 +1205,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel( const float src0_0 = src_ptr[k_idx]; int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0)); v0_s32 = std::max( - std::min( - v0_s32 + nudged_zero_point0, static_cast(INT8_MAX)), - static_cast(INT8_MIN)); + std::min(v0_s32 + nudged_zero_point0, static_cast(kI8Max)), + static_cast(kI8Min)); dst_ptr[0] = (int8_t)v0_s32; dst_ptr += sizeof(int8_t); } @@ -1118,6 +1263,11 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel( } main_acc = main_acc * lhs_scale; + + if (bias) { + main_acc += bias[col_idx]; + } + main_acc = std::max(main_acc, scalar_min); main_acc = std::min(main_acc, scalar_max); @@ -1128,28 +1278,27 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel( } /** - * Dynamic Input Quant 4 bit weights matmul execution flow - (INT4 Weights + FP scales + FP32 Bias) - FP32 Input Packed Buffer - | | - Quantize Cast - to INT8 to INT8 - | | - v v - INT8 Input INT8 Weights - \ / - \ / - \ / - INT8 Matrix Multiplication - | - v - FP32 Dequantized and Accumulate in FP32 - | - v - FP32 Final Output - - * The Groupwise kernel requires BFloat16 Scales and Channelwise kernel requires - * Float32 Scales. If not provided, we will use fallback implementation. + * Dynamic INT4 weight-only MatMul with per-row input quantization. + * + * Execution Flow: + * + * (INT4 Weights + FP Scales [+ optional Bias]) + * + * Input (FP32 or BF16) Packed Weight Buffer + * | | + * Row-wise Quantization (INT8) | + * | | + * INT8 Input Activation INT4 Quantized Weights + Scales + * \ / + * \ / + * Quantized Matrix Multiply + * | + * Output Tensor (BF16 or FP32) + * + * Notes: + * - Groupwise kernels expect BF16 scales + * - Channelwise kernels expect FP32 scales + * - Bias is currently unsupported in fallback path */ void dyn_quant_matmul_4bit_kernel( const Tensor& output, @@ -1161,65 +1310,75 @@ void dyn_quant_matmul_4bit_kernel( const int64_t block_size) { #if AT_KLEIDIAI_ENABLED() const int64_t weight_packed_size = - kleidiai::kai_pack_rhs_int4_size(N, K, block_size); + kleidiai::kai_pack_rhs_int4_size(N, K, block_size, inp.scalar_type()); if (weight_packed_size == packed_weights.numel()) { // KleidiAI interface internally handles the Channelwise and groupwise // distinction - kleidiai::kai_quant_pack_lhs_int4_mm( - output, inp, packed_weights, M, N, K, block_size); + kleidiai::kai_quant_pack_lhs_int4_mm(output, inp, packed_weights, M, N, K, block_size); } else #endif { - float* lhs_f32 = reinterpret_cast(inp.data_ptr()); - const auto weights_size = N * K / 2; - // The weights needs to be in uint8_t data type after quantization - auto extracted_weights = - (packed_weights.narrow(0, 0, weights_size)).to(kByte); - auto float32_scales = - (packed_weights.narrow( - 0, weights_size, packed_weights.size(0) - weights_size)) - .to(kFloat); - uint8_t* rhs_4bit = - reinterpret_cast(extracted_weights.data_ptr()); - float* rhs_scales_f32 = reinterpret_cast(float32_scales.data_ptr()); - float* dst_f32 = reinterpret_cast(output.data_ptr()); - if (block_size == K) { - ref_dyn_quant_matmul_4bit_channelwise_kernel( - M, - N, - K, - lhs_f32, - rhs_4bit, - rhs_scales_f32, - dst_f32, - -FLT_MAX, - FLT_MAX); - } else if (!(block_size % 32) && !(K % block_size)) { - ref_dyn_quant_matmul_4bit_groupwise_kernel( - M, - N, - K, - block_size, - lhs_f32, - rhs_4bit, - rhs_scales_f32, - dst_f32, - -FLT_MAX, - FLT_MAX); + { + void* input = inp.data_ptr(); + void* dst = output.data_ptr(); + + // Extract weights, sclaes and biases form from packed tensor + const int weights_elements = N * K / 2; + const int scale_elements = N * (K / block_size); + TORCH_CHECK(packed_weights.numel() >= (weights_elements + scale_elements), "Invalid packed weight tensor size"); + + auto extracted_weights = packed_weights.narrow(0, 0, weights_elements).to(kByte); + auto extracted_scales_and_bias = packed_weights.narrow(0, weights_elements, packed_weights.size(0) - weights_elements).to(kFloat); + auto float32_scales = extracted_scales_and_bias.narrow(0, 0, scale_elements); + + int bias_elements = packed_weights.numel() - (weights_elements + scale_elements); + float* weight_scales = float32_scales.data_ptr(); + + void* bias_data = nullptr; + if (bias_elements) { + auto float32_bias = extracted_scales_and_bias.narrow(0, scale_elements, bias_elements); + TORCH_CHECK(float32_bias.size(0) == N, "Expected bias length to match output dimension"); + bias_data = float32_bias.data_ptr(); + + } + // 2 elements of 4 bit weights are packed into 1 uint8 packet + uint8_t* weights_4bit = reinterpret_cast(extracted_weights.data_ptr()); + + // Dispatch to reference kernels + if (inp.scalar_type() == at::kBFloat16) { + // BF16 input, BF16 output + constexpr float BF16_MAX = 3.38953139e+38f; + constexpr float BF16_MIN = -BF16_MAX; + if (block_size == K) { + ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16( + M, N, K, + (uint16_t*)input, weights_4bit, weight_scales, + (uint16_t*)dst, BF16_MIN, BF16_MAX, (float*)bias_data); + } else { + TORCH_CHECK(false, "Unsupported block size for BF16 fallback"); + } + } else if (inp.scalar_type() == at::kFloat) { + // FP32 input, FP32 output + if (block_size == K) { + ref_dyn_quant_matmul_4bit_channelwise_kernel( + M, N, K, + (float*)input, weights_4bit, weight_scales, + (float*)dst, -FLT_MAX, FLT_MAX, (float*)bias_data); + } else if (!(block_size % 32) && !(K % block_size)) { + ref_dyn_quant_matmul_4bit_groupwise_kernel( + M, N, K, block_size, + (float*)input, weights_4bit, weight_scales, + (float*)dst, -FLT_MAX, FLT_MAX, (float*)bias_data); + } else { + TORCH_CHECK(false, "Unsupported block size for FP32 fallback"); + } } else { - TORCH_CHECK( - block_size == K || (!(block_size % 32) && !(K % block_size)), - __func__, - ": Group size should be multiple 32 or in_features [", - K, - "]. Provided ", - block_size); + TORCH_CHECK(false, "Unsupported input/output dtype combination for int4mm kernel"); } - } } - +} } // anonymous namespace - +} ALSO_REGISTER_AVX512_DISPATCH(weight_to_int4pack_stub, &weight_to_int4pack_kernel) ALSO_REGISTER_AVX512_DISPATCH(int4pack_mm_stub, &int4pack_mm_kernel) REGISTER_DISPATCH(dyn_quant_pack_4bit_weight_stub, &dyn_quant_pack_4bit_weight_kernel) diff --git a/aten/src/ATen/native/kleidiai/kai_kernels.cpp b/aten/src/ATen/native/kleidiai/kai_kernels.cpp index ce0f10bf6df1f..1313f98f90109 100644 --- a/aten/src/ATen/native/kleidiai/kai_kernels.cpp +++ b/aten/src/ATen/native/kleidiai/kai_kernels.cpp @@ -21,18 +21,27 @@ void kai_pack_int4_rhs( const int64_t n, const int64_t k, const int64_t bl) { - // Prefer Channelwise kernel over Groupwise kernel for conflicting cases if (bl == k) { // Channelwise - auto kernel_packet = kai_select_channelwise_matmul_ukernel( - kai_kernel_id:: - matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod); - auto& params = kernel_packet.rhs_pack_params; - params.lhs_zero_point = 1; - params.rhs_zero_point = 8; - - kai_pack_rhs_channelwise_int4( - kernel_packet, weight_packed, weight, scales, bias, n, k); + if (weight.scalar_type() == at::kBFloat16) { + auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel( + kai_kernel_id:: + matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod); + auto& params = kernel_packet.rhs_pack_params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + kai_pack_rhs_channelwise_int4( + kernel_packet, weight_packed, weight, scales, bias, n, k); + } else { + auto kernel_packet = kai_select_channelwise_matmul_ukernel( + kai_kernel_id:: + matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod); + auto& params = kernel_packet.rhs_pack_params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + kai_pack_rhs_channelwise_int4( + kernel_packet, weight_packed, weight, scales, bias, n, k); + } } else if (!(bl % 32) && !(k % bl)) { // Groupwise auto kernel_packet = kai_select_groupwise_matmul_ukernel( @@ -63,19 +72,29 @@ void kai_pack_int4_rhs( size_t kai_pack_rhs_int4_size( const int64_t n, const int64_t k, - const int64_t bl) { + const int64_t bl, + at::ScalarType tensor_dtype) { size_t packed_size = n * k; - // Prefer Channelwise kernel over Groupwise kernel for conflicting cases if (bl == k) { - // Channelwise - auto kernel_packet = kai_select_channelwise_matmul_ukernel( - kai_kernel_id:: - matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod); - const auto& ukernel = kernel_packet.ukernel; - const size_t nr = ukernel.get_nr(); - const size_t kr = ukernel.get_kr(); - const size_t sr = ukernel.get_sr(); - packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr); + if (tensor_dtype == at::kBFloat16) { + auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel( + kai_kernel_id:: + matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod); + const auto& ukernel = kernel_packet.ukernel; + const size_t nr = ukernel.get_nr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr); + } else { + auto kernel_packet = kai_select_channelwise_matmul_ukernel( + kai_kernel_id:: + matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod); + const auto& ukernel = kernel_packet.ukernel; + const size_t nr = ukernel.get_nr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr); + } } else if (!(bl % 32) && !(k % bl)) { // Groupwise auto kernel_packet = kai_select_groupwise_matmul_ukernel( @@ -148,8 +167,7 @@ static void kai_quant_pack_lhs_int4_mm_groupwise( const auto lhs_src_ptr = lhs_native_mtx_f32 + thread_id * src_stride; const int64_t m_idx = thread_id * vec_per_thread; auto lhs_packed_ptr = lhs_packed_base + - kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32( - m_idx, k, mr, kr, sr); + kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr); const int64_t vec_num = (thread_id == num_threads - 1) ? (m - vec_per_thread * thread_id) : vec_per_thread; @@ -259,8 +277,7 @@ static void kai_quant_pack_lhs_int4_mm_channelwise( const auto lhs_src_ptr = lhs_native_mtx_f32 + thread_id * src_stride; const int64_t m_idx = thread_id * vec_per_thread; auto lhs_packed_ptr = lhs_packed_base + - kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32( - m_idx, k, mr, kr, sr); + kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr); const int64_t vec_num = (thread_id == num_threads - 1) ? (m - vec_per_thread * thread_id) : vec_per_thread; @@ -320,19 +337,144 @@ static void kai_quant_pack_lhs_int4_mm_channelwise( }); } -void kai_quant_pack_lhs_int4_mm( +static void kai_quant_pack_lhs_int4_mm_bf16_channelwise( const Tensor& output, const Tensor& input, const Tensor& weight, const int64_t m, const int64_t n, + const int64_t k) { + // Kernel IDs for GEMM and GEMV + constexpr kai_kernel_id gemm_id = + kai_kernel_id::matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm; + constexpr kai_kernel_id gemv_id = + kai_kernel_id::matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod; + + // Get total threads and select kernel + const int64_t total_threads = at::get_num_threads(); + auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(gemv_id); + if (cpuinfo_has_arm_i8mm() && m > 1) { + kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(gemm_id); + } + + // Thread blocking parameters + const int64_t n_step = kernel_packet.ukernel.get_n_step(); + const size_t mr = kernel_packet.ukernel.get_mr(); + const size_t kr = kernel_packet.ukernel.get_kr(); + const size_t sr = kernel_packet.ukernel.get_sr(); + + const size_t lhs_packed_size = + kernel_packet.kai_get_lhs_packed_size(m, k, mr, kr, sr); + auto lhs_packed = std::make_unique(lhs_packed_size); + uint8_t* dst_act_mtx_bf16 = reinterpret_cast(output.data_ptr()); + const uint8_t* lhs_native_mtx_bf16 = + reinterpret_cast(input.data_ptr()); + const uint8_t* rhs_packed_mtx_qs4cx = + reinterpret_cast(weight.data_ptr()); + uint8_t* lhs_packed_base = lhs_packed.get(); + + constexpr int32_t element_size = sizeof(uint16_t); + const size_t lhs_stride = k * element_size; + const size_t dst_stride = n * element_size; + + // LHS quantization packing + int64_t vec_per_thread = get_vec_per_thread(m, total_threads, mr); + int64_t num_threads = (m + vec_per_thread - 1) / vec_per_thread; + const size_t src_stride = vec_per_thread * lhs_stride; + + auto lhs_quant_pack = [=, &kernel_packet](int64_t thread_id) { + const auto lhs_src_ptr = lhs_native_mtx_bf16 + thread_id * src_stride; + const int64_t m_idx = thread_id * vec_per_thread; + auto lhs_packed_ptr = lhs_packed_base + + kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr); + const int64_t vec_num = (thread_id == num_threads - 1) + ? (m - vec_per_thread * thread_id) + : vec_per_thread; + + kernel_packet.kai_run_lhs_quant_pack( + vec_num, + k, + mr, + kr, + sr, + 0, + (const uint16_t*)lhs_src_ptr, + lhs_stride, + lhs_packed_ptr); + }; + + at::parallel_for( + 0, num_threads, /*grain_size=*/1, [&](int64_t begin, int64_t end) { + for (int64_t thread_id = begin; thread_id < end; ++thread_id) { + lhs_quant_pack(thread_id); + } + }); + + // Matrix multiplication + vec_per_thread = get_vec_per_thread(n, total_threads, n_step); + num_threads = (n + vec_per_thread - 1) / vec_per_thread; + + auto mm = [=, &kernel_packet](int64_t thread_id) { + const auto rhs_packed_ptr = rhs_packed_mtx_qs4cx + + kernel_packet.ukernel.get_rhs_packed_offset( + thread_id * vec_per_thread, k); + auto dst_ptr = dst_act_mtx_bf16 + + kernel_packet.ukernel.get_dst_offset( + 0, thread_id * vec_per_thread, dst_stride); + const int64_t vec_num = (thread_id == num_threads - 1) + ? (n - vec_per_thread * thread_id) + : vec_per_thread; + + kernel_packet.ukernel.run_matmul( + m, + vec_num, + k, + lhs_packed_base, + rhs_packed_ptr, + (uint16_t*)dst_ptr, + dst_stride, + element_size, // dst_stride_col + -FLT_MAX, + FLT_MAX); + }; + + at::parallel_for( + 0, num_threads, /*grain_size=*/1, [&](int64_t begin, int64_t end) { + for (int64_t thread_id = begin; thread_id < end; ++thread_id) { + mm(thread_id); + } + }); +} +void kai_quant_pack_lhs_int4_mm( + const at::Tensor& output, + const at::Tensor& input, + const at::Tensor& weight, + const int64_t m, + const int64_t n, const int64_t k, const int64_t bl) { // Prefer Channelwise kernel over Groupwise kernel for conflicting cases if (bl == k) { - kleidiai::kai_quant_pack_lhs_int4_mm_channelwise( - output, input, weight, m, n, k); - } else if (!(bl % 32) && !(k % bl)) { + const auto input_dtype = input.dtype(); + + if (input_dtype == at::kBFloat16) { + if (cpuinfo_has_arm_bf16()) { + kleidiai::kai_quant_pack_lhs_int4_mm_bf16_channelwise( + output, input, weight, m, n, k); + } else { + TORCH_CHECK( + false, + "BF16 Unsupported: CPU does not support BF16. Please use a CPU with BF16 support."); + } + } else if (input_dtype == at::kFloat) { + kleidiai::kai_quant_pack_lhs_int4_mm_channelwise( + output, input, weight, m, n, k); + } else { + TORCH_CHECK( + false, + "Unsupported input data type: Only Bfloat16 and Float inputs are supported."); + } + } else if ((bl % 32 == 0) && (k % bl == 0)) { kleidiai::kai_quant_pack_lhs_int4_mm_groupwise( output, input, weight, m, n, k, bl); } diff --git a/aten/src/ATen/native/kleidiai/kai_kernels.h b/aten/src/ATen/native/kleidiai/kai_kernels.h index 9b522d7f7705a..a4179cefd06cf 100644 --- a/aten/src/ATen/native/kleidiai/kai_kernels.h +++ b/aten/src/ATen/native/kleidiai/kai_kernels.h @@ -25,7 +25,8 @@ void kai_pack_int4_rhs( size_t kai_pack_rhs_int4_size( const int64_t n, const int64_t k, - const int64_t bl); + const int64_t bl, + at::ScalarType tensor_dtype = at::kFloat); /** * @brief Run 2 operations ( Input quantize and pack -> 4 bit Matmul ) diff --git a/aten/src/ATen/native/kleidiai/kai_pack.h b/aten/src/ATen/native/kleidiai/kai_pack.h index 4ff3371ab5e2a..d9f08333591ed 100644 --- a/aten/src/ATen/native/kleidiai/kai_pack.h +++ b/aten/src/ATen/native/kleidiai/kai_pack.h @@ -36,7 +36,8 @@ void kai_pack_rhs_groupwise_int4( AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null"); } - float* bias_ptr = bias.has_value() ? bias.value().data_ptr() : NULL; + float* bias_ptr = + bias.has_value() ? bias.value().to(kFloat).data_ptr() : NULL; auto& params = kernel.rhs_pack_params; kernel.kai_run_rhs_pack( @@ -73,7 +74,8 @@ void kai_pack_rhs_channelwise_int4( auto weight_packed_data = reinterpret_cast(weight_packed.data_ptr()); const auto weight_data = weight.data_ptr(); - const auto scales_data = scales.data_ptr(); + + const auto scales_data = scales.to(kFloat).data_ptr(); if (weight_data == nullptr) { AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null"); @@ -83,7 +85,8 @@ void kai_pack_rhs_channelwise_int4( AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null"); } - float* bias_ptr = bias.has_value() ? bias.value().data_ptr() : NULL; + float* bias_ptr = + bias.has_value() ? bias.value().to(kFloat).data_ptr() : NULL; auto& params = kernel.rhs_pack_params; kernel.kai_run_rhs_pack( diff --git a/aten/src/ATen/native/kleidiai/kai_ukernel_interface.cpp b/aten/src/ATen/native/kleidiai/kai_ukernel_interface.cpp index 0de198d7dc012..783133b83e670 100644 --- a/aten/src/ATen/native/kleidiai/kai_ukernel_interface.cpp +++ b/aten/src/ATen/native/kleidiai/kai_ukernel_interface.cpp @@ -68,5 +68,39 @@ kai_matmul_ukernel_f32_qa8dxp_qs4cxp kai_select_channelwise_matmul_ukernel( const kai_kernel_id id) { return channelwise_8bit_4bit_kernels.at(id); } + +// Kernel Mapping - BF16 Channelwise +std::unordered_map + bf16_channelwise_8bit_4bit_kernels = { + {kai_kernel_id:: + matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod, + {{kai_get_m_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod, + kai_get_n_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod, + kai_get_mr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod, + kai_get_nr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod, + kai_get_kr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod, + kai_get_sr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod, + kai_get_dst_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod, + kai_get_dst_size_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod, + kai_run_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod}}}, + {kai_kernel_id::matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm, + {{kai_get_m_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm, + kai_get_n_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm, + kai_get_mr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm, + kai_get_nr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm, + kai_get_kr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm, + kai_get_sr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm, + kai_get_dst_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm, + kai_get_dst_size_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm, + kai_run_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm}}}}; + +kai_matmul_ukernel_bf16_qa8dxp_qs4cxp kai_select_bf16_channelwise_matmul_ukernel( + const kai_kernel_id id) { + return bf16_channelwise_8bit_4bit_kernels.at(id); +} } // namespace at::native::kleidiai #endif diff --git a/aten/src/ATen/native/kleidiai/kai_ukernel_interface.h b/aten/src/ATen/native/kleidiai/kai_ukernel_interface.h index 8480469cdea86..cfcf7a81ba85f 100644 --- a/aten/src/ATen/native/kleidiai/kai_ukernel_interface.h +++ b/aten/src/ATen/native/kleidiai/kai_ukernel_interface.h @@ -10,21 +10,32 @@ #include #include #include +#include +#include +#include #include +#include #include #include namespace at::native::kleidiai { enum class kai_kernel_id { + // FP32 inputs, 4-bit weights, FP32 output matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod = - 0, // Groupwise 4 bit GEMV + 0, // Groupwise 4-bit GEMV (per-group scales, NEON DOTPROD) matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm = - 1, // Groupwise 4 bit GEMM + 1, // Groupwise 4-bit GEMM (per-group scales, NEON I8MM) matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod = - 2, // Channelwise 4 bit GEMV + 2, // Channelwise 4-bit GEMV (per-channel scales, NEON DOTPROD) matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm = - 3 // Channelwise 4 bit GEMM + 3, // Channelwise 4-bit GEMM (per-channel scales, NEON I8MM) + + // BF16 inputs, 4-bit weights, BF16 output + matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod = + 4, // Channelwise 4-bit GEMV with BF16 input/output + matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm = + 5 // Channelwise 4-bit GEMM with BF16 input/output }; // Channelwise Kernel mapping @@ -66,6 +77,9 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp { void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params); + size_t(*kai_get_lhs_quant_pack_offset)( + size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr + ); kai_matmul_ukernel_f32_qa8dxp_qs4cxp( const kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel& kernel) @@ -75,12 +89,71 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp { kai_get_rhs_packed_size( &kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0), kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32), - kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0) {} + kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0), + kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32){} }; struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp kai_select_channelwise_matmul_ukernel(const kai_kernel_id id); +// bf16 Channelwise Kernel mapping +struct kai_matmul_ukernel_bf16_qa8dxp_qs4cxp { + struct kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel ukernel; + struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params rhs_pack_params; + size_t (*kai_get_lhs_packed_size)( + size_t m, + size_t k, + size_t mr, + size_t kr, + size_t sr); + size_t (*kai_get_rhs_packed_size)( + size_t n, + size_t k, + size_t nr, + size_t kr, + size_t sr); + void (*kai_run_lhs_quant_pack)( + size_t m, + size_t k, + size_t mr, + size_t kr, + size_t sr, + size_t m_idx_start, + const void* lhs, + size_t lhs_stride, + void* lhs_packed); + void (*kai_run_rhs_pack)( + size_t num_groups, + size_t n, + size_t k, + size_t nr, + size_t kr, + size_t sr, + const uint8_t* rhs, + const float* bias, + const float* scale, + void* rhs_packed, + size_t extra_bytes, + const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params); + size_t(*kai_get_lhs_quant_pack_offset)( + size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr + ); + + kai_matmul_ukernel_bf16_qa8dxp_qs4cxp( + const kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel& kernel) + : ukernel(kernel), + kai_get_lhs_packed_size( + &kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon), + kai_get_rhs_packed_size( + &kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0), + kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_bf16_neon), + kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0), + kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon){} + }; + +struct kai_matmul_ukernel_bf16_qa8dxp_qs4cxp +kai_select_bf16_channelwise_matmul_ukernel(const kai_kernel_id id); + // Groupwise Kernel mapping struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p { struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel; @@ -125,6 +198,9 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p { void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params); + size_t(*kai_get_lhs_quant_pack_offset)( + size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr + ); kai_matmul_ukernel_f32_qa8dxp_qs4c32p( const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& kernel) @@ -134,7 +210,8 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p { kai_get_rhs_packed_size( &kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0), kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32), - kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0) {} + kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0), + kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32) {} }; struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel( diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index ed65f07742945..8ecd637f4a322 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -2484,7 +2484,7 @@ def fn(a, b_int8pack, b_scales, c): @skipCUDAIf(True, "No _dyn_quant_pack_4bit_weight implementation on CUDA") @skipIfRocm @skipIfXpu(msg="No _dyn_quant_pack_4bit_weight implementation on XPU") - def test__dyn_quant_pack_4bit_weight(self): + def test__dyn_quant_pack_4bit_weight_fp32(self): q_group = 32 k = 128 n = 128 @@ -2515,12 +2515,54 @@ def fn(b, in_features, out_features): self.common(fn, (b, in_features, out_features)) + @xfail_if_mps_unimplemented + @xfail_if_triton_cpu + @skipCUDAIf(True, "No _dyn_quant_pack_4bit_weight implementation on CUDA") + @skipIfRocm + @skipIfXpu(msg="No _dyn_quant_pack_4bit_weight implementation on XPU") + @skip_if_halide # bf16 + def test__dyn_quant_pack_4bit_weight_bf16(self): + k = 128 + n = 128 + q_group = 32 + + if not self.is_dtype_supported(torch.bfloat16): + raise unittest.SkipTest( + f"torch.bfloat16 not supported for device {self.device}" + ) + + torch.manual_seed(1) + b = torch.rand((k, n), dtype=torch.bfloat16) + in_features = b.size(0) + out_features = b.size(1) + + def dyn_quant_pack_4bit_weight(b, in_features, out_features): + b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric( + b, n_bit=4, groupsize=q_group + ) + + if q_group == in_features: + b_scales_and_zeros = b_scales_and_zeros.to(torch.float) + else: + b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16) + b_int4pack = torch._dyn_quant_pack_4bit_weight( + b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features + ) + + return b_int4pack, b_scales_and_zeros + + def fn(b, in_features, out_features): + b_int4pack, _ = dyn_quant_pack_4bit_weight(b, in_features, out_features) + return b_int4pack + + self.common(fn, (b, in_features, out_features)) + @xfail_if_mps_unimplemented @xfail_if_triton_cpu @skipCUDAIf(True, "No _dyn_quant_matmul_4bit implementation on CUDA") @skipIfRocm @skipIfXpu(msg="No _dyn_quant_matmul_4bit implementation on XPU") - def test__dyn_quant_matmul_4bit(self): + def test__dyn_quant_matmul_4bit_fp32_input(self): q_group = 32 m = 32 k = 128 @@ -2560,6 +2602,68 @@ def fn(a, q_group, in_features, out_features): self.common(fn, (a, q_group, in_features, out_features)) + @skipCPUIf(IS_MACOS, "fails on M1, mismatch in bf16 support reporting") + @xfail_if_mps_unimplemented + @xfail_if_triton_cpu + @skipCUDAIf(True, "No _dyn_quant_matmul_4bit implementation on CUDA") + @skipIfRocm + @skipIfXpu(msg="No _dyn_quant_matmul_4bit implementation on XPU") + @skip_if_halide # bf16 + def test__dyn_quant_matmul_4bit_bf16_input(self): + m = 32 + k = 128 + n = 128 + q_group = k + + if not self.is_dtype_supported(torch.bfloat16): + raise unittest.SkipTest( + f"torch.bfloat16 not supported for device {self.device}" + ) + + torch.manual_seed(1) + a = torch.rand((m, k), dtype=torch.bfloat16) + b = torch.rand((k, n), dtype=torch.bfloat16) + + # codegen_dynamic_shape test fails without explicitly marking these dynamic + torch._dynamo.mark_dynamic(a, 0) + torch._dynamo.mark_dynamic(b, 1) + + in_features = b.size(0) + out_features = b.size(1) + + if not self.is_dtype_supported(torch.bfloat16): + raise unittest.SkipTest( + f"torch.bfloat16 not supported for device {self.device}" + ) + + def dyn_quant_pack_4bit_weight(b, in_features, out_features): + b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric( + b, n_bit=4, groupsize=q_group + ) + + if q_group == in_features: + b_scales_and_zeros = b_scales_and_zeros.to(torch.float) + else: + b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16) + b_int4pack = torch._dyn_quant_pack_4bit_weight( + b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features + ) + + return b_int4pack, b_scales_and_zeros + + def fn(a, q_group, in_features, out_features): + b_int4pack, _ = dyn_quant_pack_4bit_weight(b, in_features, out_features) + res = torch.ops.aten._dyn_quant_matmul_4bit( + a, + b_int4pack, + q_group, + in_features, + out_features, + ) + return res + + self.common(fn, (a, q_group, in_features, out_features), atol=1, rtol=0.5) + def test_expanded_reduction(self): def fn(x, y): z = x * y diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 5a629b371c766..2ed88a4ec2344 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -3741,6 +3741,7 @@ def kai_roundup(a: int, b: int) -> int: def get_kai_packed_weight_size(n_bits, N, K, groupsize): if n_bits == 4: + # Works for both fp32 and bf16 Kernels if groupsize == K: # channelwise # dotprod params only [1x8x32_neon_dotprod] kai_nr = 8 @@ -3870,6 +3871,8 @@ def meta__dyn_quant_pack_4bit_weight( ) return weights.new_empty(int(packed_weight_size), dtype=torch.uint8) packed_weight_size = weights.numel() + scales_zeros.numel() + if bias is not None: + packed_weight_size += bias.numel() return weights.new_empty(packed_weight_size, dtype=torch.float) @@ -3883,8 +3886,12 @@ def meta__dyn_quant_matmul_4bit( ): torch._check(inp.dim() == 2, lambda: "input must be a 2D tensor") torch._check( - inp.dtype == torch.float32, - lambda: f"expected input to be f32, got {inp.dtype}", + (inp.dtype == torch.float32) + or (inp.dtype == torch.bfloat16 and block_size == in_features), + lambda: ( + f"expected input to be f32 or bf16 (bf16 requires block_size == in_features), " + f"got {inp.dtype} with block_size={block_size} and in_features={in_features}" + ), ) M = inp.size(0) return inp.new_empty(M, out_features, dtype=inp.dtype) From ae142ab89fd38655a4eae56ae068ac07f06e5b79 Mon Sep 17 00:00:00 2001 From: Aleksei Nikiforov Date: Thu, 20 Nov 2025 13:10:43 +0000 Subject: [PATCH 093/230] s390x: fix periodic tests build (#168001) It looks like building python_call.cpp with -O3 triggers a bug in gcc-14. As a workaround, ignore offending warning on s390x in the code. Build failure link: https://github.com/pytorch/pytorch/actions/runs/19423391774/job/55584553077 GCC bug reference: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=115016 In addition to that, fix docker image names for s390x test workflows similar to build workflows and remove fail marks from couple of tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168001 Approved by: https://github.com/seemethere --- .github/workflows/_linux-test.yml | 5 ++++- test/dynamo/test_structured_trace.py | 3 +-- test/inductor/test_torchinductor.py | 1 - torch/csrc/distributed/rpc/python_call.cpp | 10 ++++++++++ 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/.github/workflows/_linux-test.yml b/.github/workflows/_linux-test.yml index b52ec158dd6d6..2434a595f5420 100644 --- a/.github/workflows/_linux-test.yml +++ b/.github/workflows/_linux-test.yml @@ -327,6 +327,7 @@ jobs: SCCACHE_REGION: ${{ !contains(matrix.runner, 'b200') && 'us-east-1' || '' }} SHM_SIZE: ${{ contains(inputs.build-environment, 'cuda') && '2g' || '1g' }} DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} + DOCKER_IMAGE_S390X: ${{ inputs.docker-image }} XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }} XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }} @@ -360,10 +361,12 @@ jobs: # if for some reason cleanup action doesn't stop container # when job is cancelled DOCKER_SHELL_CMD="sleep 12h" + USED_IMAGE="${DOCKER_IMAGE_S390X}" else SHM_OPTS="--shm-size=${SHM_SIZE}" JENKINS_USER="--user jenkins" DOCKER_SHELL_CMD= + USED_IMAGE="${DOCKER_IMAGE}" fi # detached container should get cleaned up by teardown_ec2_linux @@ -426,7 +429,7 @@ jobs: ${JENKINS_USER} \ -v "${GITHUB_WORKSPACE}:/var/lib/jenkins/workspace" \ -w /var/lib/jenkins/workspace \ - "${DOCKER_IMAGE}" \ + "${USED_IMAGE}" \ ${DOCKER_SHELL_CMD} ) echo "DOCKER_CONTAINER_ID=${container_name}" >> "${GITHUB_ENV}" diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index e1e2b228062f6..33715d2cf861b 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -21,7 +21,7 @@ from torch._inductor.test_case import TestCase from torch._logging._internal import TorchLogsFormatter from torch.nn.parallel import DistributedDataParallel as DDP -from torch.testing._internal.common_utils import find_free_port, xfailIfS390X +from torch.testing._internal.common_utils import find_free_port from torch.testing._internal.triton_utils import requires_cuda_and_triton @@ -1017,7 +1017,6 @@ def fn(a): logs = self.buffer.getvalue() self.assertTrue(all(event in logs for event in chromium_events)) - @xfailIfS390X @requires_tlparse @torch._dynamo.config.patch("compiled_autograd", True) def test_compiled_autograd_attribution(self): diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 8ecd637f4a322..7cfb815a93d7d 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -2172,7 +2172,6 @@ def fn(a): @skipCPUIf(IS_MACOS, "fails on macos") @skip_if_halide # accuracy 4.7% off - @xfailIfS390X # accuracy failure def test_multilayer_var_lowp(self): def fn(a): return torch.var(a) diff --git a/torch/csrc/distributed/rpc/python_call.cpp b/torch/csrc/distributed/rpc/python_call.cpp index 5973903cfcf10..4770d744215c6 100644 --- a/torch/csrc/distributed/rpc/python_call.cpp +++ b/torch/csrc/distributed/rpc/python_call.cpp @@ -6,6 +6,12 @@ PythonCall::PythonCall(SerializedPyObj&& serializedPyObj, bool isAsyncExecution) : serializedPyObj_(std::move(serializedPyObj)), isAsyncExecution_(isAsyncExecution) {} +#if defined(__GNUC__) && __GNUC__ == 14 +/* this warning is falsely triggered with gcc-14 in following function. */ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wfree-nonheap-object" +#endif + c10::intrusive_ptr PythonCall::toMessageImpl() && { std::vector payload; payload.reserve(serializedPyObj_.payload_.length() + 1); @@ -21,6 +27,10 @@ c10::intrusive_ptr PythonCall::toMessageImpl() && { MessageType::PYTHON_CALL); } +#if defined(__GNUC__) && __GNUC__ == 14 +#pragma GCC diagnostic pop +#endif + std::unique_ptr PythonCall::fromMessage(const Message& message) { TORCH_INTERNAL_ASSERT( !message.payload().empty(), From 6edf2aa8f3a694077519e6223f7e132a2071f357 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 20 Nov 2025 13:27:12 +0000 Subject: [PATCH 094/230] Revert "Improve build logic in activities for kineto (#167204)" This reverts commit 6fc430644b1357fab03a03619576fef7197ac60e. Reverted https://github.com/pytorch/pytorch/pull/167204 on behalf of https://github.com/guangyey due to break xpu legacy profiler ([comment](https://github.com/pytorch/pytorch/pull/167204#issuecomment-3558053023)) --- torch/csrc/autograd/init.cpp | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 7470344cc05f7..a13cc70270ccb 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -390,27 +390,31 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { m.def("_supported_activities", []() { std::set activities{ torch::profiler::impl::ActivityType::CPU}; -#if defined(USE_KINETO) -#if (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER)) +#if defined(USE_KINETO) && \ + (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER)) + if (at::hasMTIA()) { + activities.insert(torch::profiler::impl::ActivityType::MTIA); + } + if (at::hasHPU()) { + activities.insert(torch::profiler::impl::ActivityType::HPU); + } if (at::getNumGPUs() > 0) { activities.insert(torch::profiler::impl::ActivityType::CUDA); } -#endif // (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER)) -#if (!defined(LIBKINETO_NOXPUPTI)) +#elif defined(USE_KINETO) if (at::hasXPU()) { activities.insert(torch::profiler::impl::ActivityType::XPU); } -#endif // (!defined(LIBKINETO_NOXPUPTI)) - if (at::hasMTIA()) { - activities.insert(torch::profiler::impl::ActivityType::MTIA); - } if (at::hasHPU()) { activities.insert(torch::profiler::impl::ActivityType::HPU); } + if (at::hasMTIA()) { + activities.insert(torch::profiler::impl::ActivityType::MTIA); + } if (c10::get_privateuse1_backend() != "privateuseone") { activities.insert(torch::profiler::impl::ActivityType::PrivateUse1); } -#endif // defined(USE_KINETO) +#endif return activities; }); From 762273e3c55ed951395eefbd44d365ef4ce7a58b Mon Sep 17 00:00:00 2001 From: Vinitha Vijayan Date: Thu, 20 Nov 2025 13:40:32 +0000 Subject: [PATCH 095/230] Move pointwise_scatter optimization to joint_graph stage from post_grad (#165463) Fixes #129449 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165463 Approved by: https://github.com/eellison --- torch/_inductor/fx_passes/joint_graph.py | 89 +++++++++++++++++++++++ torch/_inductor/fx_passes/post_grad.py | 93 +----------------------- 2 files changed, 91 insertions(+), 91 deletions(-) diff --git a/torch/_inductor/fx_passes/joint_graph.py b/torch/_inductor/fx_passes/joint_graph.py index 9db694f1d8629..021abb0d6b13b 100644 --- a/torch/_inductor/fx_passes/joint_graph.py +++ b/torch/_inductor/fx_passes/joint_graph.py @@ -957,3 +957,92 @@ def repl(inp, other): pass_dict=pass_patterns[1], extra_check=_other_is_broadcasted_in_dim, )(div_softmax_pattern) + + +def scatter_upon_const_tensor_extra_check(m): + if not config.optimize_scatter_upon_const_tensor: + return False + full_shape = m.kwargs["shape"] + selector = m.kwargs["selector"] + dim = m.kwargs["dim"] + if dim < 0: + dim += len(full_shape) + + selector_ft = selector.meta["val"] + assert selector_ft.dim() == len(full_shape) + + for idx, select_sz, full_sz in zip( + itertools.count(), selector_ft.shape, full_shape + ): + if idx == dim: + continue + + # TODO: the pattern can be updated to support the case that index tensor + # is shorter. But that will need a more complex condition expression + # especially for multi-dimensional tensors. + # Skip it for now. + if isinstance(full_sz, torch.fx.Node): + full_sz = full_sz.meta["val"] + if select_sz < full_sz: + return False + + # Actually we can support small size larger than 1. It would be a bit + # tedious. E.g., we load all the index values (not many) and compare + # them with the position in tensor to decide what value to return. + return selector_ft.size(dim) == 1 + + +@register_graph_pattern( + CallFunction( + aten.scatter.value, + CallFunction( + aten.full, + KeywordArg("shape"), + KeywordArg("background_val"), + dtype=KeywordArg("dtype"), + ), + KeywordArg("dim"), + KeywordArg("selector"), + KeywordArg("val"), # scalar value + ), + # pyrefly: ignore [bad-argument-type] + pass_dict=patterns, + extra_check=scatter_upon_const_tensor_extra_check, +) +def scatter_upon_const_tensor( + match: Match, shape, background_val, dtype, dim, selector, val +): + """ + Match the pattern of full+scatter into a pointwise operation in joint graph. + + TODO: Right now the scatter value must be a scalar. But we could support it + when it is a tensor as well. + """ + from torch._inductor import metrics + + # pyrefly: ignore # bad-assignment + metrics.num_matches_for_scatter_upon_const_tensor += 1 + + # Create a replacement that uses torch.where for the pointwise operation + def repl_fn(shape, background_val, dim, selector, val): + # Create a tensor of indices for the scatter dimension + length = shape[dim] + indices = torch.arange(length, device=selector.device, dtype=torch.int64) + + # Reshape indices to have size 'length' at dim, then broadcast + view_shape = [1] * len(shape) + view_shape[dim] = length + indices_view = indices.view(*view_shape) + + # Broadcast selector to match full tensor shape + selector_expanded = selector.expand(shape) + + # Create a mask for where to scatter + mask = selector_expanded == indices_view + + # Use torch.where to implement the scatter pointwise operation + return torch.where(mask, val, background_val) + + # replace the scatter operation with pointwise equivalent + # pyrefly: ignore [bad-argument-type] + match.replace_by_example(repl_fn, [shape, background_val, dim, selector, val]) diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index e0362f2aaafd4..e403e82ff6c3b 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -16,13 +16,13 @@ from torch._decomp import register_decomposition from torch._dynamo.utils import counters from torch._inductor import comms -from torch._inductor.virtualized import ops +from torch._inductor.virtualized import ops # noqa: F401 from torch._logging import trace_structured from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq from torch.utils._ordered_set import OrderedSet -from .. import config, ir, pattern_matcher +from .. import config, ir, pattern_matcher # noqa: F401 from ..codegen.common import custom_backend_passes from ..comms import remove_fsdp2_unsharded_param_graph_input_usage from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage @@ -802,95 +802,6 @@ def is_valid_mm_plus_mm(match: Match): return True -def scatter_upon_const_tensor_extra_check(m): - if not config.optimize_scatter_upon_const_tensor: - return False - full_shape = m.kwargs["shape"] - selector = m.kwargs["selector"] - dim = m.kwargs["dim"] - if dim < 0: - dim += len(full_shape) - - selector_ft = selector.meta["val"] - assert selector_ft.dim() == len(full_shape) - - for idx, select_sz, full_sz in zip( - itertools.count(), selector_ft.shape, full_shape - ): - if idx == dim: - continue - - # TODO: the pattern can be updated to support the case that index tensor - # is shorter. But that will need a more complex condition expression - # especially for multi-dimensional tensors. - # Skip it for now. - if isinstance(full_sz, fx.Node): - full_sz = full_sz.meta["val"] - if select_sz < full_sz: - return False - - # Actually we can support small size larger than 1. It would be a bit - # tedius. E.g., we load all the index values (not many) and compare - # them with the position in tensor to decide what value to return. - return selector_ft.size(dim) == 1 - - -@register_lowering_pattern( - CallFunction( - aten.scatter.value, - CallFunction( - aten.full, - KeywordArg("shape"), - KeywordArg("background_val"), - dtype=KeywordArg("dtype"), - ), - KeywordArg("dim"), - KeywordArg("selector"), - KeywordArg("val"), # scalar value - ), - extra_check=scatter_upon_const_tensor_extra_check, -) -def scatter_upon_const_tensor( - match: Match, shape, background_val, dtype, dim, selector, val -): - """ - Match the pattern of full+scatter into a pointwise. - - TODO: Right now the scatter value must be a scalar. But we could support it - when it is a tensor as well. - """ - from torch._inductor import metrics - - # Check if inputs are tensors instead of inductor IR nodes - if isinstance(selector, torch.Tensor): - # Return a fake tensor with the proper shape that this operator is intended to return - device = selector.device if hasattr(selector, "device") else torch.device("cpu") - return torch.empty(shape, dtype=dtype, device=device) - - # pyrefly: ignore [bad-assignment] - metrics.num_matches_for_scatter_upon_const_tensor += 1 - - selector_loader = selector.make_loader() - - def inner_fn(idx): - selector_idx = list(idx) - selector_idx[dim] = 0 - - selector = selector_loader(selector_idx) - return ops.where( - selector == ops.index_expr(idx[dim], torch.int64), - ops.constant(val, dtype), - ops.constant(background_val, dtype), - ) - - return ir.Pointwise.create( - device=selector.get_device(), - dtype=dtype, - inner_fn=inner_fn, - ranges=shape, - ) - - @register_lowering_pattern( CallFunction( aten.add, From bd883bb2903e1850cb4b442dbfe74538aee23dd9 Mon Sep 17 00:00:00 2001 From: Klaus Zimmermann Date: Thu, 20 Nov 2025 09:26:30 +0100 Subject: [PATCH 096/230] Add basic spin linting documentation (#167227) This adds basic documentation of the linting features for Spin added in #167226 to the CONTRIBUTING.md document. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167227 Approved by: https://github.com/atalman, https://github.com/albanD --- CONTRIBUTING.md | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index bc0b0fc9bb00f..863a685886a84 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -14,6 +14,10 @@ aspects of contributing to PyTorch. - [Tips and Debugging](#tips-and-debugging) - [Nightly Checkout & Pull](#nightly-checkout--pull) - [Codebase structure](#codebase-structure) +- [Spin](#spin) + - [Linting](#linting) + - [default lint](#default-lint) + - [Regenerating](#regenerating) - [Unit testing](#unit-testing) - [Python Unit Testing](#python-unit-testing) - [Better local unit tests with `pytest`](#better-local-unit-tests-with-pytest) @@ -274,6 +278,46 @@ dependencies as well as the nightly binaries into the repo directory. * ... * [.circleci](.circleci) - CircleCI configuration management. [README](.circleci/README.md) +## Spin + +[Spin](https://github.com/scientific-python/spin) is a developer cli tool that +helps running common tasks. +To list the available tasks, run `spin --help`. +Currently, we support the following tasks with Spin: + +### Linting + +Spin helps with linting by making sure that lintrunner is installed correctly +and by isolating the lintrunner environment from the general development +environment using uv. + +|command|| +|-|-| +|`setup-lint`|update lintrunner and perform a fresh setup| +|`lazy-setup-lint`|only perform setup if the lint configuration has changed| +|`lint`|perform default lint (see below)| +|`quicklint`|perform lint on all files changed in the latest commit and the working directory| +|`quickfix`|autofix issues on all files changed in the latest commit and the working directory| + +#### default lint + +Since some linters take a long time to run, we categorize all linters as either +fast or slow. In the default lint, only the fast linters are run on all files; +the slow linters are run on the changed files only. + +### Regenerating + +Pytorch makes use of a number of code generations, which range from the version +information in `torch/version.py` over type stubs and other linter support to +github workflows. +With Spin, we offer a unified interface to these tasks. + +|command|| +|-|-| +|`regenerate-version`|regenerate `torch/version.py`| +|`regenerate-type-stubs`|regenerates type stubs for use by static type checkers| +|`regenerate-clangtidy-files`|regenerates clang related files needed for linting| + ## Unit testing ### Python Unit Testing From 9d7f9834c8cee0900207cf040a0f765070977889 Mon Sep 17 00:00:00 2001 From: Klaus Zimmermann Date: Thu, 20 Nov 2025 09:26:30 +0100 Subject: [PATCH 097/230] Add workflow regeneration to spin (#167551) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167551 Approved by: https://github.com/Skylion007, https://github.com/albanD, https://github.com/atalman ghstack dependencies: #167227 --- .spin/cmds.py | 7 +++++++ CONTRIBUTING.md | 1 + pyproject.toml | 1 + 3 files changed, 9 insertions(+) diff --git a/.spin/cmds.py b/.spin/cmds.py index a81717c7423be..9ed9f4a796b45 100644 --- a/.spin/cmds.py +++ b/.spin/cmds.py @@ -328,3 +328,10 @@ def quicklint(ctx, apply_patches, **kwargs): def quickfix(ctx, **kwargs): """Autofix changed files.""" ctx.invoke(quicklint, apply_patches=True) + + +@click.command() +def regenerate_github_workflows(): + """Regenerate GitHub workflows from templates.""" + cmd = [sys.executable, "scripts/generate_ci_workflows.py"] + spin.util.run(cmd, cwd="./.github") diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 863a685886a84..850753f13b63a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -317,6 +317,7 @@ With Spin, we offer a unified interface to these tasks. |`regenerate-version`|regenerate `torch/version.py`| |`regenerate-type-stubs`|regenerates type stubs for use by static type checkers| |`regenerate-clangtidy-files`|regenerates clang related files needed for linting| +|`regenerate-github-workflows`|regenerates github workflows from jinja templates| ## Unit testing diff --git a/pyproject.toml b/pyproject.toml index 9986c6a9b7b6b..d9927122352f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -391,4 +391,5 @@ package = 'torch' ".spin/cmds.py:regenerate_version", ".spin/cmds.py:regenerate_type_stubs", ".spin/cmds.py:regenerate_clangtidy_files", + ".spin/cmds.py:regenerate_github_workflows", ] From 43acddb3fd9d5b20ee884e208504031067141285 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Wed, 19 Nov 2025 22:42:59 +0200 Subject: [PATCH 098/230] Move c10/util/Deprecated.h to headeronly (#168173) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168173 Approved by: https://github.com/janeyx99 --- c10/util/Deprecated.h | 102 +---------------------------- torch/header_only_apis.txt | 6 ++ torch/headeronly/util/Deprecated.h | 102 +++++++++++++++++++++++++++++ 3 files changed, 109 insertions(+), 101 deletions(-) create mode 100644 torch/headeronly/util/Deprecated.h diff --git a/c10/util/Deprecated.h b/c10/util/Deprecated.h index 88440a0242eb4..3237074feff8c 100644 --- a/c10/util/Deprecated.h +++ b/c10/util/Deprecated.h @@ -1,102 +1,2 @@ #pragma once - -/** - * This file provides portable macros for marking declarations - * as deprecated. You should generally use C10_DEPRECATED, - * except when marking 'using' declarations as deprecated, - * in which case you should use C10_DEFINE_DEPRECATED_USING - * (due to portability concerns). - */ - -// Sample usage: -// -// C10_DEPRECATED void bad_func(); -// struct C10_DEPRECATED BadStruct { -// ... -// }; - -// NB: __cplusplus doesn't work for MSVC, so for now MSVC always uses -// the "__declspec(deprecated)" implementation and not the C++14 -// "[[deprecated]]" attribute. We tried enabling "[[deprecated]]" for C++14 on -// MSVC, but ran into issues with some older MSVC versions. -#if (defined(__cplusplus) && __cplusplus >= 201402L) -#define C10_DEPRECATED [[deprecated]] -#define C10_DEPRECATED_MESSAGE(message) [[deprecated(message)]] -#elif defined(__GNUC__) -#define C10_DEPRECATED __attribute__((deprecated)) -// TODO Is there some way to implement this? -#define C10_DEPRECATED_MESSAGE(message) __attribute__((deprecated)) - -#elif defined(_MSC_VER) -#define C10_DEPRECATED __declspec(deprecated) -#define C10_DEPRECATED_MESSAGE(message) __declspec(deprecated(message)) -#else -#warning "You need to implement C10_DEPRECATED for this compiler" -#define C10_DEPRECATED -#endif - -// Sample usage: -// -// C10_DEFINE_DEPRECATED_USING(BadType, int) -// -// which is the portable version of -// -// using BadType [[deprecated]] = int; - -// technically [[deprecated]] syntax is from c++14 standard, but it works in -// many compilers. -#if defined(__has_cpp_attribute) -#if __has_cpp_attribute(deprecated) && !defined(__CUDACC__) -#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ - using TypeName [[deprecated]] = TypeThingy; -#endif -#endif - -#if defined(_MSC_VER) -#if defined(__CUDACC__) -// neither [[deprecated]] nor __declspec(deprecated) work on nvcc on Windows; -// you get the error: -// -// error: attribute does not apply to any entity -// -// So we just turn the macro off in this case. -#if defined(C10_DEFINE_DEPRECATED_USING) -#undef C10_DEFINE_DEPRECATED_USING -#endif -#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ - using TypeName = TypeThingy; -#else -// [[deprecated]] does work in windows without nvcc, though msc doesn't support -// `__has_cpp_attribute` when c++14 is supported, otherwise -// __declspec(deprecated) is used as the alternative. -#ifndef C10_DEFINE_DEPRECATED_USING -#if defined(_MSVC_LANG) && _MSVC_LANG >= 201402L -#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ - using TypeName [[deprecated]] = TypeThingy; -#else -#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ - using TypeName = __declspec(deprecated) TypeThingy; -#endif -#endif -#endif -#endif - -#if !defined(C10_DEFINE_DEPRECATED_USING) && defined(__GNUC__) -// nvcc has a bug where it doesn't understand __attribute__((deprecated)) -// declarations even when the host compiler supports it. We'll only use this gcc -// attribute when not cuda, and when using a GCC compiler that doesn't support -// the c++14 syntax we checked for above (available in __GNUC__ >= 5) -#if !defined(__CUDACC__) -#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ - using TypeName __attribute__((deprecated)) = TypeThingy; -#else -// using cuda + gcc < 5, neither deprecated syntax is available so turning off. -#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ - using TypeName = TypeThingy; -#endif -#endif - -#if !defined(C10_DEFINE_DEPRECATED_USING) -#warning "You need to implement C10_DEFINE_DEPRECATED_USING for this compiler" -#define C10_DEFINE_DEPRECATED_USING -#endif +#include diff --git a/torch/header_only_apis.txt b/torch/header_only_apis.txt index 9f422b720d4e6..7e64efbb8b73c 100644 --- a/torch/header_only_apis.txt +++ b/torch/header_only_apis.txt @@ -219,3 +219,9 @@ HeaderOnlyGenericPackedTensorAccessor # HeaderOnlyTensorAccessorBase and # HeaderOnlyGenericPackedTensorAccessorBase are tested through # HeaderOnlyTensorAccessor and HeaderOnlyGenericPackedTensorAccessor + +# torch/headeronly/util/Deprecated.h +# C10_DEPRECATED, C10_DEPRECATED_MESSAGE, and +# C10_DEFINE_DEPRECATED_USING functionalities are expressed at compile +# time that have no effect to runtime. Therefore, these macros are not +# tested under test/. diff --git a/torch/headeronly/util/Deprecated.h b/torch/headeronly/util/Deprecated.h new file mode 100644 index 0000000000000..88440a0242eb4 --- /dev/null +++ b/torch/headeronly/util/Deprecated.h @@ -0,0 +1,102 @@ +#pragma once + +/** + * This file provides portable macros for marking declarations + * as deprecated. You should generally use C10_DEPRECATED, + * except when marking 'using' declarations as deprecated, + * in which case you should use C10_DEFINE_DEPRECATED_USING + * (due to portability concerns). + */ + +// Sample usage: +// +// C10_DEPRECATED void bad_func(); +// struct C10_DEPRECATED BadStruct { +// ... +// }; + +// NB: __cplusplus doesn't work for MSVC, so for now MSVC always uses +// the "__declspec(deprecated)" implementation and not the C++14 +// "[[deprecated]]" attribute. We tried enabling "[[deprecated]]" for C++14 on +// MSVC, but ran into issues with some older MSVC versions. +#if (defined(__cplusplus) && __cplusplus >= 201402L) +#define C10_DEPRECATED [[deprecated]] +#define C10_DEPRECATED_MESSAGE(message) [[deprecated(message)]] +#elif defined(__GNUC__) +#define C10_DEPRECATED __attribute__((deprecated)) +// TODO Is there some way to implement this? +#define C10_DEPRECATED_MESSAGE(message) __attribute__((deprecated)) + +#elif defined(_MSC_VER) +#define C10_DEPRECATED __declspec(deprecated) +#define C10_DEPRECATED_MESSAGE(message) __declspec(deprecated(message)) +#else +#warning "You need to implement C10_DEPRECATED for this compiler" +#define C10_DEPRECATED +#endif + +// Sample usage: +// +// C10_DEFINE_DEPRECATED_USING(BadType, int) +// +// which is the portable version of +// +// using BadType [[deprecated]] = int; + +// technically [[deprecated]] syntax is from c++14 standard, but it works in +// many compilers. +#if defined(__has_cpp_attribute) +#if __has_cpp_attribute(deprecated) && !defined(__CUDACC__) +#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ + using TypeName [[deprecated]] = TypeThingy; +#endif +#endif + +#if defined(_MSC_VER) +#if defined(__CUDACC__) +// neither [[deprecated]] nor __declspec(deprecated) work on nvcc on Windows; +// you get the error: +// +// error: attribute does not apply to any entity +// +// So we just turn the macro off in this case. +#if defined(C10_DEFINE_DEPRECATED_USING) +#undef C10_DEFINE_DEPRECATED_USING +#endif +#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ + using TypeName = TypeThingy; +#else +// [[deprecated]] does work in windows without nvcc, though msc doesn't support +// `__has_cpp_attribute` when c++14 is supported, otherwise +// __declspec(deprecated) is used as the alternative. +#ifndef C10_DEFINE_DEPRECATED_USING +#if defined(_MSVC_LANG) && _MSVC_LANG >= 201402L +#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ + using TypeName [[deprecated]] = TypeThingy; +#else +#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ + using TypeName = __declspec(deprecated) TypeThingy; +#endif +#endif +#endif +#endif + +#if !defined(C10_DEFINE_DEPRECATED_USING) && defined(__GNUC__) +// nvcc has a bug where it doesn't understand __attribute__((deprecated)) +// declarations even when the host compiler supports it. We'll only use this gcc +// attribute when not cuda, and when using a GCC compiler that doesn't support +// the c++14 syntax we checked for above (available in __GNUC__ >= 5) +#if !defined(__CUDACC__) +#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ + using TypeName __attribute__((deprecated)) = TypeThingy; +#else +// using cuda + gcc < 5, neither deprecated syntax is available so turning off. +#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ + using TypeName = TypeThingy; +#endif +#endif + +#if !defined(C10_DEFINE_DEPRECATED_USING) +#warning "You need to implement C10_DEFINE_DEPRECATED_USING for this compiler" +#define C10_DEFINE_DEPRECATED_USING +#endif From 7fff3172c54bc628a63eed5fac25b49d4b655294 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Wed, 19 Nov 2025 22:43:00 +0200 Subject: [PATCH 099/230] Revise stableivalue from/to deprecation (#168155) An alternative approach to https://github.com/pytorch/pytorch/pull/167923 to fix windows build failure to avoid massive replacement `from -> torch::stable::detail::from`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168155 Approved by: https://github.com/janeyx99 ghstack dependencies: #168173 --- torch/_inductor/codegen/cpp_wrapper_cpu.py | 14 ++++---- torch/csrc/stable/stableivalue_conversions.h | 35 +++++--------------- 2 files changed, 15 insertions(+), 34 deletions(-) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 61a97fd740cbc..63bff112afee2 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -2561,13 +2561,13 @@ def parse_arg(arg_type: torch.JitType, codegen_arg: str) -> str: codegen_arg = codegen_arg.removeprefix("&") if codegen_arg == "nullptr": - return "from(std::nullopt)" + return "torch::stable::detail::from(std::nullopt)" var_name = f"tmp_var_{next(tmp_var_number)}" dispatch_lines.writeline( f"std::optional {var_name}{{{parse_arg(arg_type.getElementType(), codegen_arg)}}};" ) - return f"from({var_name})" + return f"torch::stable::detail::from({var_name})" raii_var = self.create_tmp_raii_handle_var_if_needed( codegen_arg, dispatch_lines @@ -2584,11 +2584,11 @@ def parse_arg(arg_type: torch.JitType, codegen_arg: str) -> str: dispatch_lines.writeline( f"aoti_torch_new_tensor_handle({raii_var}, &{var_name});" ) - return f"from({var_name})" + return f"torch::stable::detail::from({var_name})" # If the RAII tensor _is_ a temporary scoped to this fallback call, # simply release and steal the handle. - return f"from({raii_var}.release())" - return f"from({codegen_arg})" + return f"torch::stable::detail::from({raii_var}.release())" + return f"torch::stable::detail::from({codegen_arg})" codegen_args = get_args() ivalue_args = ( @@ -2609,7 +2609,7 @@ def parse_arg(arg_type: torch.JitType, codegen_arg: str) -> str: if len(output_args) == 1 and (output := output_args[0]) is not None: # result is a single tensor dispatch_lines.writeline( - f"{output} = to(dispatch_vars[0]);" + f"{output} = torch::stable::detail::to(dispatch_vars[0]);" ) else: # result is a tuple of tensors @@ -2617,7 +2617,7 @@ def parse_arg(arg_type: torch.JitType, codegen_arg: str) -> str: if output_arg is None: continue dispatch_lines.writeline( - f"{output_arg} = to(dispatch_vars[{idx}]);" + f"{output_arg} = torch::stable::detail::to(dispatch_vars[{idx}]);" ) dispatch_lines.writeline("}") diff --git a/torch/csrc/stable/stableivalue_conversions.h b/torch/csrc/stable/stableivalue_conversions.h index 4538781594785..ed885fbe03a12 100644 --- a/torch/csrc/stable/stableivalue_conversions.h +++ b/torch/csrc/stable/stableivalue_conversions.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -345,10 +346,10 @@ struct FromImpl> { TORCH_ERROR_CODE_CHECK( torch_new_list_reserve_size(val.size(), &new_list_handle)); for (const auto& elem : val) { - TORCH_ERROR_CODE_CHECK(torch_list_push_back( - new_list_handle, torch::stable::detail::from(elem))); + TORCH_ERROR_CODE_CHECK( + torch_list_push_back(new_list_handle, from(elem))); } - return torch::stable::detail::from(new_list_handle); + return from(new_list_handle); } catch (const std::runtime_error&) { if (new_list_handle != nullptr) { // clean up memory if an error was thrown @@ -779,31 +780,11 @@ HIDDEN_NAMESPACE_END(torch, stable, detail) // WARNING! Will be removed. Only exists for BC. See [global from/to deprecation // note] template -[[deprecated("Use torch::stable::detail::from instead.")]] -inline StableIValue from(T val) { - return torch::stable::detail::from(val); -} - -// WARNING! Will be removed. Only exists for BC. See [global from/to deprecation -// note] -template -[[deprecated("Use torch::stable::detail::from instead.")]] -inline StableIValue from(const std::optional& val) { - return torch::stable::detail::from(val); -} - -// WARNING! Will be removed. Only exists for BC. See [global from/to deprecation -// note] -[[deprecated( - "Use torch::stable::detail::from instead.")]] [[maybe_unused]] inline StableIValue -from(const torch::stable::Tensor& val) { - return torch::stable::detail::from(val); -} +C10_DEPRECATED_MESSAGE("Use torch::stable::detail::from instead.") +auto from = &torch::stable::detail::from; // WARNING! Will be removed. Only exists for BC. See [global from/to deprecation // note] template -[[deprecated("Use torch::stable::detail::to instead.")]] -inline T to(StableIValue val) { - return torch::stable::detail::to(val); -} +C10_DEPRECATED_MESSAGE("Use torch::stable::detail::to instead.") +auto to = &torch::stable::detail::to; From a01e8a2ebbbd30b99dbb086187838ddd683fbf77 Mon Sep 17 00:00:00 2001 From: "Wang, Chuanqi" Date: Thu, 20 Nov 2025 16:25:06 +0000 Subject: [PATCH 100/230] [BE] Update xpu driver repo for CD used almalinux 8.10 (#157356) XPU CD docker image built on `quay.io/pypa/manylinux_2_28_x86_64`, which based on almalinux 8.10 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157356 Approved by: https://github.com/EikanWang, https://github.com/malfet --- .ci/docker/common/install_xpu.sh | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/.ci/docker/common/install_xpu.sh b/.ci/docker/common/install_xpu.sh index 22b7af890c1f6..a29de2cecb870 100644 --- a/.ci/docker/common/install_xpu.sh +++ b/.ci/docker/common/install_xpu.sh @@ -64,14 +64,13 @@ function install_ubuntu() { function install_rhel() { . /etc/os-release - if [[ "${ID}" == "rhel" ]]; then - if [[ ! " 8.8 8.10 9.0 9.2 9.3 " =~ " ${VERSION_ID} " ]]; then - echo "RHEL version ${VERSION_ID} not supported" - exit - fi - elif [[ "${ID}" == "almalinux" ]]; then - # Workaround for almalinux8 which used by quay.io/pypa/manylinux_2_28_x86_64 - VERSION_ID="8.8" + if [[ ! " 8.8 8.10 9.0 9.2 9.3 " =~ " ${VERSION_ID} " ]]; then + echo "RHEL version ${VERSION_ID} not supported" + exit + fi + # Using testing channel for CD build + if [[ "${ID}" == "almalinux" ]]; then + XPU_DRIVER_VERSION="/testing" fi dnf install -y 'dnf-command(config-manager)' From ba682386850c02f9443bfbc5edb0cbf6ad99b187 Mon Sep 17 00:00:00 2001 From: PaulZhang12 Date: Mon, 17 Nov 2025 14:10:02 -0800 Subject: [PATCH 101/230] [Inductor] Freeze layout for potentially padded strides in template autotuning (#168032) Properly freeze layouts to ensure that padded strides are reflected in templates during codegen. Otherwise, there will be a difference between the template codegen'd and the input tensors for benchmarking, leading to CUDA IMA Differential Revision: [D87273721](https://our.internmc.facebook.com/intern/diff/D87273721) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168032 Approved by: https://github.com/eellison, https://github.com/njriasan --- test/inductor/test_max_autotune.py | 60 +++++++++++++++++++++++++++++ torch/_inductor/select_algorithm.py | 9 +++-- 2 files changed, 66 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 46a63db754697..90714b58951b1 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -2478,6 +2478,66 @@ def layout_checker(choices): finally: clear_preprocessing_fns(clear_defaults=False) + @config.patch( + {"test_configs.max_mm_configs": 4, "max_autotune_gemm_backends": "TRITON"} + ) + def test_fixed_layout_at_lowering(self): + """ + Test that max-autotune with addmm/bmm/mm_plus_mm correctly handles + padding and maintains correct output strides. Specifically, when matrix + b with shape (4608, 1490) is padded, its stride should become 1536. + """ + + def mm_func(a, b) -> torch.Tensor: + a_t = torch.permute(a, [1, 0]).to(torch.bfloat16) + b_dtype = b.to(torch.bfloat16) + # Add .to() to make sure that mm could be potentially padded + # Strides for output are not padded + return (a_t @ b_dtype).to(torch.float32) + + def addmm_func(a, b, bias) -> torch.Tensor: + a_t = torch.permute(a, [1, 0]).to(torch.bfloat16) + b_dtype = b.to(torch.bfloat16) + bias_dtype = bias.to(torch.bfloat16) + return torch.addmm(bias_dtype, a_t, b_dtype).to(torch.float32) + + def bmm_func(a, b) -> torch.Tensor: + a_t = torch.permute(a, [2, 0, 1]).to(torch.bfloat16) + b_dtype = b.to(torch.bfloat16) + return torch.bmm(a_t, b_dtype).to(torch.float32) + + def mm_plus_mm_func(a1, b1, a2, b2) -> torch.Tensor: + a1_t = torch.permute(a1, [1, 0]).to(torch.bfloat16) + b1_dtype = b1.to(torch.bfloat16) + a2_t = torch.permute(a2, [1, 0]).to(torch.bfloat16) + b2_dtype = b2.to(torch.bfloat16) + return (a1_t @ b1_dtype + a2_t @ b2_dtype).to(torch.float32) + + a = torch.randn((4608, 512), device=GPU_TYPE, dtype=torch.bfloat16) + b = torch.randn((4608, 1490), device=GPU_TYPE) + bias = torch.randn(1490, device=GPU_TYPE) + + a_bmm = torch.randn((512, 4608, 8), device=GPU_TYPE, dtype=torch.bfloat16) + b_bmm = torch.randn((8, 4608, 1490), device=GPU_TYPE) + + # Test mm_plus_mm + a2 = torch.randn((4608, 512), device=GPU_TYPE, dtype=torch.bfloat16) + b2 = torch.randn((4608, 1490), device=GPU_TYPE) + + # 1490 padded to 1536, check in template code + output_code_padding_check = "stride_bk = 1536" + funcs_and_args = [ + (mm_func, (a, b)), + (addmm_func, (a, b, bias)), + (bmm_func, (a_bmm, b_bmm)), + (mm_plus_mm_func, (a, b, a2, b2)), + ] + + for f, args in funcs_and_args: + c_f = torch.compile(f, mode="max-autotune-no-cudagraphs") + _, code_out = run_and_get_code(c_f, *args) + FileCheck().check(output_code_padding_check).run(code_out[0]) + class TestMaxAutotunePrecompile(TestCase): def test_precompilation_threads(self): diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index d6893b07ee3d9..493ca1179fad8 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -777,7 +777,7 @@ def stride(self, name, index=None): val = self.output_node.get_stride() else: assert isinstance(name, str) - val = self.named_input_nodes[name].get_stride() + val = self.get_stride_and_maybe_freeze_layout(self.named_input_nodes[name]) if isinstance(index, int): return texpr(self.rename_indexing(val[index])) @@ -955,7 +955,6 @@ def load_input( self.template_mask = mask if mask is not None else "None" self.template_out_shape = index_shape if index_shape else "xindex" self.template_indices = indices - self.named_input_nodes[input_name].data.freeze_layout() self.cse.invalidate(OrderedSet()) template_mask = self.template_mask @@ -1412,7 +1411,7 @@ def make_load(self, name, indices, mask): assert isinstance(indices, (list, tuple)) assert isinstance(name, str) assert isinstance(mask, str) - stride = self.named_input_nodes[name].get_stride() + stride = self.get_stride_and_maybe_freeze_layout(self.named_input_nodes[name]) indices = list(map(OpOverrides.paren, indices)) assert len(indices) == len(stride) index = " + ".join( @@ -1502,6 +1501,10 @@ def kernel_benchmark_extra_args(self) -> list[str]: ) ] + def get_stride_and_maybe_freeze_layout(self, node) -> list[int]: + node.data.freeze_layout() + return node.get_stride() + @functools.cache def _jinja2_env(): From f4382d7f9851d4f268c23affd0cf1f41a6e92fff Mon Sep 17 00:00:00 2001 From: arkadip-maitra Date: Thu, 20 Nov 2025 17:27:50 +0000 Subject: [PATCH 102/230] Fixes floor divide int min overflow issue (#166127) Fixes #127804 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166127 Approved by: https://github.com/albanD --- c10/util/generic_math.h | 6 ++++++ test/test_binary_ufuncs.py | 12 ++++++++++++ 2 files changed, 18 insertions(+) diff --git a/c10/util/generic_math.h b/c10/util/generic_math.h index 493c03cb42e64..8770977840cb2 100644 --- a/c10/util/generic_math.h +++ b/c10/util/generic_math.h @@ -58,6 +58,12 @@ inline C10_HOST_DEVICE scalar_t div_floor_floating(scalar_t a, scalar_t b) template inline C10_HOST_DEVICE scalar_t div_floor_integer(scalar_t a, scalar_t b) { + if (C10_UNLIKELY( + std::is_signed::value && + a == std::numeric_limits::min() && b == scalar_t(-1))) { + return a; + } + if (c10::signs_differ(a, b)) { // Subtracts one from the results of truncation division if the // divisor and dividend have different sign(bit)s and the remainder of diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index 2b5606aec98d6..16c4deb7fbddd 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -3058,6 +3058,18 @@ def test_floor_divide_zero(self, device, dtype): with self.assertWarnsOnceRegex(UserWarning, "floor_divide"): a // b + @dtypes(torch.int8, torch.int16, torch.int32, torch.int64) + def test_floor_divide_int_min(self, device, dtype): + int_min = torch.iinfo(dtype).min + a = torch.tensor([int_min], dtype=dtype, device=device) + b = torch.tensor([-1], dtype=dtype, device=device) + + result = torch.floor_divide(a, b) + result_ = a // b + + self.assertEqual(result, a) + self.assertEqual(result_, a) + @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) def test_muldiv_scalar(self, device, dtype): x = make_tensor((10, 3), dtype=dtype, device=device, low=None, high=None) From dd89d2c043ed7c05b234503ee1e99f1a99669026 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 19 Nov 2025 14:47:50 -0800 Subject: [PATCH 103/230] [DTensor] Document fast-path dispatch (#168192) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168192 Approved by: https://github.com/albanD --- torch/distributed/tensor/_api.py | 3 +++ torch/distributed/tensor/_dispatch.py | 11 +++++++++++ 2 files changed, 14 insertions(+) diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py index 3946f9249d0de..fb072d8dce629 100644 --- a/torch/distributed/tensor/_api.py +++ b/torch/distributed/tensor/_api.py @@ -349,6 +349,9 @@ def __coerce_same_metadata_as_tangent__(self, flatten_spec, expected_type=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] # We just need to have an implementation here; the __torch_dispatch__ machinery # calls into a specific C++ fast path that doesn't call here. + # See #167051 for details + # python_arg_parser.cpp: dispatch_on_subclass() + # -> python_variable.cpp: dispatchDTensorOp() raise NotImplementedError( "DTensor.__torch_dispatch__ should not actually get called" ) diff --git a/torch/distributed/tensor/_dispatch.py b/torch/distributed/tensor/_dispatch.py index 630f327add3d7..f52538c0cf368 100644 --- a/torch/distributed/tensor/_dispatch.py +++ b/torch/distributed/tensor/_dispatch.py @@ -154,6 +154,17 @@ def __init__(self) -> None: aten.as_strided.default: as_strided_handler, } + # ******************************************************************************************** + # def dispatch(...) + # + # NOTE: this class no longer contains the top-level dispatch entrypoint! + # See #167051 for details + # + # The entrypoint has been moved to C++, and it handles common cases and then calls back into + # OpDispatcher python to handle corner cases. + # See dispatchDTensorOp() defined in python_variable.cpp and called from python_arg_parser.cpp + # ******************************************************************************************** + # This flag is used internally to control whether we treat the torch.Tensor(non-DTensor) # as implicitly replicated or we throw error to user. # NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave From 32b92600c1c4d78c1322e6923b442ab039792a02 Mon Sep 17 00:00:00 2001 From: arkadip-maitra Date: Thu, 20 Nov 2025 17:28:34 +0000 Subject: [PATCH 104/230] Fixes remainder and fmod operation and makes it same as cuda (#165833) Fixes #165649 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165833 Approved by: https://github.com/albanD --- aten/src/ATen/native/cpu/BinaryOpsKernel.cpp | 6 ++++++ test/test_binary_ufuncs.py | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index b5f3d91692b9a..26ec55c11d823 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -353,6 +353,9 @@ void remainder_kernel(TensorIteratorBase& iter) { AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "remainder_cpu", [&]() { cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t { TORCH_CHECK(b != 0, "ZeroDivisionError"); + if (a == std::numeric_limits::min() && b == scalar_t(-1)) { + return 0; + } scalar_t r = a % b; if ((r != 0) && (c10::is_negative(r) != c10::is_negative(b))) { r += b; @@ -1035,6 +1038,9 @@ void fmod_kernel(TensorIteratorBase& iter) { AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "fmod_cpu", [&]() { cpu_kernel(iter, [=](scalar_t x, scalar_t d) -> scalar_t { TORCH_CHECK(d != 0, "ZeroDivisionError"); + if (x == std::numeric_limits::min() && d == scalar_t(-1)) { + return 0; + } return x % d; }); }); diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index 16c4deb7fbddd..d448f95319416 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -2756,6 +2756,25 @@ def test_fmod_remainder_by_zero_integral(self, device, dtype): value = 255 if dtype == torch.uint8 else -1 self.assertTrue(torch.all(fn(x, zero) == value)) + @onlyNativeDeviceTypes + @dtypes(*integral_types()) + def test_fmod_remainder_overflow(self, device, dtype): + fn_list = (torch.fmod, torch.remainder) + for fn in fn_list: + if dtype in [torch.uint8, torch.uint16, torch.uint32, torch.uint64]: + continue + + min_val = torch.iinfo(dtype).min + dividend = torch.full((2, 3), min_val, dtype=dtype, device=device) + divisor = torch.full((3,), -1, dtype=dtype, device=device) + + result = fn(dividend, divisor) + expected = torch.zeros_like(dividend) + self.assertEqual(result, expected) + + result_scalar = fn(dividend, -1) + self.assertEqual(result_scalar, expected) + @dtypes(*all_types_and(torch.half)) def test_fmod_remainder(self, device, dtype): # Use numpy as reference From 29bd2ddb312fee734714262222ed26d0b1459b59 Mon Sep 17 00:00:00 2001 From: Parshant Sharma Date: Thu, 20 Nov 2025 17:32:59 +0000 Subject: [PATCH 105/230] Fix: Remove incorrect non-negative validation for correction parameter in torch.var during export (#162254) Fixes #161083 ### Summary: This PR fixes a bug where `torch.export.export `incorrectly validates that the `correction` parameter for `torch.var` should be non-negative, causing export to fail when using negative correction values that are perfectly valid and work correctly in eager PyTorch execution. ### Fix Impact: - **Improved Consistency**: Aligns export behavior with eager PyTorch functionality - **Enhanced Compatibility:** Allows more models to be exported successfully Pull Request resolved: https://github.com/pytorch/pytorch/pull/162254 Approved by: https://github.com/isuruf --- torch/_prims_common/__init__.py | 2 -- torch/testing/_internal/common_methods_invocations.py | 3 +++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 9ba46e8c5310c..019e9c59f2423 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -1895,8 +1895,6 @@ def set_correction( # NB: we don't actually support symint here, but it's harmless to accept if not isinstance(correction, (IntLike, FloatLike)): raise ValueError("correction argument should be integer or float") - if correction < 0: - raise ValueError("correction argument should be non-negative") return sym_float(correction) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 6724ab2ae739a..cf2fd54490591 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -6185,6 +6185,9 @@ def sample_inputs_std_var(op_info, device, dtype, requires_grad, **kwargs): yield SampleInput(tensor_nd(), dim=(1,), correction=S // 2) yield SampleInput(tensor_nd(), dim=None, correction=0, keepdim=True) yield SampleInput(tensor_nd(), dim=None, correction=None) + yield SampleInput(tensor_nd(), dim=None, correction=-1) + yield SampleInput(tensor_nd(), dim=None, correction=-5) + yield SampleInput(tensor_nd(), correction=0.5, keepdim=True) yield SampleInput(tensor_nd(), correction=0, keepdim=True) yield SampleInput(make_tensor(3, 4, 5, device=device, dtype=dtype, requires_grad=requires_grad), dim=-3) From 88d635c54f73393f50cb795cfa13b15ba7d7339b Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Thu, 20 Nov 2025 17:53:44 +0000 Subject: [PATCH 106/230] Remove useless super() delegation (#168235) This PR removes useless super() delegations detected by pylint. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168235 Approved by: https://github.com/albanD, https://github.com/zou3519 --- torch/_dynamo/exc.py | 15 +-- torch/_dynamo/variables/base.py | 3 - torch/_dynamo/variables/dicts.py | 14 --- torch/_higher_order_ops/_invoke_quant.py | 3 - .../learnedheuristic_interface.py | 6 - torch/_inductor/codegen/cpp.py | 3 - torch/_inductor/codegen/cuda/gemm_template.py | 13 -- .../ao/nn/intrinsic/qat/modules/conv_fused.py | 117 ------------------ .../data_sparsifier/benchmarks/dlrm_utils.py | 3 - .../quantizer/embedding_quantizer.py | 3 - .../quantizer/xpu_inductor_quantizer.py | 3 - torch/backends/__init__.py | 3 - torch/backends/cudnn/__init__.py | 3 - torch/backends/miopen/__init__.py | 3 - torch/backends/mkldnn/__init__.py | 3 - torch/backends/opt_einsum/__init__.py | 3 - .../_checkpoint/checkpoint_wrapper.py | 3 - torch/jit/_monkeytype_config.py | 3 - torch/jit/_recursive.py | 3 +- torch/multiprocessing/spawn.py | 8 -- torch/testing/_internal/common_utils.py | 2 - 21 files changed, 5 insertions(+), 212 deletions(-) diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index f11c78bdaa49e..5b0e8a402dd96 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -198,24 +198,20 @@ class RecompileError(TorchDynamoException): class ArgsMismatchError(Unsupported): - def __init__(self, msg: str) -> None: - super().__init__(msg) + pass class AttributeMutationError(Unsupported): - def __init__(self, msg: str) -> None: - super().__init__(msg) + pass class InfiniteGeneratorError(Unsupported): # Raised when the number of yielded values is greater than MAX_ITERATOR_LIMIT - def __init__(self, msg: str) -> None: - super().__init__(msg) + pass class SideEffectsError(Unsupported): - def __init__(self, msg: str) -> None: - super().__init__(msg) + pass class CondOpArgsMismatchError(ArgsMismatchError): @@ -223,9 +219,6 @@ class CondOpArgsMismatchError(ArgsMismatchError): Internal error from cond() due to arguments mismatch. """ - def __init__(self, msg: str) -> None: - super().__init__(msg) - class UserErrorType(Enum): DYNAMIC_CONTROL_FLOW = auto() diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index 4e248320e60b6..2d11a27bafac0 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -151,9 +151,6 @@ class AttributeMutation(MutationType): allows mutation on the value's attributes. """ - def __init__(self, typ: SourceType) -> None: - super().__init__(typ) - class AttributeMutationExisting(AttributeMutation): """ diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 24cd5007da37d..636875d85e54a 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -1296,13 +1296,6 @@ def install_dict_contains_guard( class FrozensetVariable(SetVariable): - def __init__( - self, - items: list[VariableTracker], - **kwargs: Any, - ) -> None: - super().__init__(items, **kwargs) - def debug_repr(self) -> str: if not self.items: return "frozenset()" @@ -1360,13 +1353,6 @@ def call_method( class DictKeySetVariable(SetVariable): - def __init__( - self, - items: list[VariableTracker], - **kwargs: Any, - ) -> None: - super().__init__(items, **kwargs) - def debug_repr(self) -> str: if not self.items: return "dict_keys([])" diff --git a/torch/_higher_order_ops/_invoke_quant.py b/torch/_higher_order_ops/_invoke_quant.py index 1fc1e1114a036..b7a9fb94b93e2 100644 --- a/torch/_higher_order_ops/_invoke_quant.py +++ b/torch/_higher_order_ops/_invoke_quant.py @@ -26,9 +26,6 @@ class InvokeQuantUnpacked(BaseHOP): def __init__(self) -> None: super().__init__("invoke_quant") - def __call__(self, subgraph, *operands, scheme=None): - return super().__call__(subgraph, *operands, scheme=scheme) - invoke_quant = InvokeQuantUnpacked() diff --git a/torch/_inductor/autoheuristic/learnedheuristic_interface.py b/torch/_inductor/autoheuristic/learnedheuristic_interface.py index cb2568d8a6801..84a941b076c31 100644 --- a/torch/_inductor/autoheuristic/learnedheuristic_interface.py +++ b/torch/_inductor/autoheuristic/learnedheuristic_interface.py @@ -39,9 +39,6 @@ def get_decisions_ranked(self, context: AHContext) -> Optional[list[str]]: class LearnedHeuristicRegression(LearnedHeuristic): - def __init__(self) -> None: - super().__init__() - def get_feedback(self, context: AHContext, choice: Choice) -> float: return 1.0 @@ -64,9 +61,6 @@ def get_decision( class LearnedHeuristicDecision(LearnedHeuristic): - def __init__(self) -> None: - super().__init__() - def get_choice(self, idx: int) -> Optional[str]: return None diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 88f203421cc1c..18b209de94cb3 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -3786,9 +3786,6 @@ class TilingSelect: In the future, we can implement advanced heuristic in a subclass. """ - def __init__(self): - super().__init__() - def select_tiling( self, fn_list, diff --git a/torch/_inductor/codegen/cuda/gemm_template.py b/torch/_inductor/codegen/cuda/gemm_template.py index 22d0981febecd..c4b7188bd9e62 100644 --- a/torch/_inductor/codegen/cuda/gemm_template.py +++ b/torch/_inductor/codegen/cuda/gemm_template.py @@ -1330,19 +1330,6 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate): including those which allow flexible fusions with epilogues. """ - def __init__( - self, - input_nodes: list[Buffer], - layout: Layout, - alpha: float, - beta: float, - input_reorder: Optional[list[int]] = None, - use_fast_accum: Optional[bool] = None, - ): - super().__init__( - input_nodes, layout, alpha, beta, input_reorder, use_fast_accum - ) - @staticmethod def add_cutlass_gemm_choices( choices: list[ChoiceCaller], diff --git a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py index 0054e996e33ce..1e49a274e129c 100644 --- a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py +++ b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py @@ -112,9 +112,6 @@ def reset_bn_parameters(self): bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) - def reset_parameters(self): - super().reset_parameters() - def update_bn_stats(self): self.freeze_bn = False self.bn.training = True @@ -534,44 +531,6 @@ class ConvBnReLU1d(ConvBn1d): # module class after fusing bn into conv _FUSED_FLOAT_MODULE: ClassVar[type[nn.Module] | None] = nni.ConvReLU1d - def __init__( - self, - # Conv1d args - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=None, - padding_mode="zeros", - # BatchNorm1d args - # num_features: out_channels - eps=1e-05, - momentum=0.1, - # affine: True - # track_running_stats: True - # Args for this module - freeze_bn=False, - qconfig=None, - ): - super().__init__( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - groups, - bias, - padding_mode, - eps, - momentum, - freeze_bn, - qconfig, - ) - def forward(self, input): return F.relu(self._forward(input)) @@ -735,44 +694,6 @@ class ConvBnReLU2d(ConvBn2d): # module class after fusing bn into conv _FUSED_FLOAT_MODULE: ClassVar[type[nni.ConvReLU2d] | None] = nni.ConvReLU2d - def __init__( - self, - # Conv2d args - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=None, - padding_mode="zeros", - # BatchNorm2d args - # num_features: out_channels - eps=1e-05, - momentum=0.1, - # affine: True - # track_running_stats: True - # Args for this module - freeze_bn=False, - qconfig=None, - ): - super().__init__( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - groups, - bias, - padding_mode, - eps, - momentum, - freeze_bn, - qconfig, - ) - def forward(self, input): return F.relu(self._forward(input)) @@ -935,44 +856,6 @@ class ConvBnReLU3d(ConvBn3d): # module class after fusing bn into conv _FUSED_FLOAT_MODULE: ClassVar[type[nni.ConvReLU3d] | None] = nni.ConvReLU3d - def __init__( - self, - # Conv3d args - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=None, - padding_mode="zeros", - # BatchNorm3d args - # num_features: out_channels - eps=1e-05, - momentum=0.1, - # affine: True - # track_running_stats: True - # Args for this module - freeze_bn=False, - qconfig=None, - ): - super().__init__( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - groups, - bias, - padding_mode, - eps, - momentum, - freeze_bn, - qconfig, - ) - def forward(self, input): return F.relu(ConvBn3d._forward(self, input)) diff --git a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py index 3c146c55947a0..e2b31e0e563bf 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py @@ -19,9 +19,6 @@ class SparseDLRM(DLRM_Net): layer of the top layer. """ - def __init__(self, **args): - super().__init__(**args) - def forward(self, dense_x, lS_o, lS_i): # pyrefly: ignore [missing-attribute] x = self.apply_mlp(dense_x, self.bot_l) # dense features diff --git a/torch/ao/quantization/quantizer/embedding_quantizer.py b/torch/ao/quantization/quantizer/embedding_quantizer.py index b0f1b823b7fdb..3b8ef1030bfdc 100644 --- a/torch/ao/quantization/quantizer/embedding_quantizer.py +++ b/torch/ao/quantization/quantizer/embedding_quantizer.py @@ -41,9 +41,6 @@ def get_embedding_operators_config() -> OperatorConfig: class EmbeddingQuantizer(Quantizer): - def __init__(self) -> None: - super().__init__() - @classmethod def get_supported_quantization_configs(cls) -> list[QuantizationConfig]: op_configs: set[QuantizationConfig] = { diff --git a/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py b/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py index d19968c2787f4..1c0fc48fd54fa 100644 --- a/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py @@ -75,9 +75,6 @@ class XPUInductorQuantizer(X86InductorQuantizer): of the optimized kernels in oneDNN library. """ - def __init__(self) -> None: - super().__init__() - """ Following annotate_xx overrides the impls in base class, as no XPU implementation for these operators currently. We would diff --git a/torch/backends/__init__.py b/torch/backends/__init__.py index c02a8c36fd08b..f54a3fd6820c7 100644 --- a/torch/backends/__init__.py +++ b/torch/backends/__init__.py @@ -113,9 +113,6 @@ def inner(precision): class GenericModule(PropModule): - def __init__(self, m, name): - super().__init__(m, name) - fp32_precision = ContextProp( _get_fp32_precision_getter("generic", "all"), _set_fp32_precision_setter("generic", "all"), diff --git a/torch/backends/cudnn/__init__.py b/torch/backends/cudnn/__init__.py index 697783c01cb64..267594531db3d 100644 --- a/torch/backends/cudnn/__init__.py +++ b/torch/backends/cudnn/__init__.py @@ -198,9 +198,6 @@ def flags( class CudnnModule(PropModule): - def __init__(self, m, name): - super().__init__(m, name) - enabled = ContextProp(torch._C._get_cudnn_enabled, torch._C._set_cudnn_enabled) deterministic = ContextProp( torch._C._get_cudnn_deterministic, torch._C._set_cudnn_deterministic diff --git a/torch/backends/miopen/__init__.py b/torch/backends/miopen/__init__.py index 93453cc11592d..1b270b658e31a 100644 --- a/torch/backends/miopen/__init__.py +++ b/torch/backends/miopen/__init__.py @@ -37,9 +37,6 @@ def flags( class MiopenModule(PropModule): - def __init__(self, m, name): - super().__init__(m, name) - immediate = ContextProp( torch._C._get_miopen_immediate, torch._C._set_miopen_immediate ) diff --git a/torch/backends/mkldnn/__init__.py b/torch/backends/mkldnn/__init__.py index 2d1ce8f3bb997..58e6b2c595e98 100644 --- a/torch/backends/mkldnn/__init__.py +++ b/torch/backends/mkldnn/__init__.py @@ -110,9 +110,6 @@ def flags(enabled=False, deterministic=False, allow_tf32=True, fp32_precision="n class MkldnnModule(PropModule): - def __init__(self, m, name): - super().__init__(m, name) - def is_available(self): return is_available() diff --git a/torch/backends/opt_einsum/__init__.py b/torch/backends/opt_einsum/__init__.py index 797d847e31e5c..264be78aa9a1c 100644 --- a/torch/backends/opt_einsum/__init__.py +++ b/torch/backends/opt_einsum/__init__.py @@ -101,9 +101,6 @@ def flags(enabled=None, strategy=None): class OptEinsumModule(PropModule): - def __init__(self, m, name): - super().__init__(m, name) - global enabled enabled = ContextProp(_get_enabled, _set_enabled) global strategy diff --git a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py index 3ce067f6cddc0..eae76e8cc72af 100644 --- a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py +++ b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py @@ -103,9 +103,6 @@ def _pre_load_state_dict_hook( class OffloadWrapper(ActivationWrapper): - def __init__(self, mod): - super().__init__(mod) - def forward(self, *args, **kwargs): with save_on_cpu(pin_memory=True): return self._checkpoint_wrapped_module(*args, **kwargs) diff --git a/torch/jit/_monkeytype_config.py b/torch/jit/_monkeytype_config.py index 0f348590ea397..e5ddc1e443a29 100644 --- a/torch/jit/_monkeytype_config.py +++ b/torch/jit/_monkeytype_config.py @@ -85,9 +85,6 @@ def get_qualified_name(func): class JitTypeTraceStoreLogger(CallTraceStoreLogger): """A JitTypeCallTraceLogger that stores logged traces in a CallTraceStore.""" - def __init__(self, store: CallTraceStore) -> None: - super().__init__(store) - def log(self, trace: CallTrace) -> None: # pyrefly: ignore [missing-attribute] self.traces.append(trace) diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index 75355cbd4b8e0..ec4bbd125119d 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -152,8 +152,7 @@ def _get_valid_constant(attr, v, owner_type): class SourceContext(torch._C._jit_tree_views.SourceRangeFactory): - def __init__(self, source, filename, file_lineno, leading_whitespace_len) -> None: - super().__init__(source, filename, file_lineno, leading_whitespace_len) + pass def get_annotations(obj): diff --git a/torch/multiprocessing/spawn.py b/torch/multiprocessing/spawn.py index f553f7cacd753..12901df09a3c5 100644 --- a/torch/multiprocessing/spawn.py +++ b/torch/multiprocessing/spawn.py @@ -46,14 +46,6 @@ def __reduce__(self): class ProcessRaisedException(ProcessException): """Exception raised when a process failed due to an exception raised by the code.""" - def __init__( - self, - msg: str, - error_index: int, - error_pid: int, - ): - super().__init__(msg, error_index, error_pid) - class ProcessExitedException(ProcessException): """Exception raised when a process failed due to signal or exited with a specific code.""" diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index d5afc413daed8..815cc8859080f 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1370,8 +1370,6 @@ class XMLTestResultVerbose(_XMLTestResult): This works with unittest_xml_reporting<=3.2.0,>=2.0.0 (3.2.0 is latest at the moment) """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) def addSkip(self, test, reason): super().addSkip(test, reason) From f97c3fc8e4f7cd90c5f1613f853e0c3c4e55c73d Mon Sep 17 00:00:00 2001 From: Fadi Arafeh Date: Fri, 31 Oct 2025 16:52:43 +0000 Subject: [PATCH 107/230] Re-enable ConvTranspose operator benchmarks for AArch64 (#166731) This was disabled by #165585 due to #165654 which was fixed by #165904 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166731 Approved by: https://github.com/malfet ghstack dependencies: #165904 --- benchmarks/operator_benchmark/pt/conv_test.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/benchmarks/operator_benchmark/pt/conv_test.py b/benchmarks/operator_benchmark/pt/conv_test.py index eb94921989ccf..f972db3f1693e 100644 --- a/benchmarks/operator_benchmark/pt/conv_test.py +++ b/benchmarks/operator_benchmark/pt/conv_test.py @@ -43,15 +43,12 @@ def forward(self, input): Conv1dBenchmark, ) - -if not torch.backends.mkldnn.is_acl_available(): - # convtranpose1d crashes with ACL, see https://github.com/pytorch/pytorch/issues/165654 - op_bench.generate_pt_test( - configs.convtranspose_1d_configs_short - + configs.conv_1d_configs_short - + configs.conv_1d_configs_long, - ConvTranspose1dBenchmark, - ) +op_bench.generate_pt_test( + configs.convtranspose_1d_configs_short + + configs.conv_1d_configs_short + + configs.conv_1d_configs_long, + ConvTranspose1dBenchmark, +) """ From 53a4b49ea49b8f957e43ceca4a4fd04cd3adde16 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 12 Nov 2025 11:55:43 -0800 Subject: [PATCH 108/230] [Pipelining] Fix error log (#167668) Minor logging fix It does make the logging wider but its better than having the lines interspersed with unrelated lines due to mixed use of print and logging. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167668 Approved by: https://github.com/fegin --- torch/distributed/pipelining/schedules.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py index abc007a8166db..7bdf3c65e4e8f 100644 --- a/torch/distributed/pipelining/schedules.py +++ b/torch/distributed/pipelining/schedules.py @@ -2272,9 +2272,7 @@ def _perform_action(action: _Action) -> None: time_step, action, ) - # TODO(whc) what is the best practice for printing a multiline log? - # logger will split it into multiple log lines, but this makes it hard to read (too wide) - print( + logger.error( _format_pipeline_order( self.pipeline_order_with_comms, # type: ignore[arg-type] error_step_number=time_step, From 803d94be29621ca07e1432af85e41ad240817fd8 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 20 Nov 2025 18:14:35 +0000 Subject: [PATCH 109/230] Revert "[dynamo][pytree][compile time] Specialize tree_is_leaf (#168070)" This reverts commit a8ccc4e84f8f99192cf94cb6ef9ea08f295ba881. Reverted https://github.com/pytorch/pytorch/pull/168070 on behalf of https://github.com/anijain2305 due to failed at D87489130 ([comment](https://github.com/pytorch/pytorch/pull/168054#issuecomment-3559399563)) --- test/dynamo/test_repros.py | 65 ---------------------------- torch/_dynamo/trace_rules.py | 2 - torch/_dynamo/variables/__init__.py | 1 - torch/_dynamo/variables/functions.py | 65 ---------------------------- 4 files changed, 133 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 24b8f4c48aa32..aab7d5268fcdc 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -8243,71 +8243,6 @@ def fn(a, b): # Should compile successfully with fullgraph=True self.assertEqual(cnt.frame_count, 1) - def test_pytree_tree_is_leaf_not_traced(self): - # Test that torch.utils._pytree.tree_is_leaf is not traced into - # when is_leaf parameter is None (the common case) - from torch.utils._pytree import tree_is_leaf - - cnt = torch._dynamo.testing.CompileCounter() - - @torch.compile(backend=cnt, fullgraph=True) - def fn(x, y): - # Test with various types - # Tensors are leaves - is_leaf_tensor = tree_is_leaf(x) - assert is_leaf_tensor is True - - # Lists are not leaves (they're in SUPPORTED_NODES) - is_leaf_list = tree_is_leaf([x, y]) - assert is_leaf_list is False - - # Dicts are not leaves - is_leaf_dict = tree_is_leaf({"a": x, "b": y}) - assert is_leaf_dict is False - - return x + y - - x = torch.randn(3, 4) - y = torch.randn(3, 4) - result = fn(x, y) - expected = x + y - - self.assertTrue(torch.allclose(result, expected)) - # Should compile successfully with fullgraph=True - self.assertEqual(cnt.frame_count, 1) - - def test_pytree_tree_is_leaf_with_namedtuple(self): - # Test that torch.utils._pytree.tree_is_leaf handles namedtuples correctly - from collections import namedtuple - - from torch.utils._pytree import tree_is_leaf - - Point = namedtuple("Point", ["x", "y"]) - - cnt = torch._dynamo.testing.CompileCounter() - - @torch.compile(backend=cnt, fullgraph=True) - def fn(a, b): - # Namedtuples are not leaves (they're in SUPPORTED_NODES) - point = Point(a, b) - is_leaf_namedtuple = tree_is_leaf(point) - assert is_leaf_namedtuple is False - - # But individual tensors are leaves - is_leaf_tensor = tree_is_leaf(a) - assert is_leaf_tensor is True - - return a + b - - x = torch.randn(3, 4) - y = torch.randn(3, 4) - result = fn(x, y) - expected = x + y - - self.assertTrue(torch.allclose(result, expected)) - # Should compile successfully with fullgraph=True - self.assertEqual(cnt.frame_count, 1) - instantiate_parametrized_tests(ReproTests) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 083c8b1f93807..36093b042002e 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -65,7 +65,6 @@ NestedUserFunctionVariable, PolyfilledFunctionVariable, PyTreeGetNodeTypeFunctionVariable, - PyTreeTreeIsLeafFunctionVariable, ReparametrizeModuleCallVariable, SkipFunctionVariable, TorchInGraphFunctionVariable, @@ -381,7 +380,6 @@ "torch/testing/_internal/common_distributed.py#forward": UserFunctionVariable, f"torch/testing/_internal/common_distributed.py#{TORCH_DYNAMO_RESUME_IN_PREFIX}": UserFunctionVariable, "torch.utils._pytree._get_node_type": PyTreeGetNodeTypeFunctionVariable, - "torch.utils._pytree.tree_is_leaf": PyTreeTreeIsLeafFunctionVariable, } diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index 439ce274b7ce6..ac0be3e5888be 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -65,7 +65,6 @@ NestedUserFunctionVariable, PolyfilledFunctionVariable, PyTreeGetNodeTypeFunctionVariable, - PyTreeTreeIsLeafFunctionVariable, SkipFunctionVariable, TMADescriptorExperimentalVariable, TMADescriptorStableVariable, diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 7916187193bae..459b8e0bf6230 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -65,7 +65,6 @@ DefaultsSource, GetItemSource, SkipGuardSource, - TorchSource, TypeSource, ) from ..utils import ( @@ -119,13 +118,6 @@ _spec_cache: WeakKeyDictionary[Any, Any] = WeakKeyDictionary() -@functools.lru_cache -def get_pytree_SUPPORTED_NODES_source(): - return AttrSource( - AttrSource(AttrSource(TorchSource(), "utils"), "_pytree"), "SUPPORTED_NODES" - ) - - class FunctionSpec: def __init__(self, func: FunctionType): code = func.__code__ @@ -2764,60 +2756,3 @@ def call_function( if is_namedtuple_class(python_type): return VariableTracker.build(tx, namedtuple) return VariableTracker.build(tx, python_type, source=type_source) - - -class PyTreeTreeIsLeafFunctionVariable(UserFunctionVariable): - """ - `torch.utils._pytree.tree_is_leaf` function is a hot function. We want to special case it to reduce Dynamo tracing time. - - def tree_is_leaf( - tree: PyTree, - is_leaf: Callable[[PyTree], bool] | None = None, - ) -> bool: - if is_leaf is not None and is_leaf(tree): - return True - return _get_node_type(tree) not in SUPPORTED_NODES - - When is_leaf is None (the common case), we can optimize by not tracing into the function. - When is_leaf is not None, we fall back to regular tracing since it requires executing user code. - """ - - def call_function( - self, - tx: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: - # tree_is_leaf(tree, is_leaf=None) - if len(args) < 1 or len(args) > 2: - raise_type_error_exc( - tx, - f"tree_is_leaf requires 1 or 2 arguments, got {len(args)}", - ) - - # Check if is_leaf parameter is provided - is_leaf = kwargs.get("is_leaf", ConstantVariable.create(None)) - if len(args) == 2: - is_leaf = args[1] - - if not ( - isinstance(is_leaf, variables.ConstantVariable) and is_leaf.value is None - ): - return super().call_function(tx, args, kwargs) - - # Optimize the case where is_leaf is None - # return _get_node_type(tree) not in SUPPORTED_NODES - tree = args[0] - node_type_var = PyTreeGetNodeTypeFunctionVariable( - torch.utils._pytree._get_node_type - ).call_function(tx, [tree], {}) - - # If the SUPPORTED_NODES was seen earlier and mutated, there would be a - # source and that will give us the mutated SUPPORTED_NODES. - supported_nodes_var = VariableTracker.build( - tx, - torch.utils._pytree.SUPPORTED_NODES, - source=get_pytree_SUPPORTED_NODES_source(), - ) - out = supported_nodes_var.call_method(tx, "__contains__", [node_type_var], {}) - return ConstantVariable.create(not out.value) From 9396e69194e8e16801b08b1326e34708a859fa5f Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 20 Nov 2025 18:14:35 +0000 Subject: [PATCH 110/230] Revert "[dynamo][compile time] Special case for torch.utils._pytree._get_node_type (#168054)" This reverts commit 0d7ba9714ac77b2b4a446a9eff913a6ff9dfc782. Reverted https://github.com/pytorch/pytorch/pull/168054 on behalf of https://github.com/anijain2305 due to failed at D87489130 ([comment](https://github.com/pytorch/pytorch/pull/168054#issuecomment-3559399563)) --- test/dynamo/test_repros.py | 59 ---------------------------- torch/_dynamo/trace_rules.py | 2 - torch/_dynamo/variables/__init__.py | 1 - torch/_dynamo/variables/functions.py | 39 ------------------ 4 files changed, 101 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index aab7d5268fcdc..10342f56d55d1 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -8184,65 +8184,6 @@ def fn(x): self.assertEqual(fn(torch.ones(3)), torch.ones(3) + 1) - def test_pytree_get_node_type_not_traced(self): - # Test that torch.utils._pytree._get_node_type is not traced into - # and doesn't cause excessive trace time overhead - from torch.utils._pytree import _get_node_type - - cnt = torch._dynamo.testing.CompileCounter() - - @torch.compile(backend=cnt, fullgraph=True) - def fn(x, y): - # Call _get_node_type which is used internally by pytree operations - node_type = _get_node_type([x, y]) - assert node_type is list - # Do some work with pytree structures - data = {"a": x, "b": y} - flat, spec = pytree.tree_flatten(data) - result = flat[0] + flat[1] - return result - - x = torch.randn(3, 4) - y = torch.randn(3, 4) - result = fn(x, y) - expected = x + y - - self.assertTrue(torch.allclose(result, expected)) - # Should compile successfully with fullgraph=True - self.assertEqual(cnt.frame_count, 1) - - def test_pytree_get_node_type_with_namedtuple(self): - # Test that torch.utils._pytree._get_node_type handles namedtuples correctly - # without being traced into, even when is_namedtuple_class is True - from collections import namedtuple - - from torch.utils._pytree import _get_node_type - - Point = namedtuple("Point", ["x", "y"]) - - cnt = torch._dynamo.testing.CompileCounter() - - @torch.compile(backend=cnt, fullgraph=True) - def fn(a, b): - # Create a namedtuple - point = Point(a, b) - # Call _get_node_type with a namedtuple instance - node_type = _get_node_type(point) - assert node_type is namedtuple - # Use pytree operations with namedtuples - flat, spec = pytree.tree_flatten(point) - result = flat[0] + flat[1] - return result - - x = torch.randn(3, 4) - y = torch.randn(3, 4) - result = fn(x, y) - expected = x + y - - self.assertTrue(torch.allclose(result, expected)) - # Should compile successfully with fullgraph=True - self.assertEqual(cnt.frame_count, 1) - instantiate_parametrized_tests(ReproTests) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 36093b042002e..97a3946b48bde 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -64,7 +64,6 @@ LocalGeneratorObjectVariable, NestedUserFunctionVariable, PolyfilledFunctionVariable, - PyTreeGetNodeTypeFunctionVariable, ReparametrizeModuleCallVariable, SkipFunctionVariable, TorchInGraphFunctionVariable, @@ -379,7 +378,6 @@ f"torch/testing/_internal/distributed/_tensor/common_dtensor.py#{TORCH_DYNAMO_RESUME_IN_PREFIX}": UserFunctionVariable, "torch/testing/_internal/common_distributed.py#forward": UserFunctionVariable, f"torch/testing/_internal/common_distributed.py#{TORCH_DYNAMO_RESUME_IN_PREFIX}": UserFunctionVariable, - "torch.utils._pytree._get_node_type": PyTreeGetNodeTypeFunctionVariable, } diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index ac0be3e5888be..74165b30bb2f0 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -64,7 +64,6 @@ LocalGeneratorObjectVariable, NestedUserFunctionVariable, PolyfilledFunctionVariable, - PyTreeGetNodeTypeFunctionVariable, SkipFunctionVariable, TMADescriptorExperimentalVariable, TMADescriptorStableVariable, diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 459b8e0bf6230..e30eeeb2c2fde 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -29,7 +29,6 @@ import sys import traceback import types -from collections import namedtuple from collections.abc import Callable, Sequence from types import CellType, FunctionType from typing import Any, Optional, TYPE_CHECKING, TypeVar @@ -39,7 +38,6 @@ import torch from torch._dynamo.exc import get_stack_above_dynamo from torch._guards import Source -from torch.utils._pytree import is_namedtuple_class from .. import config, graph_break_hints, polyfills, variables from ..bytecode_transformation import create_call_function, create_rot_n, is_generator @@ -65,7 +63,6 @@ DefaultsSource, GetItemSource, SkipGuardSource, - TypeSource, ) from ..utils import ( check_constant_args, @@ -2720,39 +2717,3 @@ def call_function( tensor=tensor, # type: ignore[arg-type] block_shape=block_shape, # type: ignore[arg-type] ) - - -class PyTreeGetNodeTypeFunctionVariable(UserFunctionVariable): - """ - `torch.utils._pytree._get_node_type` function is very hot function. We want to special case it to reduce Dynamo tracing time. - - def _get_node_type(tree: Any) -> Any: - node_type = type(tree) - # All namedtuple types are implicitly registered as pytree nodes. - # XXX: Other parts of the codebase expect namedtuple types always return - # `namedtuple` instead of the actual namedtuple type. Even if the type - # is explicitly registered. - if is_namedtuple_class(node_type): - return namedtuple - return node_type - """ - - def call_function( - self, - tx: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: - if len(args) != 1: - raise_type_error_exc( - tx, - f"pytree_get_node_type requires exactly 1 argument, got {len(args)}", - ) - type_source = None - if args[0].source: - install_guard(args[0].source.make_guard(GuardBuilder.TYPE_MATCH)) - type_source = TypeSource(args[0].source) - python_type = args[0].python_type() - if is_namedtuple_class(python_type): - return VariableTracker.build(tx, namedtuple) - return VariableTracker.build(tx, python_type, source=type_source) From 7bbbbcaeaa8781bbe1a68401cefcf704a5554111 Mon Sep 17 00:00:00 2001 From: Sam Gross Date: Thu, 20 Nov 2025 19:31:30 +0000 Subject: [PATCH 111/230] Fix debug assertion in autograd_not_implemented_fallback.cpp (#168280) Summary: The use_count may be two if the Tensor has a corresponding PyObject. Differential Revision: D87560032 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168280 Approved by: https://github.com/albanD --- .../autograd_not_implemented_fallback.cpp | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp index c16cbb2331f07..a4a9afec1a7cc 100644 --- a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp +++ b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp @@ -51,6 +51,20 @@ void _foreach_tensor( } } +[[maybe_unused]] +size_t expected_fresh_use_count(const at::Tensor& self) { + if (!self.defined()) { + // An UndefinedTensorImpl always has a use count of 0 + return 0; + } + if (self.unsafeGetTensorImpl()->pyobj_slot()->load_pyobj() != nullptr) { + // A TensorImpl with a Python object has a use count of 2 + return 2; + } + // A fresh TensorImpl (with no PyObject) has a use count of 1 + return 1; +} + AutogradFallbackMode kAutogradFallbackMode = AutogradFallbackMode::Warn; } // namespace @@ -420,8 +434,7 @@ static void autogradNotImplementedFallbackImpl( op_name == "aten::_test_optional_floatlist") return; if (!is_inplace_output[idx_ret]) - TORCH_INTERNAL_ASSERT( - t.use_count() <= 1, op_name); // Okay to return undefined tensor + TORCH_INTERNAL_ASSERT(t.use_count() == expected_fresh_use_count(t)); // note(crcrpar): `_foreach_norm` returns a list of scalar Tensors and // each Tensor shares a storage of a hidden, intermediate 1D Tensor // created inside the CUDA implementation. This is because the From 2eccaf9ca5c627fadcf9d9f355fda21ae977de01 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Tue, 18 Nov 2025 06:08:24 -0800 Subject: [PATCH 112/230] [submodule][inductor]Fix an AMD CPU max-autotune breakage (#168079) Summary: Fix https://github.com/pytorch/pytorch/issues/138718. Bump up the cpuinfo commit to pick up https://github.com/pytorch/cpuinfo/pull/338 which fixed an AMD CPU cache size not recognized issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168079 Approved by: https://github.com/eellison --- test/inductor/test_aot_inductor.py | 8 -------- third_party/cpuinfo | 2 +- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 56700bdac835f..fd962c8bea70a 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -67,12 +67,10 @@ IS_MACOS, IS_WINDOWS, MACOS_VERSION, - MI300_ARCH, parametrize, runOnRocm, skipIfMPS, skipIfRocm, - skipIfRocmArch, skipIfWindows, skipIfWindowsXPU, skipIfXpu, @@ -175,11 +173,8 @@ def get_module_ext_type(): class AOTInductorTestsTemplate: - # Temporarily skipping test as pytorch/cpuinfo not able to retrieve cache size for - # AMD EPYC 9575F 64-Core Processor CPU in gfx942 VM Runners @common_utils.parametrize("embed_kernel_binary", [False, True]) @common_utils.parametrize("max_autotune", [False, True]) - @skipIfRocmArch(MI300_ARCH) def test_simple(self, embed_kernel_binary, max_autotune): if self.device == "cpu" and IS_MACOS and max_autotune: raise unittest.SkipTest("max_autotune not supported on macos") @@ -5212,10 +5207,7 @@ def forward(self, values, offsets): ) self.assertTrue(same(model(*example_input), actual)) - # Temporarily skipping test as pytorch/cpuinfo not able to retrieve cache size for - # AMD EPYC 9575F 64-Core Processor CPU in gfx942 VM Runners @common_utils.parametrize("max_autotune", [True, False]) - @skipIfRocmArch(MI300_ARCH) def test_misc_1(self, max_autotune): if self.device == "cpu" and IS_MACOS and max_autotune: raise unittest.SkipTest("max_autotune not supported on macos") diff --git a/third_party/cpuinfo b/third_party/cpuinfo index 5e3d2445e6a84..f858c30bcb16f 160000 --- a/third_party/cpuinfo +++ b/third_party/cpuinfo @@ -1 +1 @@ -Subproject commit 5e3d2445e6a84d9599bee2bf78edbb4d80865e1d +Subproject commit f858c30bcb16f8effd5ff46996f0514539e17abc From 4887c46900e475f9b3623e02a94816dcbb4e43b2 Mon Sep 17 00:00:00 2001 From: Jagadish Krishnamoorthy Date: Thu, 20 Nov 2025 20:12:51 +0000 Subject: [PATCH 113/230] [ROCm] Fix HIP document url. (#168220) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168220 Approved by: https://github.com/jeffdaily --- c10/cuda/CUDAMiscFunctions.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/c10/cuda/CUDAMiscFunctions.cpp b/c10/cuda/CUDAMiscFunctions.cpp index b305008d44f8c..49bad41dda866 100644 --- a/c10/cuda/CUDAMiscFunctions.cpp +++ b/c10/cuda/CUDAMiscFunctions.cpp @@ -17,8 +17,13 @@ std::string get_cuda_error_help(cudaError_t error) noexcept { default: help_text.append("\nSearch for `") .append(cudaGetErrorName(error)) +#if defined(USE_ROCM) + .append( + "' in https://rocm.docs.amd.com/projects/HIP/en/latest/index.html for more information."); +#else .append( "' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information."); +#endif break; } return help_text; From 02df234aa3ad58516211750ed6c776e67278c454 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Thu, 20 Nov 2025 21:03:20 +0000 Subject: [PATCH 114/230] [varlen attn] batch invariance testing (#167865) Tests batch invariance for varlen attention - permutations of samples in the batch should not affect fwd output and backward gradients Pull Request resolved: https://github.com/pytorch/pytorch/pull/167865 Approved by: https://github.com/drisspg --- test/test_varlen_attention.py | 153 ++++++++++++++++++++++++++++++++-- 1 file changed, 144 insertions(+), 9 deletions(-) diff --git a/test/test_varlen_attention.py b/test/test_varlen_attention.py index 1a8371eee6345..3b3186a157895 100644 --- a/test/test_varlen_attention.py +++ b/test/test_varlen_attention.py @@ -110,6 +110,16 @@ def forward_sdpa( return self.out_proj(attn_out) +def pack_sequences(seqs, device): + x_packed = torch.cat(seqs, dim=0) + seq_lens = torch.tensor([len(s) for s in seqs], device=device) + cu_seq = torch.zeros(len(seqs) + 1, device=device, dtype=torch.int32) + cu_seq[1:] = seq_lens.cumsum(0) + max_len = seq_lens.max().item() + + return x_packed, cu_seq, max_len + + def create_variable_length_batch( shape: VarlenShape, device: torch.device, dtype: torch.dtype ): @@ -119,16 +129,15 @@ def create_variable_length_batch( seq_lengths.append(min(length, shape.max_seq_len)) seq_lengths = torch.tensor(seq_lengths, device=device) - total_tokens = seq_lengths.sum().item() - - x_packed = torch.randn( - total_tokens, shape.embed_dim, device=device, dtype=dtype, requires_grad=True - ) - cu_seq = torch.zeros(shape.batch_size + 1, device=device, dtype=torch.int32) - cu_seq[1:] = seq_lengths.cumsum(0) + sequences = [ + torch.randn( + seq_len, shape.embed_dim, device=device, dtype=dtype, requires_grad=True + ) + for seq_len in seq_lengths + ] - max_len = seq_lengths.max().item() + x_packed, cu_seq, max_len = pack_sequences(sequences, device) x_padded = torch.zeros( shape.batch_size, max_len, shape.embed_dim, device=device, dtype=dtype ) @@ -146,7 +155,6 @@ def create_variable_length_batch( "x_packed": x_packed, "x_padded": x_padded, "max_len": max_len, - "total_tokens": total_tokens, } @@ -428,6 +436,133 @@ def test_varlen_vs_sdpa(self, device, dtype, is_causal): start_idx = end_idx + @skipIfRocm(msg="ROCM does not support variable length attention") + @unittest.skipIf( + not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported" + ) + @parametrize("dtype", [torch.bfloat16, torch.float16]) + @parametrize("is_causal", [False, True]) + @parametrize("num_perms", [1, 3, 5]) + def test_batch_invariance(self, device, dtype, is_causal, num_perms): + torch.manual_seed(42) + + batch_size, max_seq_len = 4, 128 + + seq_lengths = [] + for _ in range(batch_size): + length = torch.randint(1, max_seq_len // 64 + 1, (1,)).item() * 64 + seq_lengths.append(min(length, max_seq_len)) + + sequences_qkv = [ + [ + torch.testing.make_tensor( + (seq_len, 2, 128), device=device, dtype=dtype, requires_grad=True + ) + for _ in range(3) + ] + for seq_len in seq_lengths + ] + sequences_q, sequences_k, sequences_v = map(list, zip(*sequences_qkv)) + + q_packed_orig = torch.cat(sequences_q, dim=0) + k_packed_orig = torch.cat(sequences_k, dim=0) + v_packed_orig = torch.cat(sequences_v, dim=0) + + seq_lens = torch.tensor(seq_lengths, device=device) + cu_seq_orig = torch.zeros(batch_size + 1, device=device, dtype=torch.int32) + cu_seq_orig[1:] = seq_lens.cumsum(0) + + original_output = varlen_attn( + q_packed_orig, + k_packed_orig, + v_packed_orig, + cu_seq_orig, + cu_seq_orig, + max_seq_len, + max_seq_len, + is_causal, + ) + + original_grad_out = torch.randn_like(original_output) + original_grads = torch.autograd.grad( + outputs=original_output, + inputs=[q_packed_orig, k_packed_orig, v_packed_orig], + grad_outputs=original_grad_out, + ) + + for _ in range(num_perms): + perm = torch.randperm(batch_size) + permuted_sequences_q = [sequences_q[perm[i]] for i in range(batch_size)] + permuted_sequences_k = [sequences_k[perm[i]] for i in range(batch_size)] + permuted_sequences_v = [sequences_v[perm[i]] for i in range(batch_size)] + + q_packed_perm = torch.cat(permuted_sequences_q, dim=0) + k_packed_perm = torch.cat(permuted_sequences_k, dim=0) + v_packed_perm = torch.cat(permuted_sequences_v, dim=0) + + permuted_seq_lens = torch.tensor( + [seq_lengths[perm[i]] for i in range(batch_size)], device=device + ) + cu_seq_perm = torch.zeros(batch_size + 1, device=device, dtype=torch.int32) + cu_seq_perm[1:] = permuted_seq_lens.cumsum(0) + + permuted_output = varlen_attn( + q_packed_perm, + k_packed_perm, + v_packed_perm, + cu_seq_perm, + cu_seq_perm, + max_seq_len, + max_seq_len, + is_causal, + ) + + for i in range(batch_size): + orig_idx = perm[i].item() + + orig_start = cu_seq_orig[orig_idx].item() + orig_end = cu_seq_orig[orig_idx + 1].item() + orig_seq_output = original_output[orig_start:orig_end] + + perm_start = cu_seq_perm[i].item() + perm_end = cu_seq_perm[i + 1].item() + perm_seq_output = permuted_output[perm_start:perm_end] + + self.assertEqual(orig_seq_output, perm_seq_output) + + permuted_grad_out = torch.zeros_like(permuted_output) + for i in range(batch_size): + orig_idx = perm[i].item() + orig_start = cu_seq_orig[orig_idx].item() + orig_end = cu_seq_orig[orig_idx + 1].item() + + perm_start = cu_seq_perm[i].item() + perm_end = cu_seq_perm[i + 1].item() + + permuted_grad_out[perm_start:perm_end] = original_grad_out[ + orig_start:orig_end + ] + + permuted_grads = torch.autograd.grad( + outputs=permuted_output, + inputs=[q_packed_perm, k_packed_perm, v_packed_perm], + grad_outputs=permuted_grad_out, + ) + + for original_grad, permuted_grad in zip(original_grads, permuted_grads): + for i in range(batch_size): + orig_idx = perm[i].item() + + orig_start = cu_seq_orig[orig_idx].item() + orig_end = cu_seq_orig[orig_idx + 1].item() + orig_seq_grad = original_grad[orig_start:orig_end] + + perm_start = cu_seq_perm[i].item() + perm_end = cu_seq_perm[i + 1].item() + perm_seq_grad = permuted_grad[perm_start:perm_end] + + self.assertEqual(orig_seq_grad, perm_seq_grad) + device_types = ("cuda",) From 6644fd7ace91cdaf2528aef52e5569b3ab257094 Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Wed, 19 Nov 2025 14:38:49 -0800 Subject: [PATCH 115/230] [inductor] make mix order reduction work with dynamic shapes (#168117) Internal models have dynamic shapes for rms/layer norm. Mix order reduction would be by-passed without this PR. co-author this PR with Paul Zhang. Differential Revision: [D87378475](https://our.internmc.facebook.com/intern/diff/D87378475) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168117 Approved by: https://github.com/v0i0, https://github.com/jansel --- test/inductor/test_mix_order_reduction.py | 31 +++++++++++++++++++++++ torch/_inductor/choices.py | 8 +++--- torch/_inductor/codegen/simd.py | 21 +++++++-------- torch/_inductor/scheduler.py | 12 ++++++--- 4 files changed, 56 insertions(+), 16 deletions(-) diff --git a/test/inductor/test_mix_order_reduction.py b/test/inductor/test_mix_order_reduction.py index 592e42ce41735..cae48673f2332 100644 --- a/test/inductor/test_mix_order_reduction.py +++ b/test/inductor/test_mix_order_reduction.py @@ -382,6 +382,37 @@ def fwd_bwd(f): metrics.codegen_mix_order_reduction, ) + def test_layer_norm_bwd_with_dynamic_shape(self): + def f(x, w, eps): + return F.layer_norm(x, x.shape[-1:], weight=w, bias=None, eps=eps) + + def fwd_bwd(f): + x.grad = None + w.grad = None + out = f(x, w, eps) + out.backward(dy) + return x.grad, w.grad + + M0, M1, N = 251, 223, 128 + wbdtype = torch.float + xdtype = torch.float + x = torch.randn(M0, M1, N, dtype=xdtype, device=GPU_TYPE, requires_grad=True) + torch._dynamo.mark_dynamic(x, 0) + w = torch.randn(N, dtype=wbdtype, device=GPU_TYPE, requires_grad=True) + dy = torch.randn_like(x) + eps = 1e-5 + + opt_f = torch.compile(f) + + ref = fwd_bwd(f) + act, (_, bwd_wrapper) = utils.run_and_get_code(fwd_bwd, opt_f) + + self.assertTrue(same(ref, act, tol=1e-2), f"ref:\n{ref}\nact:\n{act}") + self.assertEqual( + inductor_config.triton.mix_order_reduction, + metrics.codegen_mix_order_reduction, + ) + @parametrize("split_reductions", (False, True)) @parametrize("shape", ((32768, 768), (32769, 768))) def test_layer_norm_bwd_no_bias(self, split_reductions, shape): diff --git a/torch/_inductor/choices.py b/torch/_inductor/choices.py index 47542cb6aef77..a5379219a2373 100644 --- a/torch/_inductor/choices.py +++ b/torch/_inductor/choices.py @@ -561,9 +561,11 @@ def can_fuse_horizontal( shared_data_score: int, ) -> bool: """Hook for heuristics to prevent horizontal (consumer/consumer) fusions""" - if ( - shared_data_score < config.score_fusion_memory_threshold - ) and not MixOrderReduction.can_fuse(node1, node2): + if MixOrderReduction.can_fuse(node1, node2): + # For mix order reduction, we disregard shared data or + # distance. + return True + if shared_data_score < config.score_fusion_memory_threshold: WhyNoFuse(node1, node2)("score_fusion_memory_threshold") return False if scheduler.are_long_distant_nodes(node1, node2): diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 65e8f88b1c425..1706f53cb2927 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -1610,10 +1610,7 @@ def benchmark_codegened_module( def _codegen_mix_order_reduction(self, node1, node2): numel, rnumel = scheduler.MixOrderReduction.get_numel_rnumel(node1) - if not V.graph.sizevars.statically_known_gt( - numel, - rnumel, - ): + if not V.graph.sizevars.evaluate_expr(sympy.Gt(numel, rnumel)): return self._codegen_mix_order_reduction(node2, node1) def _pick_split_size(): @@ -1625,7 +1622,10 @@ def _pick_split_size(): device_prop = DeviceProperties.create(node1.get_device()) num_sm = device_prop.multi_processor_count estimated_num_splits = num_sm * 8 - split_size = max(next_power_of_2(numel // estimated_num_splits), 16) + + # split_size is decided based on hint + numel_hint = V.graph.sizevars.size_hint(numel) + split_size = max(next_power_of_2(numel_hint // estimated_num_splits), 16) split_size = min(split_size, 128) return split_size @@ -1634,10 +1634,7 @@ def _pick_split_size(): # pyrefly: ignore [bad-assignment] metrics.codegen_mix_order_reduction += 1 - assert V.graph.sizevars.statically_known_gt( - numel, - rnumel, - ) + assert V.graph.sizevars.evaluate_expr(sympy.Gt(numel, rnumel)) # split epilogue out of node2 node2_reductions, node2_epilogue = self._split_mix_order_reduction_epilogue( @@ -1726,6 +1723,8 @@ def _bench(candidate_split_size): if node.get_outputs()[0].node.get_name() not in rename: node.mark_run() + V.graph.wrapper_code.make_comment("# Call mix order reduction kernel") + self.codegen_comment(node_schedule, None) # workspace args is still needed after the call kernel.call_kernel(kernel.kernel_name, deallocate_ws=False) V.graph.removed_buffers |= kernel.removed_buffers @@ -1733,7 +1732,9 @@ def _bench(candidate_split_size): # a extra round of reduction assert len(converted_nodes) == len(kernel.saved_partial_accumulate) - nsplit = (numel + split_size - 1) // split_size + nsplit = V.graph.wrapper_code.codegen_python_sizevar( + (numel + split_size - 1) // split_size + ) for idx, partial_accum in enumerate(kernel.saved_partial_accumulate): buffer_name = partial_accum.buffer_name diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 45cf9e409b656..e5bd34ea977e7 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -253,19 +253,23 @@ def can_fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: # small workload. When a workload is small enough, data can be # fully cached by L2 size_thres = 5 * 2**20 - if not V.graph.sizevars.statically_known_geq(nrow * ncol, size_thres): + + # Call evaluate_expr rather than statically_known_geq since nrow can + # have dynamic shape in real models. + # Don't use hint directly since hint can be non-representative. + if not V.graph.sizevars.evaluate_expr(sympy.Ge(nrow * ncol, size_thres)): return False # We require more more row than columns since # 1, we prefer doing persistent reduction for each row # 2, we will split the reduction across the rows - if not V.graph.sizevars.statically_known_geq(nrow, ncol * 2): + if not V.graph.sizevars.evaluate_expr(sympy.Ge(nrow, ncol * 2)): return False # When nrow is small, ncol should also be small (due to the check # above). Thus the entire tensor should be well cached in L2. # Mix order reduction is less beneficial. - if not V.graph.sizevars.statically_known_geq(nrow, 4096): + if not V.graph.sizevars.evaluate_expr(sympy.Ge(nrow, 4096)): return False contiguous_node, other_node = ( @@ -301,6 +305,8 @@ def can_fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: return False # rnumel so large that we will not generated persistent reduction + # We don't see real use cases with dynamic ncol. But if we do, + # we should call evaluete_expr here which adds guards. if not V.graph.sizevars.statically_known_leq(ncol, 1024 * 16): return False From f7fc6346b0d06ecd77b3e905328c57766b4474ea Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 20 Nov 2025 22:04:03 +0000 Subject: [PATCH 116/230] Revert "Allow BlockDescriptorOptions classes to be overridden In TritonKernel (#165899)" This reverts commit 13cda9b89e2f4f6a420ec048260cec61ff4649bf. Reverted https://github.com/pytorch/pytorch/pull/165899 on behalf of https://github.com/jansel due to See #167892 ([comment](https://github.com/pytorch/pytorch/pull/165899#issuecomment-3560236176)) --- torch/_inductor/codegen/triton.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 4ac481478196a..9b718f0c780c1 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -2262,10 +2262,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): kexpr: Callable[[sympy.Expr], str] = texpr allow_block_ptr = True tma_compatibility_checker_cls = TMACompatibilityChecker - block_ptr_options_cls: type[BlockPtrOptions] = BlockPtrOptions - tensor_descriptor_options_cls: type[TensorDescriptorOptions] = ( - TensorDescriptorOptions - ) def __init__( self, @@ -2731,9 +2727,9 @@ def match_block_expr() -> Optional[BlockDescriptorOptions]: self.filter_masks(mask_vars) options_class = ( - self.block_ptr_options_cls + BlockPtrOptions if config.triton.use_block_ptr - else self.tensor_descriptor_options_cls + else TensorDescriptorOptions ) nonlocal tma_compatibility_checker if config.triton.use_block_ptr: @@ -2757,7 +2753,7 @@ def match_block_expr() -> Optional[BlockDescriptorOptions]: can_lift=can_lift, transpose_contiguous=transpose_contiguous, ) - if isinstance(options_class, TensorDescriptorOptions): + if options_class == TensorDescriptorOptions: tma_compatibility_checker = cast( TMACompatibilityChecker, tma_compatibility_checker ) From b4f5472307a289fcd2fb1237677a31ce79470e4a Mon Sep 17 00:00:00 2001 From: Nikhil Patel Date: Thu, 20 Nov 2025 22:04:30 +0000 Subject: [PATCH 117/230] [BE][Inductor] Move mm templates into separate files (#168179) Summary: `mm.py` currently embeds multiple Jinja templates inline, making the file harder to read and maintain. This change switches to `load_kernel_template()`, placing each template in its own file and restoring proper Jinja syntax highlighting. To add a new template named, for example, `new_mm`, place the jinja code in `_inductor/kernel/templates/new_mm.py.jinja`, then just call `load_template("new_mm")`. Test Plan: CI Differential Revision: D87461102 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168179 Approved by: https://github.com/njriasan --- torch/_inductor/kernel/mm.py | 826 +----------------- ...kwell_ws_persistent_device_tma_mm.py.jinja | 107 +++ .../triton_epilogue_scaled_mm.py.jinja | 194 ++++ .../triton_main_loop_scaled_mm.py.jinja | 212 +++++ .../kernel/templates/triton_mm.py.jinja | 72 ++ .../kernel/templates/triton_mm_rocm.py.jinja | 71 ++ .../triton_persistent_tma_mm.py.jinja | 129 +++ 7 files changed, 799 insertions(+), 812 deletions(-) create mode 100644 torch/_inductor/kernel/templates/triton_blackwell_ws_persistent_device_tma_mm.py.jinja create mode 100644 torch/_inductor/kernel/templates/triton_epilogue_scaled_mm.py.jinja create mode 100644 torch/_inductor/kernel/templates/triton_main_loop_scaled_mm.py.jinja create mode 100644 torch/_inductor/kernel/templates/triton_mm.py.jinja create mode 100644 torch/_inductor/kernel/templates/triton_mm_rocm.py.jinja create mode 100644 torch/_inductor/kernel/templates/triton_persistent_tma_mm.py.jinja diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index 986ceb4405a14..5b57c458f46e6 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -55,6 +55,7 @@ ) from .mm_common import ( _is_static_problem, + load_kernel_template, mm_args, mm_grid, persistent_mm_grid, @@ -75,162 +76,18 @@ aten = torch.ops.aten prims = torch.ops.prims +# We define each template kernel in a separate file which is the name of the input to load_kernel_template +# (e.g. triton_mm for templates/triton_mm.py.jinja). +# If you are adding a new template, please follow that pattern and add a new file with your implementation in the templates folder. mm_template = TritonTemplate( name="mm", grid=mm_grid, - source=( - r""" -{{def_kernel("A", "B")}} - M = {{size("A", 0)}} - N = {{size("B", 1)}} - K = {{size("A", 1)}} - if M * N == 0: - # early exit due to zero-size input(s) - return - stride_am = {{stride("A", 0)}} - stride_ak = {{stride("A", 1)}} - stride_bk = {{stride("B", 0)}} - stride_bn = {{stride("B", 1)}} - - # based on triton.ops.matmul - pid = tl.program_id(0).to(INDEX_DTYPE) - grid_m = (M + BLOCK_M - 1) // BLOCK_M - grid_n = (N + BLOCK_N - 1) // BLOCK_N - - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - tl.assume(pid_m >= 0) - tl.assume(pid_n >= 0) - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - if ((stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1)) and (M >= BLOCK_M and K > 1): - offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - else: - offs_a_m = rm % M - if ((stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1)) and (N >= BLOCK_N and K > 1): - offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - else: - offs_b_n = rn % N - offs_k = tl.arange(0, BLOCK_K) - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - - for k_idx in range(0, tl.cdiv(K, BLOCK_K)): - {% if not EVEN_K %} - a_mask = offs_k[None, :] < (K - k_idx * BLOCK_K) - b_mask = offs_k[:, None] < (K - k_idx * BLOCK_K) - {% endif %} - a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K) - b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K) - - idx_m = offs_a_m[:, None] - idx_n = a_k_idx_vals - {{load_input("A", "a", ("idx_m", "idx_n"), mask=None if EVEN_K else "a_mask", - indent_width=8, index_shape=("BLOCK_M", "BLOCK_K"))}} - - idx_m = b_k_idx_vals - idx_n = offs_b_n[None, :] - {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", - indent_width=8, index_shape=("BLOCK_K", "BLOCK_N"))}} - - {% if USE_FAST_ACCUM %} - acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) - {% else %} - acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) - {% endif %} - - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - idx_m = rm[:, None] - idx_n = rn[None, :] - mask = (idx_m < M) & (idx_n < N) - - # inductor generates a suffix - {{store_output(("idx_m", "idx_n"), "acc", "mask", val_shape=("BLOCK_M", "BLOCK_N"))}} -""" - if (torch.version.hip is None) or triton_version >= "3.3.0" - # FIXME: To get around rocm failures like https://github.com/pytorch/pytorch/actions/runs/13123783322/job/36617154943 - # The only difference between the two templates is M >= BLOCK_M and N >= BLOCK_N checking. - # See more details in https://github.com/pytorch/pytorch/pull/146293 - else r""" -{{def_kernel("A", "B")}} - M = {{size("A", 0)}} - N = {{size("B", 1)}} - K = {{size("A", 1)}} - if M * N == 0: - # early exit due to zero-size input(s) - return - stride_am = {{stride("A", 0)}} - stride_ak = {{stride("A", 1)}} - stride_bk = {{stride("B", 0)}} - stride_bn = {{stride("B", 1)}} - - # based on triton.ops.matmul - pid = tl.program_id(0).to(INDEX_DTYPE) - grid_m = (M + BLOCK_M - 1) // BLOCK_M - grid_n = (N + BLOCK_N - 1) // BLOCK_N - - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - tl.assume(pid_m >= 0) - tl.assume(pid_n >= 0) - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1): - offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - else: - offs_a_m = rm % M - if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1): - offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - else: - offs_b_n = rn % N - offs_k = tl.arange(0, BLOCK_K) - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - - for k_idx in range(0, tl.cdiv(K, BLOCK_K)): - {% if not EVEN_K %} - a_mask = offs_k[None, :] < (K - k_idx * BLOCK_K) - b_mask = offs_k[:, None] < (K - k_idx * BLOCK_K) - {% endif %} - a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K) - b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K) - - idx_m = offs_a_m[:, None] - idx_n = a_k_idx_vals - {{load_input("A", "a", ("idx_m", "idx_n"), mask=None if EVEN_K else "a_mask", - indent_width=8, index_shape=("BLOCK_M", "BLOCK_K"))}} - - idx_m = b_k_idx_vals - idx_n = offs_b_n[None, :] - {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", - indent_width=8, index_shape=("BLOCK_K", "BLOCK_N"))}} - {% if USE_FAST_ACCUM %} - acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) - {% else %} - acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) - {% endif %} - - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - idx_m = rm[:, None] - idx_n = rn[None, :] - mask = (idx_m < M) & (idx_n < N) - - # inductor generates a suffix - {{store_output(("idx_m", "idx_n"), "acc", "mask", val_shape=("BLOCK_M", "BLOCK_N"))}} -""" - ), + source=load_kernel_template("triton_mm") + if (torch.version.hip is None) or triton_version >= "3.3.0" + # FIXME: To get around rocm failures like https://github.com/pytorch/pytorch/actions/runs/13123783322/job/36617154943 + # The only difference between the two templates is M >= BLOCK_M and N >= BLOCK_N checking. + # See more details in https://github.com/pytorch/pytorch/pull/146293 + else load_kernel_template("triton_mm_rocm"), cache_codegen_enabled_for_template=True, prologue_loads_all_inputs=True, ) @@ -238,682 +95,27 @@ persistent_tma_mm_template = TritonTemplate( name="mm_persistent_tma", grid=persistent_mm_grid, - source=r""" -{{def_kernel("A", "B")}} - M = {{size("A", 0)}} - N = {{size("B", 1)}} - K = {{size("A", 1)}} - if M * N == 0: - # early exit due to zero-size input(s) - return - - start_pid = tl.program_id(0).to(INDEX_DTYPE) - grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - k_tiles = tl.cdiv(K, BLOCK_K) - num_tiles = grid_m * grid_n - tiles_per_SM = num_tiles // NUM_SMS - if start_pid < num_tiles % NUM_SMS: - tiles_per_SM += 1 - - tile_id = start_pid - NUM_SMS - ki = -1 - - width = GROUP_M * grid_n - rk_for_mask = tl.arange(0, BLOCK_K) - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - - {%- if TMA_EXPERIMENTAL_API %} - workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE - a_desc_ptr = workspace_base - b_desc_ptr = workspace_base + TMA_SIZE - - triton.language.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=a_desc_ptr, - global_address=A, - load_size=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], - global_size=[M, K] if A_ROW_MAJOR else [K, M], - element_ty=A.dtype.element_ty, - ) - triton.language.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=b_desc_ptr, - global_address=B, - load_size=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], - global_size=[K, N] if B_ROW_MAJOR else [N, K], - element_ty=B.dtype.element_ty, - ) - - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) - - {%- else %} - stride_am = {{stride("A", 0)}} - stride_ak = {{stride("A", 1)}} - stride_bk = {{stride("B", 0)}} - stride_bn = {{stride("B", 1)}} - a_desc = triton.language.make_tensor_descriptor( - base=A, - shape=[M, K] if A_ROW_MAJOR else [K, M], - strides=[stride_am, 1] if A_ROW_MAJOR else [stride_ak, 1], - block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], - ) - b_desc = triton.language.make_tensor_descriptor( - base=B, - shape=[K, N] if B_ROW_MAJOR else [N, K], - strides=[stride_bk, 1] if B_ROW_MAJOR else [stride_bn, 1], - block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], - ) - {%- endif %} - - pid_m = 0 - pid_n = 0 - rm = 0 - rn = 0 - - for _ in range(0, k_tiles * tiles_per_SM): - ki = tl.where(ki == k_tiles - 1, 0, ki + 1) - if ki == 0: - tile_id += NUM_SMS - # re-order program ID for better L2 performance - group_id = tile_id // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (tile_id % group_size) - pid_n = (tile_id % width) // (group_size) - - rm = pid_m * BLOCK_M - rn = pid_n * BLOCK_N - - rk = ki * BLOCK_K - - {%- if TMA_EXPERIMENTAL_API %} - a = tl._experimental_descriptor_load( - a_desc_ptr, - [rm, rk] if A_ROW_MAJOR else [rk, rm], - [BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], - A.dtype.element_ty, - ) - b = tl._experimental_descriptor_load( - b_desc_ptr, - [rk, rn] if B_ROW_MAJOR else [rn, rk], - [BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], - B.dtype.element_ty, - ) - {%- else %} - a = tl.load_tensor_descriptor( - a_desc, - [rm, rk] if A_ROW_MAJOR else [rk, rm], - ) - b = tl.load_tensor_descriptor( - b_desc, - [rk, rn] if B_ROW_MAJOR else [rn, rk], - ) - {%- endif %} - acc += tl.dot( - a if A_ROW_MAJOR else a.T, - b if B_ROW_MAJOR else b.T, - allow_tf32=ALLOW_TF32, - ) - - if ki == k_tiles - 1: - # inductor generates a suffix - {%- if TMA_EXPERIMENTAL_API %} - # rematerialize rm and rn to save registers - rcm = rm + tl.arange(0, BLOCK_M) - rcn = rn + tl.arange(0, BLOCK_N) - idx_m = rcm[:, None] - idx_n = rcn[None, :] - mask = (idx_m < M) & (idx_n < N) - {{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12, val_shape=("BLOCK_M", "BLOCK_N"))}} - {%- else %} - {{store_output(("rm", "rn"), "acc", indent_width=12, val_shape=("BLOCK_M", "BLOCK_N"), block_indexing=True)}} - {%- endif %} - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - -""", + source=load_kernel_template("triton_persistent_tma_mm"), ) -load_scales = r""" -@triton.jit -def load_scales(scale_ptr, SCALE_RECIPE: tl.constexpr): - if SCALE_RECIPE == 0: - return tl.load(scale_ptr) # For tensor-wise scaling, we'll load the scalar values - else: - return scale_ptr # For all other scaling recipes, we'll return the pointers -""" - - -apply_scaling = r""" -@triton.jit -def apply_scaling( - accumulator, - a_scale, - b_scale, - SCALE_RECIPE_A: tl.constexpr, - SCALE_RECIPE_B: tl.constexpr, - offs_cm, - offs_cn, - M, - N, - stride_a_scale_m, - stride_b_scale_n, -): - if SCALE_RECIPE_A == 1 and SCALE_RECIPE_B == 1: # (ScalingType.RowWise, ScalingType.RowWise) - # For row-wise scaling, we need to load the scales for each row/column - a_scales = tl.load( - a_scale + (offs_cm * stride_a_scale_m), - mask=offs_cm < M, - other=0.0, - ) - b_scales = tl.load( - b_scale + (offs_cn * stride_b_scale_n), - mask=offs_cn < N, - other=0.0, - ) - acc_scale = a_scales[:, None] * b_scales[None, :] - else: # (ScalingType.TensorWise, ScalingType.TensorWise) - # For per-tensor scaling, we can directly use the loaded scalar values - acc_scale = a_scale * b_scale - - return accumulator * acc_scale -""" - - -scaled_mm_device_tma_epilogue_scaling = r""" -{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} - M = {{size("A", 0)}} - N = {{size("B", 1)}} - K = {{size("A", 1)}} - if M * N == 0: - # early exit due to zero-size input(s) - return - - stride_am = {{stride("A", 0)}} - stride_ak = {{stride("A", 1)}} - stride_bk = {{stride("B", 0)}} - stride_bn = {{stride("B", 1)}} - - if SCALE_RECIPE_A == 1: # ScalingType.RowWise - stride_a_scale_m = 1 - else: - stride_a_scale_m = 0 - - if SCALE_RECIPE_B == 1: # ScalingType.RowWise - stride_b_scale_n = 1 - else: - stride_b_scale_n = 0 - - start_pid = tl.program_id(axis=0).to(INDEX_DTYPE) - num_pid_m = tl.cdiv(M, BLOCK_M) - num_pid_n = tl.cdiv(N, BLOCK_N) - k_tiles = tl.cdiv(K, BLOCK_K) - num_tiles = num_pid_m * num_pid_n - - {%- if TMA_EXPERIMENTAL_API %} - workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE - a_desc_ptr = workspace_base - b_desc_ptr = workspace_base + TMA_SIZE - - triton.language.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=a_desc_ptr, - global_address=A, - load_size=[BLOCK_M, BLOCK_K], - global_size=[M, K], - element_ty=A.dtype.element_ty, - ) - triton.language.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=b_desc_ptr, - global_address=B, - load_size=[BLOCK_N, BLOCK_K], - global_size=[N, K], - element_ty=B.dtype.element_ty, - ) - - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) - - {%- else %} - stride_am = {{stride("A", 0)}} - stride_bn = {{stride("B", 1)}} - a_desc = triton.language.make_tensor_descriptor( - base=A, - shape=[M, K], - strides=[stride_am, 1], - block_shape=[BLOCK_M, BLOCK_K], - ) - b_desc = triton.language.make_tensor_descriptor( - base=B, - shape=[N, K], - strides=[stride_bn, 1], - block_shape=[BLOCK_N, BLOCK_K], - ) - {%- endif %} - - tiles_per_SM = num_tiles // NUM_SMS - if start_pid < num_tiles % NUM_SMS: - tiles_per_SM += 1 - - tile_id = start_pid - NUM_SMS - ki = -1 - - pid_m = 0 - pid_n = 0 - offs_am = 0 - offs_bn = 0 - - num_pid_in_group = GROUP_M * num_pid_n - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - a_scale = load_scales(A_inverse_scale, SCALE_RECIPE_A) - b_scale = load_scales(B_inverse_scale, SCALE_RECIPE_B) - - for _ in range(0, k_tiles * tiles_per_SM): - ki = tl.where(ki == k_tiles - 1, 0, ki + 1) - if ki == 0: - tile_id += NUM_SMS - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_M) - pid_m = first_pid_m + (tile_id % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_M - offs_bn = pid_n * BLOCK_N - - offs_k = ki * BLOCK_K - - {%- if TMA_EXPERIMENTAL_API %} - a = tl._experimental_descriptor_load( - a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty - ) - b = tl._experimental_descriptor_load( - b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty - ) - {%- else %} - a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k]) - b = tl.load_tensor_descriptor(b_desc, [offs_bn, offs_k]) - {%- endif %} - if USE_FAST_ACCUM: - accumulator = tl.dot(a, b.T, accumulator) - else: - accumulator += tl.dot(a, b.T) - - if ki == k_tiles - 1: - # Apply inverse scaling - offs_cm = offs_am + tl.arange(0, BLOCK_M) - offs_cn = offs_bn + tl.arange(0, BLOCK_N) - # Apply scaling - accumulator = apply_scaling( - accumulator, - a_scale, - b_scale, - SCALE_RECIPE_A, - SCALE_RECIPE_B, - offs_cm, - offs_cn, - M, - N, - stride_a_scale_m, - stride_b_scale_n, - ) - - # inductor generates a suffix - {%- if TMA_EXPERIMENTAL_API %} - idx_m = offs_cm[:, None] - idx_n = offs_cn[None, :] - mask = (idx_m < M) & (idx_n < N) - {{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12, val_shape=("BLOCK_M", "BLOCK_N"))}} - {%- else %} - {{store_output( - ("offs_am", "offs_bn"), - "accumulator", - indent_width=12, - val_shape=("BLOCK_M", "BLOCK_N"), - block_indexing=True, - )}} - {%- endif %} - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) -""" - scaled_mm_device_tma_epilogue_scaling_template = TritonTemplate( name="scaled_mm_device_tma_epilogue_scaling", grid=persistent_mm_grid, - source=scaled_mm_device_tma_epilogue_scaling + load_scales + apply_scaling, + source=load_kernel_template("triton_epilogue_scaled_mm"), ) -blockwise1xTILESIZE_scaling = r""" -@triton.jit -def blockwise1xTILESIZE_scaling( - pid, - scale, - ki, - lhs_size, - lhs_blocks, - k_blocks, - BLOCK_lhs: tl.constexpr, - BLOCK_K: tl.constexpr, - MIN_BLOCK_TILE_K: tl.constexpr, - TILE_SIZE: tl.constexpr, -): - row_offs_scale = pid * BLOCK_lhs + tl.arange(0, BLOCK_lhs) - col_offs_scale = ki * tl.cdiv(BLOCK_K, TILE_SIZE) + tl.arange(0, (BLOCK_K + TILE_SIZE - 1) // TILE_SIZE) - ptrs = scale + row_offs_scale[:, None] * k_blocks + col_offs_scale[None, :] - mask = (row_offs_scale[:, None] < lhs_size) & (col_offs_scale[None, :] < k_blocks) - scale_block = tl.load(ptrs, mask=mask, other=1.0) - - scale_expanded = scale_block[:, :, None] - scale_expanded = tl.broadcast_to( - scale_expanded, - (BLOCK_lhs, (BLOCK_K + TILE_SIZE - 1) // TILE_SIZE, MIN_BLOCK_TILE_K) - ) - scale_expanded = scale_expanded.reshape( - BLOCK_lhs, - ((BLOCK_K + TILE_SIZE - 1) // TILE_SIZE) * MIN_BLOCK_TILE_K - ) - - return scale_expanded -""" - -blockwise128x128_scaling = r""" -@triton.jit -def blockwise128x128_scaling( - pid, - scale, - ki, - lhs_blocks, - k_blocks, - BLOCK_lhs: tl.constexpr, - BLOCK_K: tl.constexpr, - MIN_BLOCK_TILE_lhs: tl.constexpr, - MIN_BLOCK_TILE_K: tl.constexpr, -): - row_offs_scale = pid * tl.cdiv(BLOCK_lhs, 128) + tl.arange(0, (BLOCK_lhs + 128 - 1) // 128) - col_offs_scale = ki * tl.cdiv(BLOCK_K, 128) + tl.arange(0, (BLOCK_K + 128 - 1) // 128) - ptrs = scale + row_offs_scale[:, None] * k_blocks + col_offs_scale[None, :] - mask = (row_offs_scale[:, None] < lhs_blocks) & (col_offs_scale[None, :] < k_blocks) - scale_block = tl.load(ptrs, mask=mask, other=1.0) - - scale_expanded = scale_block[:, :, None, None] - scale_expanded = tl.broadcast_to( - scale_expanded, - ((BLOCK_lhs + 128 - 1) // 128, (BLOCK_K + 128 - 1) // 128, MIN_BLOCK_TILE_lhs, MIN_BLOCK_TILE_K) - ) - scale_expanded = scale_expanded.reshape( - ((BLOCK_lhs + 128 - 1) // 128) * MIN_BLOCK_TILE_lhs, - ((BLOCK_K + 128 - 1) // 128) * MIN_BLOCK_TILE_K - ) - - return scale_expanded -""" - -scaled_mm_device_tma_main_loop_scaling = r""" -{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} - M = {{size("A", 0)}} - N = {{size("B", 1)}} - K = {{size("A", 1)}} - if M * N == 0: - # early exit due to zero-size input(s) - return - - stride_am = {{stride("A", 0)}} - stride_bn = {{stride("B", 1)}} - - start_pid = tl.program_id(axis=0).to(INDEX_DTYPE) - num_pid_m = tl.cdiv(M, BLOCK_M) - num_pid_n = tl.cdiv(N, BLOCK_N) - k_tiles = tl.cdiv(K, BLOCK_K) - num_tiles = num_pid_m * num_pid_n - - a_desc = triton.language.make_tensor_descriptor( - base=A, - shape=[M, K], - strides=[stride_am, 1], - block_shape=[BLOCK_M, BLOCK_K], - ) - b_desc = triton.language.make_tensor_descriptor( - base=B, - shape=[N, K], - strides=[stride_bn, 1], - block_shape=[BLOCK_N, BLOCK_K], - ) - - tiles_per_SM = num_tiles // NUM_SMS - if start_pid < num_tiles % NUM_SMS: - tiles_per_SM += 1 - - tile_id = start_pid - NUM_SMS - ki = -1 - - pid_m = 0 - pid_n = 0 - offs_am = 0 - offs_bn = 0 - - num_pid_in_group = GROUP_M * num_pid_n - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - a_scale = load_scales(A_inverse_scale, SCALE_RECIPE_A) - b_scale = load_scales(B_inverse_scale, SCALE_RECIPE_B) - - for _ in range(0, k_tiles * tiles_per_SM): - ki = tl.where(ki == k_tiles - 1, 0, ki + 1) - if ki == 0: - tile_id += NUM_SMS - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_M) - pid_m = first_pid_m + (tile_id % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_M - offs_bn = pid_n * BLOCK_N - - offs_k = ki * BLOCK_K - - a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k]) - b = tl.load_tensor_descriptor(b_desc, [offs_bn, offs_k]) - - am_blocks = tl.cdiv(M, TILE_SIZE_A) - ak_blocks = tl.cdiv(K, TILE_SIZE_A) - bn_blocks = tl.cdiv(N, TILE_SIZE_B) - bk_blocks = tl.cdiv(K, TILE_SIZE_B) - - {%- if SCALE_RECIPE_A == 5 %} # ScalingType.Blockwise128x128 - scale_a_block = blockwise128x128_scaling( - pid_m, - a_scale, - ki, - am_blocks, - ak_blocks, - BLOCK_M, - BLOCK_K, - MIN_BLOCK_TILE_AM, - MIN_BLOCK_TILE_AK, - ) - {%- else %} # ScalingType.Blockwise1xTILESIZE - scale_a_block = blockwise1xTILESIZE_scaling( - pid_m, - a_scale, - ki, - M, - am_blocks, - ak_blocks, - BLOCK_M, - BLOCK_K, - MIN_BLOCK_TILE_AK, - TILE_SIZE_A, - ) - {%- endif %} - - {%- if SCALE_RECIPE_A == 5 %} # ScalingType.Blockwise128x128 - scale_b_block = blockwise128x128_scaling( - pid_n, - b_scale, - ki, - bn_blocks, - bk_blocks, - BLOCK_N, - BLOCK_K, - MIN_BLOCK_TILE_BN, - MIN_BLOCK_TILE_BK, - ) - {%- else %} # ScalingType.Blockwise1xTILESIZE - scale_b_block = blockwise1xTILESIZE_scaling( - pid_n, - b_scale, - ki, - N, - bn_blocks, - bk_blocks, - BLOCK_N, - BLOCK_K, - MIN_BLOCK_TILE_BK, - TILE_SIZE_B, - ) - {%- endif %} - - a_scaled = a * scale_a_block - b_scaled = b * scale_b_block - accumulator = tl.dot(a_scaled, b_scaled.T, accumulator) - - if ki == k_tiles - 1: - offs_cm = offs_am + tl.arange(0, BLOCK_M) - offs_cn = offs_bn + tl.arange(0, BLOCK_N) - - # inductor generates a suffix - {{store_output( - ("offs_am", "offs_bn"), - "accumulator", - indent_width=12, - val_shape=("BLOCK_M", "BLOCK_N"), - block_indexing=True, - )}} - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) -""" scaled_mm_device_tma_main_loop_scaling_template = TritonTemplate( name="scaled_mm_device_tma_main_loop_scaling", grid=persistent_mm_grid, - source=scaled_mm_device_tma_main_loop_scaling - + load_scales - + blockwise1xTILESIZE_scaling - + blockwise128x128_scaling, + source=load_kernel_template("triton_main_loop_scaled_mm"), ) -_compute_blackwell_pid = r""" -@triton.jit -def _compute_pid(tile_id, num_pid_in_group, grid_m, GROUP_M: tl.constexpr, NUM_SMS: tl.constexpr): - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_M - GROUP_M = min(grid_m - first_pid_m, GROUP_M) - pid_m = first_pid_m + (tile_id % GROUP_M) - pid_n = (tile_id % num_pid_in_group) // GROUP_M - return pid_m, pid_n -""" - -_blackwell_ws_persistent_device_tma = r""" -{{def_kernel("A", "B")}} - M = {{size("A", 0)}} - N = {{size("B", 1)}} - K = {{size("A", 1)}} - if M * N == 0: - # early exit due to zero-size input(s) - return - start_pid = tl.program_id(0) - grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - k_tiles = tl.cdiv(K, BLOCK_K) - num_tiles = grid_m * grid_n - - # Note: We require TMA_EXPERIMENTAL_API == False, which - # we will check before invoking this template. - stride_am = {{stride("A", 0)}} - stride_ak = {{stride("A", 1)}} - stride_bk = {{stride("B", 0)}} - stride_bn = {{stride("B", 1)}} - a_desc = triton.language.make_tensor_descriptor( - base=A, - shape=[M, K] if A_ROW_MAJOR else [K, M], - strides=[stride_am, 1] if A_ROW_MAJOR else [stride_ak, 1], - block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], - ) - b_desc = triton.language.make_tensor_descriptor( - base=B, - shape=[K, N] if B_ROW_MAJOR else [N, K], - strides=[stride_bk, 1] if B_ROW_MAJOR else [stride_bn, 1], - block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], - ) - - # tile_id_c is used in the epilogue to break the dependency between - # the prologue and the epilogue - tile_id_c = start_pid - NUM_SMS - num_pid_in_group = GROUP_M * grid_n - - for tile_id in tl.range( - start_pid, num_tiles, NUM_SMS, flatten=FLATTEN, warp_specialize=WARP_SPECIALIZE - ): - pid_m, pid_n = _compute_pid( - tile_id, num_pid_in_group, grid_m, GROUP_M, NUM_SMS - ) - offs_am = pid_m * BLOCK_M - offs_bn = pid_n * BLOCK_N - - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for ki in range(k_tiles): - offs_k = ki * BLOCK_K - a = tl.load_tensor_descriptor( - a_desc, - [offs_am, offs_k] if A_ROW_MAJOR else [offs_k, offs_am], - ) - b = tl.load_tensor_descriptor( - b_desc, - [offs_k, offs_bn] if B_ROW_MAJOR else [offs_bn, offs_k], - ) - accumulator += tl.dot( - a if A_ROW_MAJOR else a.T, - b if B_ROW_MAJOR else b.T, - allow_tf32=ALLOW_TF32, - ) - - tile_id_c += NUM_SMS - pid_m, pid_n = _compute_pid( - tile_id_c, num_pid_in_group, grid_m, GROUP_M, NUM_SMS - ) - offs_cm = pid_m * BLOCK_M - offs_cn = pid_n * BLOCK_N - {%- if EPILOGUE_SUBTILE %} - tl.static_assert(BLOCK_N % 2 == 0) - acc = tl.reshape(accumulator, (BLOCK_M, 2, BLOCK_N // 2)) - acc = tl.permute(acc, (0, 2, 1)) - acc0, acc1 = tl.split(acc) - {{store_output( - ("offs_cm", "offs_cn"), - "acc0", - indent_width=8, - val_shape=("BLOCK_M", "BLOCK_N // 2"), - block_indexing=True - )}} - offs_cn2 = offs_cn + BLOCK_N // 2 - {{store_output( - ("offs_cm", "offs_cn2"), - "acc1", - indent_width=8, - val_shape=("BLOCK_M", "BLOCK_N // 2"), - block_indexing=True - )}} - {%- else %} - {{store_output( - ("offs_cm", "offs_cn"), - "accumulator", - indent_width=8, - val_shape=("BLOCK_M", "BLOCK_N"), - block_indexing=True - )}} - {%- endif %} -""" - blackwell_ws_persistent_device_tma_mm_template = TritonTemplate( name="blackwell_ws_persistent_device_tma", grid=persistent_mm_grid, - source=_blackwell_ws_persistent_device_tma + _compute_blackwell_pid, + source=load_kernel_template("triton_blackwell_ws_persistent_device_tma_mm"), ) diff --git a/torch/_inductor/kernel/templates/triton_blackwell_ws_persistent_device_tma_mm.py.jinja b/torch/_inductor/kernel/templates/triton_blackwell_ws_persistent_device_tma_mm.py.jinja new file mode 100644 index 0000000000000..34ff2d69793c0 --- /dev/null +++ b/torch/_inductor/kernel/templates/triton_blackwell_ws_persistent_device_tma_mm.py.jinja @@ -0,0 +1,107 @@ +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + start_pid = tl.program_id(0) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = grid_m * grid_n + + # Note: We require TMA_EXPERIMENTAL_API == False, which + # we will check before invoking this template. + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + a_desc = triton.language.make_tensor_descriptor( + base=A, + shape=[M, K] if A_ROW_MAJOR else [K, M], + strides=[stride_am, 1] if A_ROW_MAJOR else [stride_ak, 1], + block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], + ) + b_desc = triton.language.make_tensor_descriptor( + base=B, + shape=[K, N] if B_ROW_MAJOR else [N, K], + strides=[stride_bk, 1] if B_ROW_MAJOR else [stride_bn, 1], + block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], + ) + + # tile_id_c is used in the epilogue to break the dependency between + # the prologue and the epilogue + tile_id_c = start_pid - NUM_SMS + num_pid_in_group = GROUP_M * grid_n + + for tile_id in tl.range( + start_pid, num_tiles, NUM_SMS, flatten=FLATTEN, warp_specialize=WARP_SPECIALIZE + ): + pid_m, pid_n = _compute_pid( + tile_id, num_pid_in_group, grid_m, GROUP_M, NUM_SMS + ) + offs_am = pid_m * BLOCK_M + offs_bn = pid_n * BLOCK_N + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_K + a = tl.load_tensor_descriptor( + a_desc, + [offs_am, offs_k] if A_ROW_MAJOR else [offs_k, offs_am], + ) + b = tl.load_tensor_descriptor( + b_desc, + [offs_k, offs_bn] if B_ROW_MAJOR else [offs_bn, offs_k], + ) + accumulator += tl.dot( + a if A_ROW_MAJOR else a.T, + b if B_ROW_MAJOR else b.T, + allow_tf32=ALLOW_TF32, + ) + + tile_id_c += NUM_SMS + pid_m, pid_n = _compute_pid( + tile_id_c, num_pid_in_group, grid_m, GROUP_M, NUM_SMS + ) + offs_cm = pid_m * BLOCK_M + offs_cn = pid_n * BLOCK_N + {%- if EPILOGUE_SUBTILE %} + tl.static_assert(BLOCK_N % 2 == 0) + acc = tl.reshape(accumulator, (BLOCK_M, 2, BLOCK_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + {{store_output( + ("offs_cm", "offs_cn"), + "acc0", + indent_width=8, + val_shape=("BLOCK_M", "BLOCK_N // 2"), + block_indexing=True + )}} + offs_cn2 = offs_cn + BLOCK_N // 2 + {{store_output( + ("offs_cm", "offs_cn2"), + "acc1", + indent_width=8, + val_shape=("BLOCK_M", "BLOCK_N // 2"), + block_indexing=True + )}} + {%- else %} + {{store_output( + ("offs_cm", "offs_cn"), + "accumulator", + indent_width=8, + val_shape=("BLOCK_M", "BLOCK_N"), + block_indexing=True + )}} + {%- endif %} + +@triton.jit +def _compute_pid(tile_id, num_pid_in_group, grid_m, GROUP_M: tl.constexpr, NUM_SMS: tl.constexpr): + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_M + GROUP_M = min(grid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (tile_id % GROUP_M) + pid_n = (tile_id % num_pid_in_group) // GROUP_M + return pid_m, pid_n diff --git a/torch/_inductor/kernel/templates/triton_epilogue_scaled_mm.py.jinja b/torch/_inductor/kernel/templates/triton_epilogue_scaled_mm.py.jinja new file mode 100644 index 0000000000000..56ef18b7a91e3 --- /dev/null +++ b/torch/_inductor/kernel/templates/triton_epilogue_scaled_mm.py.jinja @@ -0,0 +1,194 @@ +{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + if SCALE_RECIPE_A == 1: # ScalingType.RowWise + stride_a_scale_m = 1 + else: + stride_a_scale_m = 0 + + if SCALE_RECIPE_B == 1: # ScalingType.RowWise + stride_b_scale_n = 1 + else: + stride_b_scale_n = 0 + + start_pid = tl.program_id(axis=0).to(INDEX_DTYPE) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = num_pid_m * num_pid_n + + {%- if TMA_EXPERIMENTAL_API %} + workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE + a_desc_ptr = workspace_base + b_desc_ptr = workspace_base + TMA_SIZE + + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=a_desc_ptr, + global_address=A, + load_size=[BLOCK_M, BLOCK_K], + global_size=[M, K], + element_ty=A.dtype.element_ty, + ) + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=b_desc_ptr, + global_address=B, + load_size=[BLOCK_N, BLOCK_K], + global_size=[N, K], + element_ty=B.dtype.element_ty, + ) + + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + + {%- else %} + stride_am = {{stride("A", 0)}} + stride_bn = {{stride("B", 1)}} + a_desc = triton.language.make_tensor_descriptor( + base=A, + shape=[M, K], + strides=[stride_am, 1], + block_shape=[BLOCK_M, BLOCK_K], + ) + b_desc = triton.language.make_tensor_descriptor( + base=B, + shape=[N, K], + strides=[stride_bn, 1], + block_shape=[BLOCK_N, BLOCK_K], + ) + {%- endif %} + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + pid_m = 0 + pid_n = 0 + offs_am = 0 + offs_bn = 0 + + num_pid_in_group = GROUP_M * num_pid_n + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + a_scale = load_scales(A_inverse_scale, SCALE_RECIPE_A) + b_scale = load_scales(B_inverse_scale, SCALE_RECIPE_B) + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_M + offs_bn = pid_n * BLOCK_N + + offs_k = ki * BLOCK_K + + {%- if TMA_EXPERIMENTAL_API %} + a = tl._experimental_descriptor_load( + a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty + ) + {%- else %} + a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k]) + b = tl.load_tensor_descriptor(b_desc, [offs_bn, offs_k]) + {%- endif %} + if USE_FAST_ACCUM: + accumulator = tl.dot(a, b.T, accumulator) + else: + accumulator += tl.dot(a, b.T) + + if ki == k_tiles - 1: + # Apply inverse scaling + offs_cm = offs_am + tl.arange(0, BLOCK_M) + offs_cn = offs_bn + tl.arange(0, BLOCK_N) + # Apply scaling + accumulator = apply_scaling( + accumulator, + a_scale, + b_scale, + SCALE_RECIPE_A, + SCALE_RECIPE_B, + offs_cm, + offs_cn, + M, + N, + stride_a_scale_m, + stride_b_scale_n, + ) + + # inductor generates a suffix + {%- if TMA_EXPERIMENTAL_API %} + idx_m = offs_cm[:, None] + idx_n = offs_cn[None, :] + mask = (idx_m < M) & (idx_n < N) + {{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12, val_shape=("BLOCK_M", "BLOCK_N"))}} + {%- else %} + {{store_output( + ("offs_am", "offs_bn"), + "accumulator", + indent_width=12, + val_shape=("BLOCK_M", "BLOCK_N"), + block_indexing=True, + )}} + {%- endif %} + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + +@triton.jit +def load_scales(scale_ptr, SCALE_RECIPE: tl.constexpr): + if SCALE_RECIPE == 0: + return tl.load(scale_ptr) # For tensor-wise scaling, we'll load the scalar values + else: + return scale_ptr # For all other scaling recipes, we'll return the pointers + + +@triton.jit +def apply_scaling( + accumulator, + a_scale, + b_scale, + SCALE_RECIPE_A: tl.constexpr, + SCALE_RECIPE_B: tl.constexpr, + offs_cm, + offs_cn, + M, + N, + stride_a_scale_m, + stride_b_scale_n, +): + if SCALE_RECIPE_A == 1 and SCALE_RECIPE_B == 1: # (ScalingType.RowWise, ScalingType.RowWise) + # For row-wise scaling, we need to load the scales for each row/column + a_scales = tl.load( + a_scale + (offs_cm * stride_a_scale_m), + mask=offs_cm < M, + other=0.0, + ) + b_scales = tl.load( + b_scale + (offs_cn * stride_b_scale_n), + mask=offs_cn < N, + other=0.0, + ) + acc_scale = a_scales[:, None] * b_scales[None, :] + else: # (ScalingType.TensorWise, ScalingType.TensorWise) + # For per-tensor scaling, we can directly use the loaded scalar values + acc_scale = a_scale * b_scale + + return accumulator * acc_scale diff --git a/torch/_inductor/kernel/templates/triton_main_loop_scaled_mm.py.jinja b/torch/_inductor/kernel/templates/triton_main_loop_scaled_mm.py.jinja new file mode 100644 index 0000000000000..171340a2c9233 --- /dev/null +++ b/torch/_inductor/kernel/templates/triton_main_loop_scaled_mm.py.jinja @@ -0,0 +1,212 @@ +{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + + stride_am = {{stride("A", 0)}} + stride_bn = {{stride("B", 1)}} + + start_pid = tl.program_id(axis=0).to(INDEX_DTYPE) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = num_pid_m * num_pid_n + + a_desc = triton.language.make_tensor_descriptor( + base=A, + shape=[M, K], + strides=[stride_am, 1], + block_shape=[BLOCK_M, BLOCK_K], + ) + b_desc = triton.language.make_tensor_descriptor( + base=B, + shape=[N, K], + strides=[stride_bn, 1], + block_shape=[BLOCK_N, BLOCK_K], + ) + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + pid_m = 0 + pid_n = 0 + offs_am = 0 + offs_bn = 0 + + num_pid_in_group = GROUP_M * num_pid_n + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + a_scale = load_scales(A_inverse_scale, SCALE_RECIPE_A) + b_scale = load_scales(B_inverse_scale, SCALE_RECIPE_B) + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_M + offs_bn = pid_n * BLOCK_N + + offs_k = ki * BLOCK_K + + a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k]) + b = tl.load_tensor_descriptor(b_desc, [offs_bn, offs_k]) + + am_blocks = tl.cdiv(M, TILE_SIZE_A) + ak_blocks = tl.cdiv(K, TILE_SIZE_A) + bn_blocks = tl.cdiv(N, TILE_SIZE_B) + bk_blocks = tl.cdiv(K, TILE_SIZE_B) + + {%- if SCALE_RECIPE_A == 5 %} # ScalingType.Blockwise128x128 + scale_a_block = blockwise128x128_scaling( + pid_m, + a_scale, + ki, + am_blocks, + ak_blocks, + BLOCK_M, + BLOCK_K, + MIN_BLOCK_TILE_AM, + MIN_BLOCK_TILE_AK, + ) + {%- else %} # ScalingType.Blockwise1xTILESIZE + scale_a_block = blockwise1xTILESIZE_scaling( + pid_m, + a_scale, + ki, + M, + am_blocks, + ak_blocks, + BLOCK_M, + BLOCK_K, + MIN_BLOCK_TILE_AK, + TILE_SIZE_A, + ) + {%- endif %} + + {%- if SCALE_RECIPE_A == 5 %} # ScalingType.Blockwise128x128 + scale_b_block = blockwise128x128_scaling( + pid_n, + b_scale, + ki, + bn_blocks, + bk_blocks, + BLOCK_N, + BLOCK_K, + MIN_BLOCK_TILE_BN, + MIN_BLOCK_TILE_BK, + ) + {%- else %} # ScalingType.Blockwise1xTILESIZE + scale_b_block = blockwise1xTILESIZE_scaling( + pid_n, + b_scale, + ki, + N, + bn_blocks, + bk_blocks, + BLOCK_N, + BLOCK_K, + MIN_BLOCK_TILE_BK, + TILE_SIZE_B, + ) + {%- endif %} + + a_scaled = a * scale_a_block + b_scaled = b * scale_b_block + accumulator = tl.dot(a_scaled, b_scaled.T, accumulator) + + if ki == k_tiles - 1: + offs_cm = offs_am + tl.arange(0, BLOCK_M) + offs_cn = offs_bn + tl.arange(0, BLOCK_N) + + # inductor generates a suffix + {{store_output( + ("offs_am", "offs_bn"), + "accumulator", + indent_width=12, + val_shape=("BLOCK_M", "BLOCK_N"), + block_indexing=True, + )}} + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + +@triton.jit +def load_scales(scale_ptr, SCALE_RECIPE: tl.constexpr): + if SCALE_RECIPE == 0: + return tl.load(scale_ptr) # For tensor-wise scaling, we'll load the scalar values + else: + return scale_ptr # For all other scaling recipes, we'll return the pointers + + +@triton.jit +def blockwise1xTILESIZE_scaling( + pid, + scale, + ki, + lhs_size, + lhs_blocks, + k_blocks, + BLOCK_lhs: tl.constexpr, + BLOCK_K: tl.constexpr, + MIN_BLOCK_TILE_K: tl.constexpr, + TILE_SIZE: tl.constexpr, +): + row_offs_scale = pid * BLOCK_lhs + tl.arange(0, BLOCK_lhs) + col_offs_scale = ki * tl.cdiv(BLOCK_K, TILE_SIZE) + tl.arange(0, (BLOCK_K + TILE_SIZE - 1) // TILE_SIZE) + ptrs = scale + row_offs_scale[:, None] * k_blocks + col_offs_scale[None, :] + mask = (row_offs_scale[:, None] < lhs_size) & (col_offs_scale[None, :] < k_blocks) + scale_block = tl.load(ptrs, mask=mask, other=1.0) + + scale_expanded = scale_block[:, :, None] + scale_expanded = tl.broadcast_to( + scale_expanded, + (BLOCK_lhs, (BLOCK_K + TILE_SIZE - 1) // TILE_SIZE, MIN_BLOCK_TILE_K) + ) + scale_expanded = scale_expanded.reshape( + BLOCK_lhs, + ((BLOCK_K + TILE_SIZE - 1) // TILE_SIZE) * MIN_BLOCK_TILE_K + ) + + return scale_expanded + + +@triton.jit +def blockwise128x128_scaling( + pid, + scale, + ki, + lhs_blocks, + k_blocks, + BLOCK_lhs: tl.constexpr, + BLOCK_K: tl.constexpr, + MIN_BLOCK_TILE_lhs: tl.constexpr, + MIN_BLOCK_TILE_K: tl.constexpr, +): + row_offs_scale = pid * tl.cdiv(BLOCK_lhs, 128) + tl.arange(0, (BLOCK_lhs + 128 - 1) // 128) + col_offs_scale = ki * tl.cdiv(BLOCK_K, 128) + tl.arange(0, (BLOCK_K + 128 - 1) // 128) + ptrs = scale + row_offs_scale[:, None] * k_blocks + col_offs_scale[None, :] + mask = (row_offs_scale[:, None] < lhs_blocks) & (col_offs_scale[None, :] < k_blocks) + scale_block = tl.load(ptrs, mask=mask, other=1.0) + + scale_expanded = scale_block[:, :, None, None] + scale_expanded = tl.broadcast_to( + scale_expanded, + ((BLOCK_lhs + 128 - 1) // 128, (BLOCK_K + 128 - 1) // 128, MIN_BLOCK_TILE_lhs, MIN_BLOCK_TILE_K) + ) + scale_expanded = scale_expanded.reshape( + ((BLOCK_lhs + 128 - 1) // 128) * MIN_BLOCK_TILE_lhs, + ((BLOCK_K + 128 - 1) // 128) * MIN_BLOCK_TILE_K + ) + + return scale_expanded diff --git a/torch/_inductor/kernel/templates/triton_mm.py.jinja b/torch/_inductor/kernel/templates/triton_mm.py.jinja new file mode 100644 index 0000000000000..2da348f3e767c --- /dev/null +++ b/torch/_inductor/kernel/templates/triton_mm.py.jinja @@ -0,0 +1,72 @@ +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0).to(INDEX_DTYPE) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + if ((stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1)) and (M >= BLOCK_M and K > 1): + offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + offs_a_m = rm % M + if ((stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1)) and (N >= BLOCK_N and K > 1): + offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + else: + offs_b_n = rn % N + offs_k = tl.arange(0, BLOCK_K) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + for k_idx in range(0, tl.cdiv(K, BLOCK_K)): + {% if not EVEN_K %} + a_mask = offs_k[None, :] < (K - k_idx * BLOCK_K) + b_mask = offs_k[:, None] < (K - k_idx * BLOCK_K) + {% endif %} + a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K) + b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K) + + idx_m = offs_a_m[:, None] + idx_n = a_k_idx_vals + {{load_input("A", "a", ("idx_m", "idx_n"), mask=None if EVEN_K else "a_mask", + indent_width=8, index_shape=("BLOCK_M", "BLOCK_K"))}} + + idx_m = b_k_idx_vals + idx_n = offs_b_n[None, :] + {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", + indent_width=8, index_shape=("BLOCK_K", "BLOCK_N"))}} + + {% if USE_FAST_ACCUM %} + acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% else %} + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% endif %} + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask", val_shape=("BLOCK_M", "BLOCK_N"))}} diff --git a/torch/_inductor/kernel/templates/triton_mm_rocm.py.jinja b/torch/_inductor/kernel/templates/triton_mm_rocm.py.jinja new file mode 100644 index 0000000000000..42b99c70d5cbd --- /dev/null +++ b/torch/_inductor/kernel/templates/triton_mm_rocm.py.jinja @@ -0,0 +1,71 @@ +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0).to(INDEX_DTYPE) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1): + offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + offs_a_m = rm % M + if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1): + offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + else: + offs_b_n = rn % N + offs_k = tl.arange(0, BLOCK_K) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + for k_idx in range(0, tl.cdiv(K, BLOCK_K)): + {% if not EVEN_K %} + a_mask = offs_k[None, :] < (K - k_idx * BLOCK_K) + b_mask = offs_k[:, None] < (K - k_idx * BLOCK_K) + {% endif %} + a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K) + b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K) + + idx_m = offs_a_m[:, None] + idx_n = a_k_idx_vals + {{load_input("A", "a", ("idx_m", "idx_n"), mask=None if EVEN_K else "a_mask", + indent_width=8, index_shape=("BLOCK_M", "BLOCK_K"))}} + + idx_m = b_k_idx_vals + idx_n = offs_b_n[None, :] + {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", + indent_width=8, index_shape=("BLOCK_K", "BLOCK_N"))}} + {% if USE_FAST_ACCUM %} + acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% else %} + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% endif %} + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask", val_shape=("BLOCK_M", "BLOCK_N"))}} diff --git a/torch/_inductor/kernel/templates/triton_persistent_tma_mm.py.jinja b/torch/_inductor/kernel/templates/triton_persistent_tma_mm.py.jinja new file mode 100644 index 0000000000000..38fe092c25780 --- /dev/null +++ b/torch/_inductor/kernel/templates/triton_persistent_tma_mm.py.jinja @@ -0,0 +1,129 @@ +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + + start_pid = tl.program_id(0).to(INDEX_DTYPE) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = grid_m * grid_n + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + width = GROUP_M * grid_n + rk_for_mask = tl.arange(0, BLOCK_K) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + {%- if TMA_EXPERIMENTAL_API %} + workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE + a_desc_ptr = workspace_base + b_desc_ptr = workspace_base + TMA_SIZE + + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=a_desc_ptr, + global_address=A, + load_size=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], + global_size=[M, K] if A_ROW_MAJOR else [K, M], + element_ty=A.dtype.element_ty, + ) + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=b_desc_ptr, + global_address=B, + load_size=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], + global_size=[K, N] if B_ROW_MAJOR else [N, K], + element_ty=B.dtype.element_ty, + ) + + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + + {%- else %} + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + a_desc = triton.language.make_tensor_descriptor( + base=A, + shape=[M, K] if A_ROW_MAJOR else [K, M], + strides=[stride_am, 1] if A_ROW_MAJOR else [stride_ak, 1], + block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], + ) + b_desc = triton.language.make_tensor_descriptor( + base=B, + shape=[K, N] if B_ROW_MAJOR else [N, K], + strides=[stride_bk, 1] if B_ROW_MAJOR else [stride_bn, 1], + block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], + ) + {%- endif %} + + pid_m = 0 + pid_n = 0 + rm = 0 + rn = 0 + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + # re-order program ID for better L2 performance + group_id = tile_id // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (tile_id % group_size) + pid_n = (tile_id % width) // (group_size) + + rm = pid_m * BLOCK_M + rn = pid_n * BLOCK_N + + rk = ki * BLOCK_K + + {%- if TMA_EXPERIMENTAL_API %} + a = tl._experimental_descriptor_load( + a_desc_ptr, + [rm, rk] if A_ROW_MAJOR else [rk, rm], + [BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], + A.dtype.element_ty, + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, + [rk, rn] if B_ROW_MAJOR else [rn, rk], + [BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], + B.dtype.element_ty, + ) + {%- else %} + a = tl.load_tensor_descriptor( + a_desc, + [rm, rk] if A_ROW_MAJOR else [rk, rm], + ) + b = tl.load_tensor_descriptor( + b_desc, + [rk, rn] if B_ROW_MAJOR else [rn, rk], + ) + {%- endif %} + acc += tl.dot( + a if A_ROW_MAJOR else a.T, + b if B_ROW_MAJOR else b.T, + allow_tf32=ALLOW_TF32, + ) + + if ki == k_tiles - 1: + # inductor generates a suffix + {%- if TMA_EXPERIMENTAL_API %} + # rematerialize rm and rn to save registers + rcm = rm + tl.arange(0, BLOCK_M) + rcn = rn + tl.arange(0, BLOCK_N) + idx_m = rcm[:, None] + idx_n = rcn[None, :] + mask = (idx_m < M) & (idx_n < N) + {{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12, val_shape=("BLOCK_M", "BLOCK_N"))}} + {%- else %} + {{store_output(("rm", "rn"), "acc", indent_width=12, val_shape=("BLOCK_M", "BLOCK_N"), block_indexing=True)}} + {%- endif %} + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) From da7c6095710560075f78a56ec43a6896315c3de2 Mon Sep 17 00:00:00 2001 From: Oleksandr Stashuk Date: Thu, 20 Nov 2025 22:09:10 +0000 Subject: [PATCH 118/230] [pytorch] Make clamp kernel branchless (#167889) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: In the old clamp implementation, there was a control divergence with the if statements. This PR reduces the branching for kernel. See attached bench example with source code in linked stack (for correctness + tests on tensor). Branchless implementation shows consistent performance gains across all data types and NaN ratios ======================================== ``` Test Plan: ``` nvcc -O3 -o clamp_bench_comprehensive clamp_benchmark_standalone.cu && ./clamp_bench_comprehensive 2>&1 ``` On H100 ``` ./clamp_bench_comprehensive ======================================== CUDA Clamp Kernel Comprehensive Benchmark ======================================== Elements: 16777216 (64.0 MB per tensor) Performance Results: ------------------------------------------------------------ Float32 0% NaN : Orig= 241± 10 ns Branch= 231± 0 ns Speedup= +3.9% Float32 1% NaN : Orig= 238± 5 ns Branch= 231± 1 ns Speedup= +3.1% Float32 10% NaN : Orig= 223± 0 ns Branch= 215± 1 ns Speedup= +3.2% Float16 0% NaN : Orig= 219± 3 ns Branch= 214± 1 ns Speedup= +2.2% Float16 1% NaN : Orig= 225± 2 ns Branch= 220± 0 ns Speedup= +2.2% Float16 10% NaN : Orig= 304± 89 ns Branch= 217± 4 ns Speedup=+28.5% BFloat16 0% NaN : Orig= 216± 1 ns Branch= 212± 3 ns Speedup= +1.6% BFloat16 1% NaN : Orig= 216± 1 ns Branch= 212± 0 ns Speedup= +1.9% BFloat16 10% NaN : Orig= 217± 2 ns Branch= 211± 0 ns Speedup= +2.6% ``` On B200 ``` ./clamp_bench_comprehensive ======================================== CUDA Clamp Kernel Comprehensive Benchmark ======================================== Elements: 16777216 (64.0 MB per tensor) Performance Results: ------------------------------------------------------------ Float32 0% NaN : Orig=104331± 53 ns Branch= 59445± 17 ns Speedup=+43.0% Float32 1% NaN : Orig=104520± 13 ns Branch= 59439± 13 ns Speedup=+43.1% Float32 10% NaN : Orig=104493± 17 ns Branch= 59440± 6 ns Speedup=+43.1% Float16 0% NaN : Orig= 98249± 16 ns Branch= 53278± 16 ns Speedup=+45.8% Float16 1% NaN : Orig= 98313± 9 ns Branch= 53287± 13 ns Speedup=+45.8% Float16 10% NaN : Orig= 98335± 9 ns Branch= 53287± 19 ns Speedup=+45.8% BFloat16 0% NaN : Orig= 98492± 47 ns Branch= 55312± 8 ns Speedup=+43.8% BFloat16 1% NaN : Orig= 99783± 69 ns Branch= 55321± 9 ns Speedup=+44.6% BFloat16 10% NaN : Orig=100284± 37 ns Branch= 55329± 26 ns Speedup=+44.8% ======================================== Differential Revision: D86561069 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167889 Approved by: https://github.com/Skylion007 --- aten/src/ATen/native/cuda/TensorCompare.cu | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/aten/src/ATen/native/cuda/TensorCompare.cu b/aten/src/ATen/native/cuda/TensorCompare.cu index ab38c1975d147..031e5b3c4f14e 100644 --- a/aten/src/ATen/native/cuda/TensorCompare.cu +++ b/aten/src/ATen/native/cuda/TensorCompare.cu @@ -44,16 +44,13 @@ void isneginf_kernel_impl(TensorIteratorBase &iter) { void clamp_kernel_impl(TensorIteratorBase& iter) { AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "clamp_cuda", [&] { gpu_kernel(iter, []GPU_LAMBDA(scalar_t v, scalar_t lower, scalar_t upper) -> scalar_t { - // Propagate nan, which doesn't propagate automatically for ROCm - if (at::_isnan(v)) { - return v; - } if (at::_isnan(lower)) { - return lower; - } if (at::_isnan(upper)) { - return upper; - } else { - return ::min(::max(v, lower), upper); - } + scalar_t result = ::min(::max(v, lower), upper); + + result = at::_isnan(upper) ? upper : result; + result = at::_isnan(lower) ? lower : result; + result = at::_isnan(v) ? v : result; + + return result; }); }); } From e7a85200da063a6dc6df0a3554e9e3c2426d0491 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 20 Nov 2025 22:09:33 +0000 Subject: [PATCH 119/230] Revert "Remove useless super() delegation (#168235)" This reverts commit 88d635c54f73393f50cb795cfa13b15ba7d7339b. Reverted https://github.com/pytorch/pytorch/pull/168235 on behalf of https://github.com/yangw-dev due to breaks distributed tests related to error_pid https://github.com/pytorch/pytorch/actions/runs/19546407616/job/55967118407#step:27:3191 ([comment](https://github.com/pytorch/pytorch/pull/168235#issuecomment-3560268701)) --- torch/_dynamo/exc.py | 15 ++- torch/_dynamo/variables/base.py | 3 + torch/_dynamo/variables/dicts.py | 14 +++ torch/_higher_order_ops/_invoke_quant.py | 3 + .../learnedheuristic_interface.py | 6 + torch/_inductor/codegen/cpp.py | 3 + torch/_inductor/codegen/cuda/gemm_template.py | 13 ++ .../ao/nn/intrinsic/qat/modules/conv_fused.py | 117 ++++++++++++++++++ .../data_sparsifier/benchmarks/dlrm_utils.py | 3 + .../quantizer/embedding_quantizer.py | 3 + .../quantizer/xpu_inductor_quantizer.py | 3 + torch/backends/__init__.py | 3 + torch/backends/cudnn/__init__.py | 3 + torch/backends/miopen/__init__.py | 3 + torch/backends/mkldnn/__init__.py | 3 + torch/backends/opt_einsum/__init__.py | 3 + .../_checkpoint/checkpoint_wrapper.py | 3 + torch/jit/_monkeytype_config.py | 3 + torch/jit/_recursive.py | 3 +- torch/multiprocessing/spawn.py | 8 ++ torch/testing/_internal/common_utils.py | 2 + 21 files changed, 212 insertions(+), 5 deletions(-) diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index 5b0e8a402dd96..f11c78bdaa49e 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -198,20 +198,24 @@ class RecompileError(TorchDynamoException): class ArgsMismatchError(Unsupported): - pass + def __init__(self, msg: str) -> None: + super().__init__(msg) class AttributeMutationError(Unsupported): - pass + def __init__(self, msg: str) -> None: + super().__init__(msg) class InfiniteGeneratorError(Unsupported): # Raised when the number of yielded values is greater than MAX_ITERATOR_LIMIT - pass + def __init__(self, msg: str) -> None: + super().__init__(msg) class SideEffectsError(Unsupported): - pass + def __init__(self, msg: str) -> None: + super().__init__(msg) class CondOpArgsMismatchError(ArgsMismatchError): @@ -219,6 +223,9 @@ class CondOpArgsMismatchError(ArgsMismatchError): Internal error from cond() due to arguments mismatch. """ + def __init__(self, msg: str) -> None: + super().__init__(msg) + class UserErrorType(Enum): DYNAMIC_CONTROL_FLOW = auto() diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index 2d11a27bafac0..4e248320e60b6 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -151,6 +151,9 @@ class AttributeMutation(MutationType): allows mutation on the value's attributes. """ + def __init__(self, typ: SourceType) -> None: + super().__init__(typ) + class AttributeMutationExisting(AttributeMutation): """ diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 636875d85e54a..24cd5007da37d 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -1296,6 +1296,13 @@ def install_dict_contains_guard( class FrozensetVariable(SetVariable): + def __init__( + self, + items: list[VariableTracker], + **kwargs: Any, + ) -> None: + super().__init__(items, **kwargs) + def debug_repr(self) -> str: if not self.items: return "frozenset()" @@ -1353,6 +1360,13 @@ def call_method( class DictKeySetVariable(SetVariable): + def __init__( + self, + items: list[VariableTracker], + **kwargs: Any, + ) -> None: + super().__init__(items, **kwargs) + def debug_repr(self) -> str: if not self.items: return "dict_keys([])" diff --git a/torch/_higher_order_ops/_invoke_quant.py b/torch/_higher_order_ops/_invoke_quant.py index b7a9fb94b93e2..1fc1e1114a036 100644 --- a/torch/_higher_order_ops/_invoke_quant.py +++ b/torch/_higher_order_ops/_invoke_quant.py @@ -26,6 +26,9 @@ class InvokeQuantUnpacked(BaseHOP): def __init__(self) -> None: super().__init__("invoke_quant") + def __call__(self, subgraph, *operands, scheme=None): + return super().__call__(subgraph, *operands, scheme=scheme) + invoke_quant = InvokeQuantUnpacked() diff --git a/torch/_inductor/autoheuristic/learnedheuristic_interface.py b/torch/_inductor/autoheuristic/learnedheuristic_interface.py index 84a941b076c31..cb2568d8a6801 100644 --- a/torch/_inductor/autoheuristic/learnedheuristic_interface.py +++ b/torch/_inductor/autoheuristic/learnedheuristic_interface.py @@ -39,6 +39,9 @@ def get_decisions_ranked(self, context: AHContext) -> Optional[list[str]]: class LearnedHeuristicRegression(LearnedHeuristic): + def __init__(self) -> None: + super().__init__() + def get_feedback(self, context: AHContext, choice: Choice) -> float: return 1.0 @@ -61,6 +64,9 @@ def get_decision( class LearnedHeuristicDecision(LearnedHeuristic): + def __init__(self) -> None: + super().__init__() + def get_choice(self, idx: int) -> Optional[str]: return None diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 18b209de94cb3..88f203421cc1c 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -3786,6 +3786,9 @@ class TilingSelect: In the future, we can implement advanced heuristic in a subclass. """ + def __init__(self): + super().__init__() + def select_tiling( self, fn_list, diff --git a/torch/_inductor/codegen/cuda/gemm_template.py b/torch/_inductor/codegen/cuda/gemm_template.py index c4b7188bd9e62..22d0981febecd 100644 --- a/torch/_inductor/codegen/cuda/gemm_template.py +++ b/torch/_inductor/codegen/cuda/gemm_template.py @@ -1330,6 +1330,19 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate): including those which allow flexible fusions with epilogues. """ + def __init__( + self, + input_nodes: list[Buffer], + layout: Layout, + alpha: float, + beta: float, + input_reorder: Optional[list[int]] = None, + use_fast_accum: Optional[bool] = None, + ): + super().__init__( + input_nodes, layout, alpha, beta, input_reorder, use_fast_accum + ) + @staticmethod def add_cutlass_gemm_choices( choices: list[ChoiceCaller], diff --git a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py index 1e49a274e129c..0054e996e33ce 100644 --- a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py +++ b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py @@ -112,6 +112,9 @@ def reset_bn_parameters(self): bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) + def reset_parameters(self): + super().reset_parameters() + def update_bn_stats(self): self.freeze_bn = False self.bn.training = True @@ -531,6 +534,44 @@ class ConvBnReLU1d(ConvBn1d): # module class after fusing bn into conv _FUSED_FLOAT_MODULE: ClassVar[type[nn.Module] | None] = nni.ConvReLU1d + def __init__( + self, + # Conv1d args + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None, + padding_mode="zeros", + # BatchNorm1d args + # num_features: out_channels + eps=1e-05, + momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + eps, + momentum, + freeze_bn, + qconfig, + ) + def forward(self, input): return F.relu(self._forward(input)) @@ -694,6 +735,44 @@ class ConvBnReLU2d(ConvBn2d): # module class after fusing bn into conv _FUSED_FLOAT_MODULE: ClassVar[type[nni.ConvReLU2d] | None] = nni.ConvReLU2d + def __init__( + self, + # Conv2d args + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None, + padding_mode="zeros", + # BatchNorm2d args + # num_features: out_channels + eps=1e-05, + momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + eps, + momentum, + freeze_bn, + qconfig, + ) + def forward(self, input): return F.relu(self._forward(input)) @@ -856,6 +935,44 @@ class ConvBnReLU3d(ConvBn3d): # module class after fusing bn into conv _FUSED_FLOAT_MODULE: ClassVar[type[nni.ConvReLU3d] | None] = nni.ConvReLU3d + def __init__( + self, + # Conv3d args + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=None, + padding_mode="zeros", + # BatchNorm3d args + # num_features: out_channels + eps=1e-05, + momentum=0.1, + # affine: True + # track_running_stats: True + # Args for this module + freeze_bn=False, + qconfig=None, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + eps, + momentum, + freeze_bn, + qconfig, + ) + def forward(self, input): return F.relu(ConvBn3d._forward(self, input)) diff --git a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py index e2b31e0e563bf..3c146c55947a0 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py @@ -19,6 +19,9 @@ class SparseDLRM(DLRM_Net): layer of the top layer. """ + def __init__(self, **args): + super().__init__(**args) + def forward(self, dense_x, lS_o, lS_i): # pyrefly: ignore [missing-attribute] x = self.apply_mlp(dense_x, self.bot_l) # dense features diff --git a/torch/ao/quantization/quantizer/embedding_quantizer.py b/torch/ao/quantization/quantizer/embedding_quantizer.py index 3b8ef1030bfdc..b0f1b823b7fdb 100644 --- a/torch/ao/quantization/quantizer/embedding_quantizer.py +++ b/torch/ao/quantization/quantizer/embedding_quantizer.py @@ -41,6 +41,9 @@ def get_embedding_operators_config() -> OperatorConfig: class EmbeddingQuantizer(Quantizer): + def __init__(self) -> None: + super().__init__() + @classmethod def get_supported_quantization_configs(cls) -> list[QuantizationConfig]: op_configs: set[QuantizationConfig] = { diff --git a/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py b/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py index 1c0fc48fd54fa..d19968c2787f4 100644 --- a/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py @@ -75,6 +75,9 @@ class XPUInductorQuantizer(X86InductorQuantizer): of the optimized kernels in oneDNN library. """ + def __init__(self) -> None: + super().__init__() + """ Following annotate_xx overrides the impls in base class, as no XPU implementation for these operators currently. We would diff --git a/torch/backends/__init__.py b/torch/backends/__init__.py index f54a3fd6820c7..c02a8c36fd08b 100644 --- a/torch/backends/__init__.py +++ b/torch/backends/__init__.py @@ -113,6 +113,9 @@ def inner(precision): class GenericModule(PropModule): + def __init__(self, m, name): + super().__init__(m, name) + fp32_precision = ContextProp( _get_fp32_precision_getter("generic", "all"), _set_fp32_precision_setter("generic", "all"), diff --git a/torch/backends/cudnn/__init__.py b/torch/backends/cudnn/__init__.py index 267594531db3d..697783c01cb64 100644 --- a/torch/backends/cudnn/__init__.py +++ b/torch/backends/cudnn/__init__.py @@ -198,6 +198,9 @@ def flags( class CudnnModule(PropModule): + def __init__(self, m, name): + super().__init__(m, name) + enabled = ContextProp(torch._C._get_cudnn_enabled, torch._C._set_cudnn_enabled) deterministic = ContextProp( torch._C._get_cudnn_deterministic, torch._C._set_cudnn_deterministic diff --git a/torch/backends/miopen/__init__.py b/torch/backends/miopen/__init__.py index 1b270b658e31a..93453cc11592d 100644 --- a/torch/backends/miopen/__init__.py +++ b/torch/backends/miopen/__init__.py @@ -37,6 +37,9 @@ def flags( class MiopenModule(PropModule): + def __init__(self, m, name): + super().__init__(m, name) + immediate = ContextProp( torch._C._get_miopen_immediate, torch._C._set_miopen_immediate ) diff --git a/torch/backends/mkldnn/__init__.py b/torch/backends/mkldnn/__init__.py index 58e6b2c595e98..2d1ce8f3bb997 100644 --- a/torch/backends/mkldnn/__init__.py +++ b/torch/backends/mkldnn/__init__.py @@ -110,6 +110,9 @@ def flags(enabled=False, deterministic=False, allow_tf32=True, fp32_precision="n class MkldnnModule(PropModule): + def __init__(self, m, name): + super().__init__(m, name) + def is_available(self): return is_available() diff --git a/torch/backends/opt_einsum/__init__.py b/torch/backends/opt_einsum/__init__.py index 264be78aa9a1c..797d847e31e5c 100644 --- a/torch/backends/opt_einsum/__init__.py +++ b/torch/backends/opt_einsum/__init__.py @@ -101,6 +101,9 @@ def flags(enabled=None, strategy=None): class OptEinsumModule(PropModule): + def __init__(self, m, name): + super().__init__(m, name) + global enabled enabled = ContextProp(_get_enabled, _set_enabled) global strategy diff --git a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py index eae76e8cc72af..3ce067f6cddc0 100644 --- a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py +++ b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py @@ -103,6 +103,9 @@ def _pre_load_state_dict_hook( class OffloadWrapper(ActivationWrapper): + def __init__(self, mod): + super().__init__(mod) + def forward(self, *args, **kwargs): with save_on_cpu(pin_memory=True): return self._checkpoint_wrapped_module(*args, **kwargs) diff --git a/torch/jit/_monkeytype_config.py b/torch/jit/_monkeytype_config.py index e5ddc1e443a29..0f348590ea397 100644 --- a/torch/jit/_monkeytype_config.py +++ b/torch/jit/_monkeytype_config.py @@ -85,6 +85,9 @@ def get_qualified_name(func): class JitTypeTraceStoreLogger(CallTraceStoreLogger): """A JitTypeCallTraceLogger that stores logged traces in a CallTraceStore.""" + def __init__(self, store: CallTraceStore) -> None: + super().__init__(store) + def log(self, trace: CallTrace) -> None: # pyrefly: ignore [missing-attribute] self.traces.append(trace) diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index ec4bbd125119d..75355cbd4b8e0 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -152,7 +152,8 @@ def _get_valid_constant(attr, v, owner_type): class SourceContext(torch._C._jit_tree_views.SourceRangeFactory): - pass + def __init__(self, source, filename, file_lineno, leading_whitespace_len) -> None: + super().__init__(source, filename, file_lineno, leading_whitespace_len) def get_annotations(obj): diff --git a/torch/multiprocessing/spawn.py b/torch/multiprocessing/spawn.py index 12901df09a3c5..f553f7cacd753 100644 --- a/torch/multiprocessing/spawn.py +++ b/torch/multiprocessing/spawn.py @@ -46,6 +46,14 @@ def __reduce__(self): class ProcessRaisedException(ProcessException): """Exception raised when a process failed due to an exception raised by the code.""" + def __init__( + self, + msg: str, + error_index: int, + error_pid: int, + ): + super().__init__(msg, error_index, error_pid) + class ProcessExitedException(ProcessException): """Exception raised when a process failed due to signal or exited with a specific code.""" diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 815cc8859080f..d5afc413daed8 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1370,6 +1370,8 @@ class XMLTestResultVerbose(_XMLTestResult): This works with unittest_xml_reporting<=3.2.0,>=2.0.0 (3.2.0 is latest at the moment) """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) def addSkip(self, test, reason): super().addSkip(test, reason) From 05b11198fd4781f0c223c5b6e7dee054858a4d1d Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 20 Nov 2025 22:32:32 +0000 Subject: [PATCH 120/230] Revert "conv: refactor for lookup table support (#167179)" This reverts commit 90c57aa3b3adc7f52c73090db6e2e7e8caebb762. Reverted https://github.com/pytorch/pytorch/pull/167179 on behalf of https://github.com/yangw-dev due to internall error: expected `Optional[Dict[str, Union[Sequence[Union[bool, float, int]], bool, float, int]]]` but got `Optional[Dict[str, Union[float, int]]]` ([comment](https://github.com/pytorch/pytorch/pull/167179#issuecomment-3560397543)) --- test/inductor/test_lookup_table.py | 156 +--------- torch/_inductor/kernel/conv.py | 130 ++++---- torch/_inductor/kernel_inputs.py | 125 +------- .../_inductor/template_heuristics/__init__.py | 2 +- torch/_inductor/template_heuristics/conv.py | 287 ------------------ 5 files changed, 89 insertions(+), 611 deletions(-) delete mode 100644 torch/_inductor/template_heuristics/conv.py diff --git a/test/inductor/test_lookup_table.py b/test/inductor/test_lookup_table.py index 32be3e730a6fb..250a822267833 100644 --- a/test/inductor/test_lookup_table.py +++ b/test/inductor/test_lookup_table.py @@ -2,18 +2,14 @@ import re import unittest from functools import partial -from typing import Any, Optional +from typing import Any, Optional, Union from unittest.mock import patch import torch import torch.nn as nn from torch._inductor import config as inductor_config from torch._inductor.choices import InductorChoices -from torch._inductor.kernel_inputs import ( - ConvKernelInputs, - MMKernelInputs, - SerializableValue, -) +from torch._inductor.kernel_inputs import MMKernelInputs from torch._inductor.lookup_table.choices import LookupTableChoices from torch._inductor.select_algorithm import ( add_preprocessing_fn, @@ -58,7 +54,7 @@ class MockMMKernelInputs(MMKernelInputs): def __init__( self, tensors: list[torch.Tensor], - scalars: Optional[dict[str, SerializableValue]] = None, + scalars: Optional[dict[str, Union[float, int]]] = None, mat1_idx: int = -2, mat2_idx: int = -1, ): @@ -84,37 +80,6 @@ def device_type(self) -> Optional[str]: return self.tensors[0].device.type -class MockConvKernelInputs(ConvKernelInputs): - """Mock ConvKernelInputs that subclasses the real class and uses real tensors""" - - def __init__( - self, - tensors: list[torch.Tensor], - scalars: Optional[dict[str, SerializableValue]] = None, - x_idx: int = 0, - weight_idx: int = 1, - bias_idx: Optional[int] = None, - ): - """Initialize with real tensors, creating mock nodes for the base class""" - mock_nodes = [MockTensorNode(t) for t in tensors] - super().__init__( - mock_nodes, scalars, x_idx=x_idx, weight_idx=weight_idx, bias_idx=bias_idx - ) - self.tensors = tensors # Keep reference to original tensors - - def shapes_hinted(self) -> tuple[tuple[int, ...], ...]: - """Delegate to symbolic since real tensors already have int shapes""" - return self.shapes_symbolic() - - def strides_hinted(self) -> tuple[tuple[int, ...], ...]: - """Delegate to symbolic since real tensors already have int strides""" - return self.strides_symbolic() # pyre-ignore - - @property - def device_type(self) -> Optional[str]: - return self.tensors[0].device.type - - class BaseLookupTableTest(TestCase): """Base class for lookup table tests with common setup and utilities""" @@ -138,7 +103,7 @@ def create_mock_mm_kernel_inputs( shapes: Optional[list[tuple[int, ...]]] = None, device: torch.device = torch.device("cuda"), dtype: torch.dtype = torch.float32, - scalars: Optional[dict[str, SerializableValue]] = None, + scalars: Optional[dict[str, Union[float, int]]] = None, ) -> MockMMKernelInputs: """Create MockMMKernelInputs with real tensors""" if shapes is None: @@ -1090,119 +1055,6 @@ def test_template_hash_filtering_e2e(self): with patch.object(inductor_config.lookup_table, "check_src_hash", True): self.run_model("mm", tensors) - @fresh_cache() - def test_conv2d_lookup_table_entry_e2e(self): - """Test end-to-end conv2d with lookup table entry - verifies config is picked up and produces valid results""" - import torch._inductor.kernel.conv - - # Create input tensors with specific shapes for conv2d - # Input: [batch=2, in_channels=3, height=32, width=32] - # Weight: [out_channels=64, in_channels=3, kernel_h=3, kernel_w=3] - # Make them channels-last to match what conv lowering uses - x = torch.randn(2, 3, 32, 32, device=self.device, dtype=torch.float16).to( - memory_format=torch.channels_last - ) - weight = torch.randn(64, 3, 3, 3, device=self.device, dtype=torch.float16).to( - memory_format=torch.channels_last - ) - - # Define conv parameters - use these SAME values everywhere - stride = (1, 1) - padding = (1, 1) - dilation = (1, 1) - groups = 1 - - # Create MockConvKernelInputs using the SAME tensors and SAME scalar values - mock_scalars = { - "stride": stride, - "padding": padding, - "dilation": dilation, - "transposed": False, - "output_padding": (0, 0), - "groups": groups, - } - mock_kernel_inputs = MockConvKernelInputs([x, weight], mock_scalars) - - # Create lookup key for "convolution" operation - choices_handler = LookupTableChoices() - lookup_key = choices_handler.make_lookup_key(mock_kernel_inputs, "convolution") - - # Get the exact template UID from conv2d_template - template_uid = torch._inductor.kernel.conv.conv2d_template.uid - - # Create a precisely configured conv2d config - # IMPORTANT: Only include per-config tunable parameters! - # Static parameters (KERNEL_H, STRIDE_H, GROUPS, UNROLL, ALLOW_TF32) are - # automatically generated by get_extra_kwargs() and should NOT be in the lookup table - conv2d_config = { - "template_id": template_uid, - # Per-config tunable parameters only (what you'd tune via autotuning) - "BLOCK_M": 64, - "BLOCK_N": 64, - "BLOCK_K": 32, - "num_stages": 2, - "num_warps": 4, - } - - # Setup lookup table - inductor_config.lookup_table.table = {lookup_key: [conv2d_config]} - - def validate_conv_choice(choices): - assert len(choices) == 1, ( - f"Expected 1 choice from lookup table, got {len(choices)}" - ) - assert isinstance(choices[0], TritonTemplateCaller), ( - f"Expected TritonTemplateCaller, got {type(choices[0])}" - ) - assert "convolution2d" in choices[0].name, ( - f"Expected 'convolution2d' in name, got {choices[0].name}" - ) - return choices - - add_preprocessing_fn(validate_conv_choice) - - # Create and compile the model using the SAME weight tensor - class SimpleConv2d(nn.Module): - def __init__(self, weight): - super().__init__() - self.register_buffer("weight", weight) - - def forward(self, x): - return torch.conv2d( - x, - self.weight, - bias=None, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) - - model = SimpleConv2d(weight).to(self.device) - - with inductor_config.patch({"max_autotune": True, "max_autotune_gemm": True}): - compiled_model = torch.compile(model) - result = compiled_model(x) # Use the SAME x tensor - - # Output shape: [batch=2, out_channels=64, out_h=32, out_w=32] - # (same spatial dims due to padding=1, stride=1, kernel=3) - expected_shape = (2, 64, 32, 32) - self.assertEqual( - result.shape, - expected_shape, - f"Expected shape {expected_shape}, got {result.shape}", - ) - - self.assertFalse( - torch.isnan(result).any().item(), - "Output contains NaN values", - ) - - self.assertFalse( - torch.isinf(result).any().item(), - "Output contains Inf values", - ) - if __name__ == "__main__": from torch._inductor.utils import is_big_gpu diff --git a/torch/_inductor/kernel/conv.py b/torch/_inductor/kernel/conv.py index 2179364c7d0c2..8e5a2aa09d4ea 100644 --- a/torch/_inductor/kernel/conv.py +++ b/torch/_inductor/kernel/conv.py @@ -8,7 +8,6 @@ from torch._inductor.codegen.rocm.ck_conv_template import CKGroupedConvFwdTemplate from .. import config, ir -from ..kernel_inputs import ConvKernelInputs from ..lowering import ( add_layout_constraint, constrain_to_fx_strides, @@ -17,9 +16,7 @@ ) from ..select_algorithm import ( autotune_select_algorithm, - ChoiceCaller, ExternKernelChoice, - KernelTemplate, SymbolicGridFn, TritonTemplate, ) @@ -545,40 +542,34 @@ def channels_last_conv(): x = ir.ExternKernel.require_stride_order(x, req_stride_order) # type: ignore[assignment] weight = ir.ExternKernel.require_stride_order(weight, req_stride_order) # type: ignore[assignment] - # Create ConvKernelInputs for unified template configuration - # Only include bias in input_nodes when it's not None - # - For Triton templates: bias is always None here (peeled off earlier), so input_nodes = [x, weight] - # - For ATEN: input_nodes = [x, weight] when bias is None, [x, weight, bias] when bias is present - if bias is not None: + ordered_kwargs_for_cpp_kernel = [ + "stride", + "padding", + "dilation", + "transposed", + "output_padding", + "groups", + ] + if bias is None: + args = [x, weight] + kwargs["bias"] = None # type: ignore[typeddict-unknown-key] + ordered_kwargs_for_cpp_kernel.insert(0, "bias") + else: + args = [x, weight, bias] bias.realize() bias.freeze_layout() V.graph.sizevars.guard_int_seq(bias.get_size()) - input_nodes = [x, weight, bias] - bias_idx = 2 - else: - input_nodes = [x, weight] - bias_idx = None - - kernel_inputs = ConvKernelInputs( - input_nodes, - scalars={ - "stride": stride, - "padding": padding, - "dilation": dilation, - "transposed": transposed, - "output_padding": output_padding, - "groups": groups, - }, - x_idx=0, - weight_idx=1, - bias_idx=bias_idx, - ) - - # Build list of templates to try - templates: list[ExternKernelChoice | KernelTemplate] = [] + choices = [] if torch._inductor.utils._use_conv_autotune_backend("ATEN"): - templates.append(aten_convolution) + choices = [ + aten_convolution.bind( + args, + layout, + ordered_kwargs_for_cpp_kernel, + **kwargs, + ) + ] if ( torch._inductor.utils._use_conv_autotune_backend("TRITON") @@ -596,23 +587,60 @@ def channels_last_conv(): and is_zeros(padding) and groups == 1 ): - templates.append(aten_conv1x1_via_mm) - - # Add appropriate template based on ndim - if ndim == 2: - templates.append(conv2d_template) - elif ndim == 3: - templates.append(conv3d_template) - - # Initialize choices list and extend with template configs - choices: list[ChoiceCaller] = [] - choices.extend( - V.choices.get_template_configs( - kernel_inputs, - templates, - "convolution", - ) - ) + choices.append(aten_conv1x1_via_mm.bind(args, layout)) + + conv_configs = V.choices.get_conv_configs(device_type) + + dtype_size = x.get_dtype().itemsize + for cfg in conv_configs( + sympy_product([x.get_size()[0], *x.get_size()[2:]]), + out_chan, + in_chan, + dtype_size=dtype_size, + ): + if ndim == 2: + conv2d_template.maybe_append_choice( + choices, + input_nodes=(x, weight), + layout=layout, + KERNEL_H=kernel_shape[0], + KERNEL_W=kernel_shape[1], + STRIDE_H=stride[0], + STRIDE_W=stride[1], + PADDING_H=padding[0], + PADDING_W=padding[1], + GROUPS=groups, + # TODO(jansel): try unroll for bigger kernels once fixed: + # https://github.com/triton-lang/triton/issues/1254 + UNROLL=is_ones(kernel_shape), + ALLOW_TF32=torch.backends.cudnn.allow_tf32, + num_stages=cfg.num_stages, + num_warps=cfg.num_warps, + **cfg.kwargs, + ) + elif ndim == 3: + conv3d_template.maybe_append_choice( + choices, + input_nodes=(x, weight), + layout=layout, + KERNEL_D=kernel_shape[0], + KERNEL_H=kernel_shape[1], + KERNEL_W=kernel_shape[2], + STRIDE_D=stride[0], + STRIDE_H=stride[1], + STRIDE_W=stride[2], + PADDING_D=padding[0], + PADDING_H=padding[1], + PADDING_W=padding[2], + GROUPS=groups, + # TODO(jansel): try unroll for bigger kernels once fixed: + # https://github.com/triton-lang/triton/issues/1254 + UNROLL=is_ones(kernel_shape), + ALLOW_TF32=torch.backends.cudnn.allow_tf32, + num_stages=cfg.num_stages, + num_warps=cfg.num_warps, + **cfg.kwargs, + ) if use_ck_conv_template(layout): CKGroupedConvFwdTemplate.add_ck_conv_choices( choices, @@ -624,9 +652,7 @@ def channels_last_conv(): groups=groups, n_spatial_dimensions=ndim, ) - return autotune_select_algorithm( - "convolution", choices, kernel_inputs.nodes(), layout - ) + return autotune_select_algorithm("convolution", choices, args, layout) @register_lowering(aten._convolution) diff --git a/torch/_inductor/kernel_inputs.py b/torch/_inductor/kernel_inputs.py index 9e585a4880106..c579cf7565772 100644 --- a/torch/_inductor/kernel_inputs.py +++ b/torch/_inductor/kernel_inputs.py @@ -1,7 +1,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Sequence from typing import Any, Optional, TYPE_CHECKING, Union import torch @@ -13,11 +12,9 @@ if TYPE_CHECKING: - import sympy + from collections.abc import Sequence -# Type aliases for serializable scalar values -Serializable = Union[int, float, bool] -SerializableValue = Union[Serializable, Sequence[Serializable]] + import sympy class KernelInputs(ABC): @@ -30,7 +27,7 @@ class KernelInputs(ABC): def __init__( self, input_nodes: list[Any], - scalars: Optional[dict[str, SerializableValue]] = None, + scalars: Optional[dict[str, Union[float, int]]] = None, out_dtype: Optional[torch.dtype] = None, ): """ @@ -186,7 +183,7 @@ def out_dtype(self) -> torch.dtype: The output dtype """ - def get_scalar(self, name: str) -> SerializableValue: + def get_scalar(self, name: str) -> Union[float, int]: """ Get the scalar value for a given name. @@ -194,7 +191,7 @@ def get_scalar(self, name: str) -> SerializableValue: name: Name of the scalar to get Returns: - The scalar value (can be int, float, bool, or tuple of these types) + The scalar value """ assert name in self._scalars, f"Scalar {name} not found, but required" return self._scalars[name] @@ -219,7 +216,7 @@ class MMKernelInputs(KernelInputs): def __init__( self, input_nodes: list[Any], - scalars: Optional[dict[str, SerializableValue]] = None, + scalars: Optional[dict[str, Union[float, int]]] = None, out_dtype: Optional[torch.dtype] = None, mat1_idx: int = -2, mat2_idx: int = -1, @@ -339,113 +336,3 @@ def mnk_hinted(self) -> tuple[int, int, int]: assert k == k_check, f"K dimensions don't match: {k} vs {k_check}" return (m, n, k) - - -class ConvKernelInputs(KernelInputs): - """ - Specialized KernelInputs for convolution operations. - Stores input tensor, weight tensor, and optional bias, along with conv parameters. - """ - - def __init__( - self, - input_nodes: list[Any], - scalars: Optional[dict[str, SerializableValue]] = None, - out_dtype: Optional[torch.dtype] = None, - x_idx: int = 0, - weight_idx: int = 1, - bias_idx: Optional[int] = None, - ): - """ - Initialize with convolution input nodes. - - Args: - input_nodes: List containing [x, weight] or [x, weight, bias] - scalars: Dict with conv params (stride, padding, dilation, groups, transposed, output_padding) - out_dtype: Optional output dtype - x_idx: Index of input tensor (default: 0) - weight_idx: Index of weight tensor (default: 1) - bias_idx: Index of bias tensor if present (default: None) - """ - super().__init__(input_nodes, scalars, out_dtype) - assert len(input_nodes) >= 2, "Expected at least 2 input nodes (x, weight)" - - self._x_idx = x_idx - self._weight_idx = weight_idx - self._bias_idx = bias_idx - - # Validate that required scalars are present - required_scalars = [ - "stride", - "padding", - "dilation", - "transposed", - "output_padding", - "groups", - ] - for key in required_scalars: - assert key in self._scalars, f"Conv requires scalar '{key}'" - - def out_dtype(self) -> torch.dtype: - """ - Get the output dtype, whether passed in or inferred from the nodes - - Returns: - The output dtype - """ - if self._out_dtype is not None: - return self._out_dtype - return self._input_nodes[self._x_idx].get_dtype() - - def output_layout(self, flexible: bool = True) -> Layout: - """ - Handle output layout generation for convolution. - - Args: - flexible: If True, return FlexibleLayout, otherwise FixedLayout - - Returns: - Layout for the convolution output - """ - from torch._inductor.kernel.conv import conv_layout - - x = self._input_nodes[self._x_idx] - weight = self._input_nodes[self._weight_idx] - bias = self._input_nodes[self._bias_idx] if self._bias_idx is not None else None - - # Use existing conv_layout function - # We know the types here because conv requires these specific scalar types - layout = conv_layout( - x, - weight, - bias, - self._scalars["stride"], # type: ignore[arg-type] - self._scalars["padding"], # type: ignore[arg-type] - self._scalars["dilation"], # type: ignore[arg-type] - self._scalars["transposed"], # type: ignore[arg-type] - self._scalars["output_padding"], # type: ignore[arg-type] - self._scalars["groups"], # type: ignore[arg-type] - ) - - # TODO: Handle flexible vs fixed based on config if needed - return layout - - def get_x_weight_bias(self) -> tuple[Any, Any, Optional[Any]]: - """ - Get x, weight, and optional bias nodes. - - Returns: - Tuple of (x, weight, bias) where bias may be None - """ - bias = self._input_nodes[self._bias_idx] if self._bias_idx is not None else None - return self._input_nodes[self._x_idx], self._input_nodes[self._weight_idx], bias - - def spatial_dims(self) -> tuple[Any, ...]: - """ - Get spatial dimensions from input tensor (H, W for 2D, D, H, W for 3D). - - Returns: - Tuple of spatial dimension sizes - """ - x_shape = self._input_nodes[self._x_idx].get_size() - return x_shape[2:] # Skip batch and channel dims diff --git a/torch/_inductor/template_heuristics/__init__.py b/torch/_inductor/template_heuristics/__init__.py index 8b980816c56dc..eb3d731525ea8 100644 --- a/torch/_inductor/template_heuristics/__init__.py +++ b/torch/_inductor/template_heuristics/__init__.py @@ -1,6 +1,6 @@ # NOTE: add new template heuristics here, so they get imported and registered # TODO: write a simple glob if there are many heuristics to auto import them in the right order -from . import aten, base, contiguous_mm, conv, decompose_k, registry, triton +from . import aten, base, contiguous_mm, decompose_k, registry, triton # expose the entry function from .registry import get_template_heuristic diff --git a/torch/_inductor/template_heuristics/conv.py b/torch/_inductor/template_heuristics/conv.py deleted file mode 100644 index 7333b5a679bd8..0000000000000 --- a/torch/_inductor/template_heuristics/conv.py +++ /dev/null @@ -1,287 +0,0 @@ -from __future__ import annotations - -from typing import Any, cast, TYPE_CHECKING - -import torch - -from ..kernel.conv import aten_convolution, conv2d_template, conv3d_template -from ..kernel_inputs import ConvKernelInputs -from ..utils import is_ones, sympy_product -from ..virtualized import V -from .base import TemplateConfigHeuristics -from .registry import register_template_heuristic -from .triton import ( - CPUConfigHeuristic, - CUDAConfigHeuristic, - MTIAConfigHeuristic, - ROCmConfigHeuristic, - XPUConfigHeuristic, -) - - -if TYPE_CHECKING: - from collections.abc import Generator - - from ..kernel_inputs import KernelInputs - - -class ConvTemplateConfigMixin(TemplateConfigHeuristics): - """ - Mixin for conv templates that converts config lists to template kwargs. - Similar to MMTemplateConfigMixin but for convolutions. - - This handles generating both the static template kwargs (KERNEL_H, STRIDE_H, etc.) - and the per-config kwargs (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps). - """ - - # Type hint for methods from BaseConfigHeuristic - get_conv_configs: Any - - def get_extra_kwargs( - self, - kernel_inputs: KernelInputs, - op_name: str, - ) -> dict[str, Any]: - """ - Return template kwargs that don't change per-config. - These are derived from kernel_inputs and must include all template parameters. - - Args: - kernel_inputs: ConvKernelInputs containing input tensors and conv params - op_name: Operation name (e.g., "convolution") - - Returns: - Dict of static template kwargs (KERNEL_H, STRIDE_H, GROUPS, etc.) - """ - assert isinstance(kernel_inputs, ConvKernelInputs), ( - f"ConvTemplateConfigMixin requires ConvKernelInputs, got {type(kernel_inputs)}" - ) - - x, weight, bias = kernel_inputs.get_x_weight_bias() - - # Extract kernel shape from weight: [out_chan, in_chan, *kernel_shape] - weight_size = V.graph.sizevars.guard_int_seq(weight.get_size()) - kernel_shape = weight_size[2:] # Skip out_chan, in_chan - ndim = len(kernel_shape) - - # Extract scalars - stride = cast(tuple[int, ...], kernel_inputs.get_scalar("stride")) - padding = cast(tuple[int, ...], kernel_inputs.get_scalar("padding")) - groups = cast(int, kernel_inputs.get_scalar("groups")) - - # Check if we should unroll (only for 1x1 kernels) - unroll = is_ones(kernel_shape) - - # Build kwargs dict based on ndim - kwargs: dict[str, Any] = { - "GROUPS": groups, - "UNROLL": unroll, - "ALLOW_TF32": torch.backends.cudnn.allow_tf32, - } - - if ndim == 2: - kwargs.update( - { - "KERNEL_H": kernel_shape[0], - "KERNEL_W": kernel_shape[1], - "STRIDE_H": stride[0], - "STRIDE_W": stride[1], - "PADDING_H": padding[0], - "PADDING_W": padding[1], - } - ) - elif ndim == 3: - kwargs.update( - { - "KERNEL_D": kernel_shape[0], - "KERNEL_H": kernel_shape[1], - "KERNEL_W": kernel_shape[2], - "STRIDE_D": stride[0], - "STRIDE_H": stride[1], - "STRIDE_W": stride[2], - "PADDING_D": padding[0], - "PADDING_H": padding[1], - "PADDING_W": padding[2], - } - ) - - return kwargs - - def _get_template_configs_impl( - self, - kernel_inputs: KernelInputs, - op_name: str, - ) -> Generator[dict[str, Any], None, None]: - """ - Yield per-config kwargs (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps). - - Args: - kernel_inputs: ConvKernelInputs containing input tensors - op_name: Operation name - - Yields: - Dict of per-config kwargs for each configuration to try - """ - assert isinstance(kernel_inputs, ConvKernelInputs), ( - "ConvTemplateConfigMixin requires ConvKernelInputs" - ) - - x, weight, bias = kernel_inputs.get_x_weight_bias() - - # Calculate dimensions for heuristics - weight_size = weight.get_size() - out_chan = weight_size[0] - in_chan = weight_size[1] - - # Batch * spatial dimensions product - x_size = x.get_size() - batch_spatial_product = sympy_product([x_size[0], *x_size[2:]]) - - # Get conv config generator from self (which is a BaseConfigHeuristic subclass) - conv_configs_generator = self.get_conv_configs() - - dtype_size = x.get_dtype().itemsize - - # Generate configs (reusing mm preprocess_mm_configs machinery) - for c in conv_configs_generator( - batch_spatial_product, - out_chan, - in_chan, - dtype_size=dtype_size, - op_name="conv", - ): - # Yield per-config kwargs - yield { - "BLOCK_M": c.kwargs.get("BLOCK_M"), - "BLOCK_N": c.kwargs.get("BLOCK_N"), - "BLOCK_K": c.kwargs.get("BLOCK_K"), - "num_stages": c.num_stages, - "num_warps": c.num_warps, - } - - -# ATEN convolution heuristic (no per-config tuning) -@register_template_heuristic(aten_convolution.uid, None) -class ATenConvConfigHeuristic(TemplateConfigHeuristics): - """ - Pseudo heuristic for ATen convolution. - ATen doesn't have configs to tune - it's a single choice. - """ - - def _get_template_configs_impl( - self, - kernel_inputs: KernelInputs, - op_name: str, - ) -> Generator[dict[str, Any], None, None]: - # ATen doesn't have per-config kwargs to tune - yield dict() - - def get_extra_kwargs( - self, - kernel_inputs: KernelInputs, - op_name: str, - ) -> dict[str, Any]: - """ - ATen gets stride, padding, etc. as ordered kwargs for the C++ kernel. - """ - assert isinstance(kernel_inputs, ConvKernelInputs) - - # Extract scalar values from kernel_inputs - stride = cast(tuple[int, ...], kernel_inputs.get_scalar("stride")) - padding = cast(tuple[int, ...], kernel_inputs.get_scalar("padding")) - dilation = cast(tuple[int, ...], kernel_inputs.get_scalar("dilation")) - transposed = cast(bool, kernel_inputs.get_scalar("transposed")) - output_padding = cast( - tuple[int, ...], kernel_inputs.get_scalar("output_padding") - ) - groups = cast(int, kernel_inputs.get_scalar("groups")) - - # Check if bias is None to match old behavior - # When bias is None: input_nodes = [x, weight], add 'bias' to kwargs and ordered list - # When bias is present: input_nodes = [x, weight, bias], don't add 'bias' to kwargs - x, weight, bias = kernel_inputs.get_x_weight_bias() - - kwargs: dict[str, Any] = { - "stride": stride, - "padding": padding, - "dilation": dilation, - "transposed": transposed, - "output_padding": output_padding, - "groups": groups, - } - - if bias is None: - # When bias is None, torch.convolution expects it as a kwarg - kwargs["bias"] = None - kwargs["ordered_kwargs_for_cpp_kernel"] = [ - "bias", - "stride", - "padding", - "dilation", - "transposed", - "output_padding", - "groups", - ] - else: - # When bias is present, it's passed as a positional arg (3rd in input_nodes) - kwargs["ordered_kwargs_for_cpp_kernel"] = [ - "stride", - "padding", - "dilation", - "transposed", - "output_padding", - "groups", - ] - - return kwargs - - -# CUDA Conv2D/Conv3D heuristics -@register_template_heuristic( - conv2d_template.uid, - "cuda", - register=torch.version.hip is None, -) -@register_template_heuristic( - conv3d_template.uid, - "cuda", - register=torch.version.hip is None, -) -class CUDAConvTemplateConfigHeuristic(ConvTemplateConfigMixin, CUDAConfigHeuristic): - """Conv template heuristic for CUDA.""" - - -# ROCm Conv2D/Conv3D heuristics -@register_template_heuristic( - conv2d_template.uid, - "cuda", - register=torch.version.hip is not None, -) -@register_template_heuristic( - conv3d_template.uid, - "cuda", - register=torch.version.hip is not None, -) -class ROCmConvTemplateConfigHeuristic(ConvTemplateConfigMixin, ROCmConfigHeuristic): - """Conv template heuristic for ROCm.""" - - -# CPU Conv2D/Conv3D heuristics -@register_template_heuristic(conv2d_template.uid, "cpu") -@register_template_heuristic(conv3d_template.uid, "cpu") -class CPUConvTemplateConfigHeuristic(ConvTemplateConfigMixin, CPUConfigHeuristic): - """Conv template heuristic for CPU.""" - - -# XPU Conv2D/Conv3D heuristics -@register_template_heuristic(conv2d_template.uid, "xpu") -@register_template_heuristic(conv3d_template.uid, "xpu") -class XPUConvTemplateConfigHeuristic(ConvTemplateConfigMixin, XPUConfigHeuristic): - """Conv template heuristic for XPU.""" - - -# MTIA Conv2D/Conv3D heuristics -@register_template_heuristic(conv2d_template.uid, "mtia") -@register_template_heuristic(conv3d_template.uid, "mtia") -class MTIAConvTemplateConfigHeuristic(ConvTemplateConfigMixin, MTIAConfigHeuristic): - """Conv template heuristic for MTIA.""" From 0ea545b24168fb9f127704640f34f47b033394cd Mon Sep 17 00:00:00 2001 From: zhangfei Date: Thu, 20 Nov 2025 22:46:22 +0000 Subject: [PATCH 121/230] Add support to enable the oneDNN backend for RISC-V (#166602) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description Currently, oneDNN can be successfully compiled and run on RISC-V : [oneDNN riscv](https://github.com/uxlfoundation/oneDNN/pull/2929) , and it has also introduced extensive RVV optimizations: [oneDNN/src/cpu/rv64](https://github.com/uxlfoundation/oneDNN/tree/main/src/cpu/rv64) . Therefore, this PR adds support for enabling the oneDNN backend on the RISC-V architecture. Although the current dependency package ideep uses a relatively lower version of oneDNN. and not yet sufficient for successful compilation, this is only a matter of time—once ideep is updated, compilation will work as expected. Below are my test results on RISC-V: ```bash $ uname -m riscv64 $ cat test.py import torch print(torch.backends.mkldnn.is_available()) $ python3 test.py True ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166602 Approved by: https://github.com/malfet --- CMakeLists.txt | 7 +++++-- cmake/Modules/FindMKLDNN.cmake | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f1d391ab6dbf9..877ed9fafd3b1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -171,6 +171,7 @@ endif() set(CPU_AARCH64 OFF) set(CPU_INTEL OFF) set(CPU_POWER OFF) +set(CPU_RISCV OFF) if(CMAKE_SYSTEM_PROCESSOR MATCHES "(AMD64|x86_64)") set(CPU_INTEL ON) @@ -178,6 +179,8 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64)") set(CPU_AARCH64 ON) elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(ppc64le)") set(CPU_POWER ON) +elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(riscv64)") + set(CPU_RISCV ON) endif() # For non-supported platforms, turn USE_DISTRIBUTED off by default. It is not @@ -327,8 +330,8 @@ cmake_dependent_option(USE_ITT "Use Intel(R) VTune Profiler ITT functionality" # Ensure that an MKLDNN build is the default for x86 CPUs but optional for # AArch64 (dependent on -DUSE_MKLDNN). cmake_dependent_option( - USE_MKLDNN "Use MKLDNN. Only available on x86, x86_64, AArch64, and ppc64le." - "${CPU_INTEL}" "CPU_INTEL OR CPU_AARCH64 OR CPU_POWER" OFF) + USE_MKLDNN "Use MKLDNN. Only available on x86, x86_64, AArch64, ppc64le and riscv64." + "${CPU_INTEL}" "CPU_INTEL OR CPU_AARCH64 OR CPU_POWER OR CPU_RISCV" OFF) cmake_dependent_option( USE_MKLDNN_ACL "Use Compute Library for the Arm architecture." OFF "USE_MKLDNN AND CPU_AARCH64" OFF) diff --git a/cmake/Modules/FindMKLDNN.cmake b/cmake/Modules/FindMKLDNN.cmake index 2018d5ec9370b..0349b09119cae 100644 --- a/cmake/Modules/FindMKLDNN.cmake +++ b/cmake/Modules/FindMKLDNN.cmake @@ -85,7 +85,7 @@ IF(NOT MKLDNN_FOUND) ENDIF(NOT APPLE AND NOT WIN32 AND NOT BUILD_LITE_INTERPRETER) IF(EXISTS "${MKLDNN_ROOT}/include/oneapi/dnnl/dnnl_ukernel.hpp") - IF(CPU_POWER) + IF(CPU_POWER OR CPU_RISCV) SET(DNNL_EXPERIMENTAL_UKERNEL OFF CACHE BOOL "" FORCE) ELSE() MESSAGE("-- Will build oneDNN UKERNEL") From c81f6968c5aeace0c954f2dd67f715260c3b5fe7 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Thu, 20 Nov 2025 11:16:14 -0800 Subject: [PATCH 122/230] Skip _assert_scalar in default partitioner (#168289) Fixes https://github.com/pytorch/torchtitan/issues/2069 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168289 Approved by: https://github.com/yiming0416 --- torch/_functorch/partitioners.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index f22b274be41ab..e7c665d8df9d1 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -1109,6 +1109,8 @@ def is_impure(node): for node in joint_module.graph.nodes: if node.name not in forward_node_names: continue + if node.target is torch.ops.aten._assert_scalar.default: + continue if is_sym_node(node): # Symints must be kept separate from tensors so that PythonFunction only calls # save_for_backward on tensors and stashes symints in autograd .ctx From a64613ab6f4479358ea07a09b72056570a6a67c8 Mon Sep 17 00:00:00 2001 From: AI Date: Thu, 20 Nov 2025 22:49:45 +0000 Subject: [PATCH 123/230] [doc] README add cmake prefix for non-conda env (#167714) Add the CMAKE_PREFIX_PATH for non-conda Python env Pull Request resolved: https://github.com/pytorch/pytorch/pull/167714 Approved by: https://github.com/XuehaiPan, https://github.com/albanD --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index a0c9b54c95a8b..c2f15d88738da 100644 --- a/README.md +++ b/README.md @@ -292,8 +292,13 @@ python tools/amd_build/build_amd.py Install PyTorch ```bash +# the CMake prefix for conda environment export CMAKE_PREFIX_PATH="${CONDA_PREFIX:-'$(dirname $(which conda))/../'}:${CMAKE_PREFIX_PATH}" python -m pip install --no-build-isolation -v -e . + +# the CMake prefix for non-conda environment, e.g. Python venv +# call following after activating the venv +export CMAKE_PREFIX_PATH="${VIRTUAL_ENV}:${CMAKE_PREFIX_PATH}" ``` **On macOS** From ddde4b771bcb2e953592250796ad6ba7d3da89c9 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Thu, 20 Nov 2025 11:57:20 -0800 Subject: [PATCH 124/230] [user-streams] Refactor out event insertion for record_stream handling (#168228) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168228 Approved by: https://github.com/williamwen42, https://github.com/anijain2305 --- test/dynamo/test_streams.py | 2 +- torch/_functorch/_aot_autograd/streams.py | 50 ++++++++++++++--------- 2 files changed, 31 insertions(+), 21 deletions(-) diff --git a/test/dynamo/test_streams.py b/test/dynamo/test_streams.py index 967bedb9ebaae..7a40ae926a527 100644 --- a/test/dynamo/test_streams.py +++ b/test/dynamo/test_streams.py @@ -27,7 +27,7 @@ def remove_file_comment(gm_str: str) -> str: def print_graph(graph: torch.fx.GraphModule) -> str: - return remove_file_comment(graph.print_readable()) + return remove_file_comment(graph.print_readable(print_output=False)) class TestStreams(torch._dynamo.test_case.TestCase): diff --git a/torch/_functorch/_aot_autograd/streams.py b/torch/_functorch/_aot_autograd/streams.py index 1b4f5ded051e3..1fc8a965740fd 100644 --- a/torch/_functorch/_aot_autograd/streams.py +++ b/torch/_functorch/_aot_autograd/streams.py @@ -7,6 +7,7 @@ Node: TypeAlias = torch.fx.Node +Graph: TypeAlias = torch.fx.Graph def is_gradient_acc(node: Node) -> bool: @@ -43,8 +44,32 @@ def set_stream(node: Node, ind: int) -> None: node.meta["custom"] = {"stream": ind} +def insert_record_event_after_node(graph: Graph, node: Node, event_ind: int) -> None: + with graph.inserting_after(node): + node = graph.call_function( + torch.ops.streams.record_event.default, + ( + event_ind, + get_stream_or_current_stream(node), + ), + ) + node.meta["partitioner_tag"] = "must_be_in_backward" + + +def insert_wait_event_before_node(graph: Graph, node: Node, event_ind: int) -> None: + with graph.inserting_before(node): + node = graph.call_function( + torch.ops.streams.wait_event.default, + ( + event_ind, + get_stream_or_current_stream(node), + ), + ) + node.meta["partitioner_tag"] = "must_be_in_backward" + + def insert_sync( - graph: torch.fx.Graph, + graph: Graph, consumer: Node, producer: Node, node_to_wait_event_ind: dict[Node, int], @@ -52,25 +77,10 @@ def insert_sync( if producer not in node_to_wait_event_ind: node_to_wait_event_ind[producer] = new_event() - with graph.inserting_after(producer): - node = graph.call_function( - torch.ops.streams.record_event.default, - ( - node_to_wait_event_ind[producer], - get_stream_or_current_stream(producer), - ), - ) - node.meta["partitioner_tag"] = "must_be_in_backward" - - with graph.inserting_before(consumer): - node = graph.call_function( - torch.ops.streams.wait_event.default, - ( - node_to_wait_event_ind[producer], - get_stream_or_current_stream(consumer), - ), - ) - node.meta["partitioner_tag"] = "must_be_in_backward" + insert_record_event_after_node( + graph, producer, node_to_wait_event_ind[producer] + ) + insert_wait_event_before_node(graph, consumer, node_to_wait_event_ind[producer]) def assign_backward_streams(gm: torch.fx.GraphModule) -> None: From 2ca51b7a544cb05c95cad8bb5d2511900bf11f40 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Thu, 20 Nov 2025 11:57:20 -0800 Subject: [PATCH 125/230] [user-streams] Refactor runtime estimation to reuse internal functions (#168229) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168229 Approved by: https://github.com/sanketpurandare ghstack dependencies: #168228 --- torch/distributed/_tools/runtime_estimator.py | 145 +++++++++--------- 1 file changed, 72 insertions(+), 73 deletions(-) diff --git a/torch/distributed/_tools/runtime_estimator.py b/torch/distributed/_tools/runtime_estimator.py index b897e51cac9f3..bee54e0454d5d 100644 --- a/torch/distributed/_tools/runtime_estimator.py +++ b/torch/distributed/_tools/runtime_estimator.py @@ -80,6 +80,78 @@ __all__ = ["RuntimeEstimator"] +def get_compute_time(func_packet, args, kwargs, out, out_dtypes) -> float: # type: ignore[no-untyped-def] + """ + Estimates the compute time of an aten operator. + + Args: + func_packet: The operator overload packet. + args: The arguments to the operator. + kwargs: The keyword arguments to the operator. + out: The output of the operator. + out_dtypes: The output data types. + + Returns: + float: The estimated compute time in nanoseconds. + """ + if func_packet in flop_registry: + assert len(out_dtypes) == 1, ( + f"Only support single out dtype got {out_dtypes} for {func_packet}" + ) + dtype = out_dtypes.pop() + # This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s + peak_gpu_flops = get_device_tflops(dtype) * 1e15 + # We can expect to achieve 75% of theoretical peak flops + factor = 0.75 + peak_empirical_flops = factor * peak_gpu_flops + flop_count_func = flop_registry[func_packet] + # We divide by a factor of 2 to get the MACs (multiply and accumulate) + flop_count = flop_count_func(*args, **kwargs, out_val=out) / 2 + # We multiply by 1e9 to get the time in nano seconds + compute_time = (flop_count / peak_empirical_flops) * 1e9 + return compute_time + return 0.0 + + +def get_num_bytes(t: torch.Tensor) -> int: + """ + Calculates the memory consumption of a tensor. + + Args: + t (torch.Tensor): The input tensor. + + Returns: + int: The memory consumption of the tensor in bytes. + """ + num_bytes = t.untyped_storage().nbytes() + mem_consumed = math.ceil(num_bytes / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE + return mem_consumed + + +def get_transfer_time(flat_args_kwargs, flat_outs) -> float: # type: ignore[no-untyped-def] + """ + Estimates the memory transfer time of input and output tensors. + + Args: + flat_args_kwargs (List[torch.Tensor]): The flat list of arguments and keyword arguments. + flat_outs (List[torch.Tensor]): The flat list of outputs. + + Returns: + float: The estimated memory transfer time in nanoseconds. + """ + gpu_memory_bandwidth = get_gpu_dram_gbps() + read_bytes = sum( + get_num_bytes(t) for t in flat_args_kwargs if isinstance(t, torch.Tensor) + ) + write_bytes = sum( + get_num_bytes(t) for t in flat_outs if isinstance(t, torch.Tensor) + ) + counted_bytes = read_bytes + write_bytes + # The GPU memory bandwidth is in GB/s so the transfer time is in nanoseconds + transfer_time = counted_bytes / gpu_memory_bandwidth + return transfer_time + + class RuntimeEstimator(TorchDispatchMode): """ Estimates the GPU runtime in milliseconds using various estimation methods under the ``FakeTensorMode``. @@ -297,79 +369,6 @@ def _roofline_estimate(cls, func, args, kwargs) -> tuple[Any, float]: # type: i "Roofline estimation needs to access CUDA capabilities to make estimations" ) - def get_num_bytes(t: torch.Tensor) -> int: - """ - Calculates the memory consumption of a tensor. - - Args: - t (torch.Tensor): The input tensor. - - Returns: - int: The memory consumption of the tensor in bytes. - """ - num_bytes = t.untyped_storage().nbytes() - mem_consumed = ( - math.ceil(num_bytes / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE - ) - return mem_consumed - - def get_compute_time(func_packet, args, kwargs, out, out_dtypes) -> float: # type: ignore[no-untyped-def] - """ - Estimates the compute time of an aten operator. - - Args: - func_packet: The operator overload packet. - args: The arguments to the operator. - kwargs: The keyword arguments to the operator. - out: The output of the operator. - out_dtypes: The output data types. - - Returns: - float: The estimated compute time in nanoseconds. - """ - if func_packet in flop_registry: - assert len(out_dtypes) == 1, ( - f"Only support single out dtype got {out_dtypes} for {func_packet}" - ) - dtype = out_dtypes.pop() - # This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s - peak_gpu_flops = get_device_tflops(dtype) * 1e15 - # We can expect to achieve 75% of theoretical peak flops - factor = 0.75 - peak_empirical_flops = factor * peak_gpu_flops - flop_count_func = flop_registry[func_packet] - # We divide by a factor of 2 to get the MACs (multiply and accumulate) - flop_count = flop_count_func(*args, **kwargs, out_val=out) / 2 - # We multiply by 1e9 to get the time in nano seconds - compute_time = (flop_count / peak_empirical_flops) * 1e9 - return compute_time - return 0.0 - - def get_transfer_time(flat_args_kwargs, flat_outs) -> float: # type: ignore[no-untyped-def] - """ - Estimates the memory transfer time of input and output tensors. - - Args: - flat_args_kwargs (List[torch.Tensor]): The flat list of arguments and keyword arguments. - flat_outs (List[torch.Tensor]): The flat list of outputs. - - Returns: - float: The estimated memory transfer time in nanoseconds. - """ - gpu_memory_bandwidth = get_gpu_dram_gbps() - read_bytes = sum( - get_num_bytes(t) - for t in flat_args_kwargs - if isinstance(t, torch.Tensor) - ) - write_bytes = sum( - get_num_bytes(t) for t in flat_outs if isinstance(t, torch.Tensor) - ) - counted_bytes = read_bytes + write_bytes - # The GPU memory bandwidth is in GB/s so the transfer time is in nanoseconds - transfer_time = counted_bytes / gpu_memory_bandwidth - return transfer_time - # Roofline Cost Model Explanation # The roofline cost model estimates the execution time of an operator based on From 31451777c195be25e1e127ca1c6eccd9ae0e45d2 Mon Sep 17 00:00:00 2001 From: Dzmitry Huba Date: Thu, 20 Nov 2025 10:59:33 -0800 Subject: [PATCH 126/230] Fix smoke test failure due to numpy import in Local Tensor (#168271) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168271 Approved by: https://github.com/atalman --- torch/distributed/_local_tensor/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torch/distributed/_local_tensor/__init__.py b/torch/distributed/_local_tensor/__init__.py index dbb0071d86ec7..194127b725fa0 100644 --- a/torch/distributed/_local_tensor/__init__.py +++ b/torch/distributed/_local_tensor/__init__.py @@ -1027,9 +1027,7 @@ def __torch_dispatch__( # type: ignore[override] with LocalTensorMode(local_tensor._ranks): return func(*args, **kwargs) - def numpy( - self, *, force: bool = False - ) -> np.ndarray: # pyrefly: ignore # missing-attribute + def numpy(self, *, force: bool = False) -> Any: if HAS_NUMPY: return self.reconcile().numpy(force=force) else: From 0bd3f5177d4aa38d694a75b15d544971ca2f9bdb Mon Sep 17 00:00:00 2001 From: Arsh Zahed Date: Thu, 20 Nov 2025 23:31:16 +0000 Subject: [PATCH 127/230] [3.14] Fix module 'torch' has no attribute 'f' (#168152) Similar to https://github.com/pytorch/pytorch/commit/4316df857c9e7f301142eb54d06a85a43f8d617b, the pickle `_getattribute` signature changed, but not all uses were caught. This PR applies a similar fix to the mentioned commit, but since we require `parent` instead of `obj` here there is some additional logic for getting the parent. **Test Plan:** Manually ran the following tests with 3.14 ``` 'test/test_package.py::TestImporter::test_package_importer_whichmodule_no_dunder_module', 'test/test_package.py::TestPackageScript::test_load_shared_tensors_repackaged', 'test/test_package.py::TestPackageScript::test_saving_and_scripting_packaged_mod', 'test/test_package.py::TestRepackage::test_repackage_import_indirectly_via_parent_module', 'test/test_package.py::TestSaveLoad::test_exporting_mismatched_code' ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/168152 Approved by: https://github.com/williamwen42, https://github.com/fxdawnn --- torch/package/_package_pickler.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/torch/package/_package_pickler.py b/torch/package/_package_pickler.py index a66c14adfe86f..a4d8e7f752505 100644 --- a/torch/package/_package_pickler.py +++ b/torch/package/_package_pickler.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs # pyrefly: ignore [missing-module-attribute] +import sys from pickle import ( # type: ignore[attr-defined] _compat_pickle, _extension_registry, @@ -64,7 +65,19 @@ def save_global(self, obj, name=None): raise PicklingError(f"Can't pickle {obj}: {str(err)}") from err module = self.importer.import_module(module_name) - _, parent = _getattribute(module, name) + if sys.version_info >= (3, 14): + # pickle._getattribute signature changes in 3.14 + # to take iterable and return just the object (not tuple) + # We need to get the parent object that contains the attribute + name_parts = name.split(".") + if "" in name_parts: + raise PicklingError(f"Can't pickle local object {obj!r}") + if len(name_parts) == 1: + parent = module + else: + parent = _getattribute(module, name_parts[:-1]) + else: + _, parent = _getattribute(module, name) # END CHANGED if self.proto >= 2: # type: ignore[attr-defined] From 247f8228ce8b7c6927a4178ba04a12d2e47b7fe7 Mon Sep 17 00:00:00 2001 From: KarhouTam Date: Thu, 20 Nov 2025 23:48:43 +0000 Subject: [PATCH 128/230] [Fix] Add generator and tensor variant signatures for `rand*_like()` functions (#167824) As the title stated. PR #166160 add generator variants to `rand*_like()` functions, this PR is for finding out staled signatures and fixing them. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167824 Approved by: https://github.com/cyyever, https://github.com/albanD --- test/cpp/jit/test_alias_analysis.cpp | 8 +++++++- torch/csrc/jit/passes/shape_analysis.cpp | 8 +++++++- torch/csrc/utils/schema_info.cpp | 8 +++++++- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/test/cpp/jit/test_alias_analysis.cpp b/test/cpp/jit/test_alias_analysis.cpp index a58ac596d7cb2..5a5ea4a69f7c7 100644 --- a/test/cpp/jit/test_alias_analysis.cpp +++ b/test/cpp/jit/test_alias_analysis.cpp @@ -1669,12 +1669,18 @@ TEST(NonDeterminismBackwardsCompatibility, BackwardsCompatibility) { "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor", "aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", "aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::rand_like.generator(Tensor self, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", "aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", - "aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.generator(Tensor self, int high, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.Tensor(Tensor self, Tensor high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.Tensor_generator(Tensor self, Tensor high, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.low_dtype(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.low_generator_dtype(Tensor self, int low, int high, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", "aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randn_like.generator(Tensor self, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "aten::randperm(int n, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor"}; for (const std::string& op : nondeterministic_ops) { const c10::FunctionSchema& schema = torch::jit::parseSchema(op); diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 2ae32f5fc5082..57dc2552c661c 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -1464,9 +1464,15 @@ class ShapePropagator : public PropertyPropBase { "aten::full_like(Tensor self, Scalar fill_value, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "aten::ones_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::rand_like.generator(Tensor self, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", - "aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.generator(Tensor self, int high, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.Tensor(Tensor self, Tensor high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.Tensor_generator(Tensor self, Tensor high, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.low_dtype(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.low_generator_dtype(Tensor self, int low, int high, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randn_like.generator(Tensor self, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "aten::zeros_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", }, [](Node* node) -> type_vec_t { diff --git a/torch/csrc/utils/schema_info.cpp b/torch/csrc/utils/schema_info.cpp index fb628bec8c654..e3176854f14d2 100644 --- a/torch/csrc/utils/schema_info.cpp +++ b/torch/csrc/utils/schema_info.cpp @@ -250,12 +250,18 @@ std::vector SchemaInfo::getNonDeterministicOps() { "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor", "aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", "aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::rand_like.generator(Tensor self, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", "aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", - "aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.generator(Tensor self, int high, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.Tensor(Tensor self, Tensor high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.Tensor_generator(Tensor self, Tensor high, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.low_dtype(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.low_generator_dtype(Tensor self, int low, int high, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", "aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randn_like.generator(Tensor self, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "aten::randperm(int n, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor"}; std::vector nondeterministic_ops; From 45253402819c28bd64a0069ec23d91318553cf50 Mon Sep 17 00:00:00 2001 From: Arsh Zahed Date: Thu, 20 Nov 2025 23:50:36 +0000 Subject: [PATCH 129/230] [3.14] Use refcount difference for TestNumPyInterop.test_from_numpy_no_leak_on_invalid_dtype (#168191) Updates TestNumPyInterop.test_from_numpy_no_leak_on_invalid_dtype to use difference in refcount instead of a hardcoded value. This is needed because Python 3.14 changed sys.getrefcount behavior with optimizations that returns a lower number of refcounts for inside functions/methods vs module-level. Instead, this PR updates the test to check if there is any change to the number of recounts. **Test Plan:** Ran locally with Python 3.14 ``` python3 test/test_numpy_interop.py TestNumPyInteropCPU.test_from_numpy_no_leak_on_invalid_dtype_cpu ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/168191 Approved by: https://github.com/williamwen42, https://github.com/fxdawnn --- test/test_numpy_interop.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/test_numpy_interop.py b/test/test_numpy_interop.py index c30ace4a70f5f..6ed34f2559a18 100644 --- a/test/test_numpy_interop.py +++ b/test/test_numpy_interop.py @@ -301,12 +301,18 @@ def test_from_numpy_no_leak_on_invalid_dtype(self): # This used to leak memory as the `from_numpy` call raised an exception and didn't decref the temporary # object. See https://github.com/pytorch/pytorch/issues/121138 x = np.array(b"value") + initial_refcount = sys.getrefcount(x) for _ in range(1000): try: torch.from_numpy(x) except TypeError: pass - self.assertTrue(sys.getrefcount(x) == 2) + final_refcount = sys.getrefcount(x) + self.assertEqual( + final_refcount, + initial_refcount, + f"Memory leak detected: refcount increased from {initial_refcount} to {final_refcount}", + ) @skipIfTorchDynamo("No need to test invalid dtypes that should fail by design.") @onlyCPU From ed6d5ff841fe28ad54e86c29c436b6d682486c59 Mon Sep 17 00:00:00 2001 From: bobrenjc93 Date: Thu, 20 Nov 2025 09:23:57 -0800 Subject: [PATCH 130/230] [precompile] nicer error message when caches are disabled (#168274) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I commonly use the following alias for generating tlparses ``` function tlp { setopt rm_star_silent rm -rf ~/tmp/trace_logs/* TORCH_TRACE=~/tmp/trace_logs TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 "$@" tlparse ~/tmp/trace_logs --slug "$USER/$(uuidgen)/custom" --overwrite-manifold } ``` and was surprised to see the following failure ``` (/home/bobren/local/a/pytorch-env) [9:18] devgpu009:/home/bobren/local/a/pytorch [130] ❯ tlp python pc.py /home/bobren/local/a/pytorch/torch/_dynamo/pgo.py:539: UserWarning: dynamo_pgo force disabled by torch.compiler.config.force_disable_caches warn_once( Traceback (most recent call last): File "/home/bobren/local/a/pytorch/pc.py", line 24, in compiled_fn.save_compiled_function(path) File "/home/bobren/local/a/pytorch/torch/_dynamo/aot_compile.py", line 128, in save_compiled_function f.write(type(self).serialize(self)) File "/home/bobren/local/a/pytorch/torch/_dynamo/aot_compile.py", line 143, in serialize type(compiled_fn).serialize_compile_artifacts(compiled_fn), File "/home/bobren/local/a/pytorch/torch/_dynamo/aot_compile_types.py", line 50, in serialize_compile_artifacts result = pickle.dumps(fn.compiled_fn.serialize()) TypeError: 'NoneType' object is not callable ``` which goes away if i remove TORCH_INDUCTOR_FORCE_DISABLE_CACHE=1 This PR adds a nicer error message ``` ❯ tlp python pc.py /home/bobren/local/a/pytorch/torch/_dynamo/pgo.py:539: UserWarning: dynamo_pgo force disabled by torch.compiler.config.force_disable_caches warn_once( Traceback (most recent call last): File "/home/bobren/local/a/pytorch/pc.py", line 24, in compiled_fn.save_compiled_function(path) File "/home/bobren/local/a/pytorch/torch/_dynamo/aot_compile.py", line 128, in save_compiled_function raise RuntimeError( RuntimeError: Cannot precompile with torch._inductor.config.force_disable_caches=True; caching is required. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/168274 Approved by: https://github.com/zhxchen17 --- torch/_dynamo/eval_frame.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 0956facde2559..43bc570841239 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -787,6 +787,11 @@ def get_compiler_config() -> Any: def aot_compile(example_inputs: tuple[tuple[Any, ...], dict[str, Any]]) -> Any: from torch._dynamo.aot_compile import aot_compile_fullgraph + if torch._inductor.config.force_disable_caches: + raise RuntimeError( + "Cannot precompile with torch._inductor.config.force_disable_caches=True; caching is required." + ) + if not self.fullgraph: raise RuntimeError( "Graph breaks are not supported with aot compile. Please use torch.compile(fullgraph=True)." From b4f3c527fce3a2c42ed0261874dfc5542b4002c0 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 20 Nov 2025 10:18:29 -0800 Subject: [PATCH 131/230] [dynamo][compile time] Special case for torch.utils._pytree._get_node_type (#168054) Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/168054 Approved by: https://github.com/XuehaiPan, https://github.com/zou3519, https://github.com/mlazos --- test/dynamo/test_repros.py | 59 ++++++++++++++++++++++++++++ torch/_dynamo/guards.py | 8 ++++ torch/_dynamo/source.py | 27 +++++++++++++ torch/_dynamo/trace_rules.py | 2 + torch/_dynamo/variables/__init__.py | 1 + torch/_dynamo/variables/functions.py | 41 +++++++++++++++++++ 6 files changed, 138 insertions(+) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 10342f56d55d1..aab7d5268fcdc 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -8184,6 +8184,65 @@ def fn(x): self.assertEqual(fn(torch.ones(3)), torch.ones(3) + 1) + def test_pytree_get_node_type_not_traced(self): + # Test that torch.utils._pytree._get_node_type is not traced into + # and doesn't cause excessive trace time overhead + from torch.utils._pytree import _get_node_type + + cnt = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnt, fullgraph=True) + def fn(x, y): + # Call _get_node_type which is used internally by pytree operations + node_type = _get_node_type([x, y]) + assert node_type is list + # Do some work with pytree structures + data = {"a": x, "b": y} + flat, spec = pytree.tree_flatten(data) + result = flat[0] + flat[1] + return result + + x = torch.randn(3, 4) + y = torch.randn(3, 4) + result = fn(x, y) + expected = x + y + + self.assertTrue(torch.allclose(result, expected)) + # Should compile successfully with fullgraph=True + self.assertEqual(cnt.frame_count, 1) + + def test_pytree_get_node_type_with_namedtuple(self): + # Test that torch.utils._pytree._get_node_type handles namedtuples correctly + # without being traced into, even when is_namedtuple_class is True + from collections import namedtuple + + from torch.utils._pytree import _get_node_type + + Point = namedtuple("Point", ["x", "y"]) + + cnt = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnt, fullgraph=True) + def fn(a, b): + # Create a namedtuple + point = Point(a, b) + # Call _get_node_type with a namedtuple instance + node_type = _get_node_type(point) + assert node_type is namedtuple + # Use pytree operations with namedtuples + flat, spec = pytree.tree_flatten(point) + result = flat[0] + flat[1] + return result + + x = torch.randn(3, 4) + y = torch.randn(3, 4) + result = fn(x, y) + expected = x + y + + self.assertTrue(torch.allclose(result, expected)) + # Should compile successfully with fullgraph=True + self.assertEqual(cnt.frame_count, 1) + instantiate_parametrized_tests(ReproTests) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index a75118f9e5032..77db6ec52d54d 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -130,6 +130,7 @@ ChainedSource, ClosureSource, CodeSource, + CollectionsSource, ConstantSource, ConstDictKeySource, CurrentStreamSource, @@ -1442,6 +1443,13 @@ def get_guard_manager_from_source(self, source: Source) -> GuardManager: example_value=example_value, guard_manager_enum=guard_manager_enum, ) + elif istype(source, CollectionsSource): + out = root_guard_manager.lambda_manager( + python_lambda=lambda _: collections, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) elif istype(source, TorchFunctionModeStackSource): out = root_guard_manager.lambda_manager( python_lambda=lambda _: get_torch_function_mode_stack_at( diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 5be6b8ccbf41d..a5a69cd177c27 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -1003,6 +1003,33 @@ def guard_source(self) -> GuardSource: return GuardSource.GLOBAL +@dataclasses.dataclass(frozen=True) +class CollectionsSource(Source): + """Points to the actual `collections` module - used instead of GlobalSource + in case the user has overridden `collections` in their local namespace""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + from .guards import GuardBuilder, install_guard + + install_guard(self.make_guard(GuardBuilder.ID_MATCH)) + + def name(self) -> str: + return "__import__('collections')" + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.extend_output( + [ + codegen.create_load_const(0), # level + create_build_tuple(0), # fromlist + codegen.create_import_name("collections"), + ] + ) + + def guard_source(self) -> GuardSource: + return GuardSource.GLOBAL + + @dataclasses.dataclass(frozen=True) class TorchFunctionModeStackSource(Source): ind: int diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 97a3946b48bde..36093b042002e 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -64,6 +64,7 @@ LocalGeneratorObjectVariable, NestedUserFunctionVariable, PolyfilledFunctionVariable, + PyTreeGetNodeTypeFunctionVariable, ReparametrizeModuleCallVariable, SkipFunctionVariable, TorchInGraphFunctionVariable, @@ -378,6 +379,7 @@ f"torch/testing/_internal/distributed/_tensor/common_dtensor.py#{TORCH_DYNAMO_RESUME_IN_PREFIX}": UserFunctionVariable, "torch/testing/_internal/common_distributed.py#forward": UserFunctionVariable, f"torch/testing/_internal/common_distributed.py#{TORCH_DYNAMO_RESUME_IN_PREFIX}": UserFunctionVariable, + "torch.utils._pytree._get_node_type": PyTreeGetNodeTypeFunctionVariable, } diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index 74165b30bb2f0..ac0be3e5888be 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -64,6 +64,7 @@ LocalGeneratorObjectVariable, NestedUserFunctionVariable, PolyfilledFunctionVariable, + PyTreeGetNodeTypeFunctionVariable, SkipFunctionVariable, TMADescriptorExperimentalVariable, TMADescriptorStableVariable, diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index e30eeeb2c2fde..1eaf58ee95dea 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -29,6 +29,7 @@ import sys import traceback import types +from collections import namedtuple from collections.abc import Callable, Sequence from types import CellType, FunctionType from typing import Any, Optional, TYPE_CHECKING, TypeVar @@ -38,6 +39,7 @@ import torch from torch._dynamo.exc import get_stack_above_dynamo from torch._guards import Source +from torch.utils._pytree import is_namedtuple_class from .. import config, graph_break_hints, polyfills, variables from ..bytecode_transformation import create_call_function, create_rot_n, is_generator @@ -59,10 +61,12 @@ from ..source import ( AttrSource, ClosureSource, + CollectionsSource, ConstantSource, DefaultsSource, GetItemSource, SkipGuardSource, + TypeSource, ) from ..utils import ( check_constant_args, @@ -2717,3 +2721,40 @@ def call_function( tensor=tensor, # type: ignore[arg-type] block_shape=block_shape, # type: ignore[arg-type] ) + + +class PyTreeGetNodeTypeFunctionVariable(UserFunctionVariable): + """ + `torch.utils._pytree._get_node_type` function is very hot function. We want to special case it to reduce Dynamo tracing time. + + def _get_node_type(tree: Any) -> Any: + node_type = type(tree) + # All namedtuple types are implicitly registered as pytree nodes. + # XXX: Other parts of the codebase expect namedtuple types always return + # `namedtuple` instead of the actual namedtuple type. Even if the type + # is explicitly registered. + if is_namedtuple_class(node_type): + return namedtuple + return node_type + """ + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if len(args) != 1: + raise_type_error_exc( + tx, + f"pytree_get_node_type requires exactly 1 argument, got {len(args)}", + ) + type_source = None + if args[0].source: + install_guard(args[0].source.make_guard(GuardBuilder.TYPE_MATCH)) + type_source = TypeSource(args[0].source) + python_type = args[0].python_type() + if is_namedtuple_class(python_type): + type_source = AttrSource(CollectionsSource(), "namedtuple") + return VariableTracker.build(tx, namedtuple, type_source) + return VariableTracker.build(tx, python_type, source=type_source) From 63ce1fb4d7c647f2941b8602d86edc79da74fce9 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Thu, 20 Nov 2025 21:02:27 +0000 Subject: [PATCH 132/230] Improve build logic in activities for kineto (#167204) # Motivation Thanks to @KarhouTam for finding the issue mentioned in #167172 This PR aims to improve the build logic in activities for kineto. # Additional Context Fix #167172 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167204 Approved by: https://github.com/EikanWang, https://github.com/ezyang --- torch/csrc/autograd/init.cpp | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index a13cc70270ccb..36a8806d281ed 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -390,31 +390,25 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { m.def("_supported_activities", []() { std::set activities{ torch::profiler::impl::ActivityType::CPU}; -#if defined(USE_KINETO) && \ - (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER)) - if (at::hasMTIA()) { - activities.insert(torch::profiler::impl::ActivityType::MTIA); - } - if (at::hasHPU()) { - activities.insert(torch::profiler::impl::ActivityType::HPU); - } +#if defined(USE_KINETO) +#if (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER)) if (at::getNumGPUs() > 0) { activities.insert(torch::profiler::impl::ActivityType::CUDA); } -#elif defined(USE_KINETO) +#endif // (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER)) if (at::hasXPU()) { activities.insert(torch::profiler::impl::ActivityType::XPU); } - if (at::hasHPU()) { - activities.insert(torch::profiler::impl::ActivityType::HPU); - } if (at::hasMTIA()) { activities.insert(torch::profiler::impl::ActivityType::MTIA); } + if (at::hasHPU()) { + activities.insert(torch::profiler::impl::ActivityType::HPU); + } if (c10::get_privateuse1_backend() != "privateuseone") { activities.insert(torch::profiler::impl::ActivityType::PrivateUse1); } -#endif +#endif // defined(USE_KINETO) return activities; }); From c4a9414bb46da3eb261a77ce314bdcf14a52f523 Mon Sep 17 00:00:00 2001 From: eellison Date: Thu, 20 Nov 2025 14:01:40 -0800 Subject: [PATCH 133/230] overlap on non mms (#167864) use non-mms for overlap. for now, i'm only using this if there is a custom estimator, since i think we probably want fusion groups to have accurate aten estimation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167864 Approved by: https://github.com/IvanKobzarev --- .../test_aten_comm_compute_reordering.py | 56 ++++++ .../_inductor/fx_passes/overlap_scheduling.py | 159 ++++++++++++------ 2 files changed, 159 insertions(+), 56 deletions(-) diff --git a/test/distributed/test_aten_comm_compute_reordering.py b/test/distributed/test_aten_comm_compute_reordering.py index a60d3868e4f82..0e76da0dbe9c0 100644 --- a/test/distributed/test_aten_comm_compute_reordering.py +++ b/test/distributed/test_aten_comm_compute_reordering.py @@ -444,6 +444,62 @@ def func(a): self.assertTrue(same(out, correct)) self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 0) + @torch._inductor.config.patch(get_patches()) + def test_custom_estimator_for_non_compute_nodes(self): + """Test that non-compute nodes with custom runtime estimates can trigger collective prefetching.""" + + def custom_estimator_with_relu(fx_node, override_size=None): + """Custom estimator that provides runtime for relu.""" + # Collective ops + if "c10" in str(fx_node.target): + return 1.0 + # Non-compute ops that we want to overlap + elif fx_node.target == aten.relu.default: + return 1.0 # relu has same time as collective + else: + return None + + def func(a, b): + c = torch.relu(a) + d = torch.mm(c, c) + + # Collective that is independent and should be prefetched during relu + ar = _functional_collectives.all_reduce(b, "sum", "0") + + # Use both results + return d * ar + + patches = { + **get_patches(), + "aten_distributed_optimizations.custom_runtime_estimation": custom_estimator_with_relu, + } + + with _dynamo_dist_per_rank_init( + self.rank, + self.world_size, + self.backend(device_type), + fake_pg=not at_least_x_gpu(2), + ): + inputs_a = ( + torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank + ) + inputs_b = torch.ones(4, 4, dtype=torch.float, device=device_type) * 2 + + with torch._inductor.config.patch(patches): + out, aten_graph_str = run_and_get_aten_graph( + torch.compile(func), inputs_a, inputs_b + ) + + # Verify that all_reduce is prefetched to run concurrently with relu + # The collective should start before relu completes to enable perfect overlap + FileCheck().check("all_reduce").check("relu").check("wait_tensor").run( + aten_graph_str + ) + + correct = func(inputs_a, inputs_b) + self.assertTrue(same(out, correct)) + self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 0) + def get_bucket_patches(compute_multiplier=1.0): estimate_aten_runtime_part = functools.partial( diff --git a/torch/_inductor/fx_passes/overlap_scheduling.py b/torch/_inductor/fx_passes/overlap_scheduling.py index b7617038f4e6a..4e70be861938e 100644 --- a/torch/_inductor/fx_passes/overlap_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_scheduling.py @@ -279,6 +279,7 @@ def __init__( self.wait_to_start: dict[fx.Node, fx.Node] = {} self._identify_collectives() + self.wasted_compute = 0.0 self.compute_index_domination = self._calculate_compute_node_domination_index() self.compute_nodes = [n for n in self.nodes if is_compute_node(n)] @@ -530,16 +531,14 @@ def run(self) -> torch.fx.GraphModule: if node in self.scheduled: continue - if is_compute_node(node): - self._handle_compute(node) + if node.op == "placeholder": + self._schedule(node) elif node in self.collective_info: self._handle_collective_start(node) elif _schedulable_wait_node(node): self._handle_wait(node) - elif node.op == "placeholder": - self._schedule(node) else: - self._handle_other(node) + self._handle_compute_or_other(node) self._reorder_graph() @@ -578,9 +577,58 @@ def _add_effect_tokens_for_overlap(self) -> None: if additional_deps: preserve_node_ordering(self.graph, additional_deps) - def _handle_other(self, node: fx.Node) -> None: + def get_non_collective_runtime_estimate(self, node: fx.Node) -> float | None: + """Get runtime estimation for a node in ms. Returns None if no estimation is available.""" + + # TODO: non custom estimation of aten nodes, potentially requires notion of fusion group + if is_compute_node(node): + return benchmark_node(node, self.custom_runtime_estimation) + + if self.custom_runtime_estimation is None: + return None + + return self.custom_runtime_estimation(node, None) + + def _reduce_exposed_time_of_in_flight_collectives( + self, node: fx.Node, available_compute: float + ) -> float: + """Reduce exposed time of in-flight collectives using available compute time and return available time""" + + # TODO: separate overlap time per process group + for info in self.in_flight.values(): + if info.exposed_time_ms == 0: + continue + overlap_amount = min(info.exposed_time_ms, available_compute) + info.exposed_time_ms -= overlap_amount + available_compute -= overlap_amount + info.hiding_nodes.add(node) + if available_compute == 0: + break + return available_compute + + def _handle_compute_or_other(self, node: fx.Node) -> None: + """Handle scheduling compute or other nodes and attempt to overlap with collectives.""" + runtime_estimate = self.get_non_collective_runtime_estimate(node) + + # TODO: we could consider skipping overlapping for overlapable, unary chains to collectives. + # using these nodes for overlap prevents bucketing. potentially if chain time < latency + if runtime_estimate is None: + assert not is_compute_node(node), "should have estimate for compute nodes" + self._schedule(node) + return + + available_compute = runtime_estimate * self.compute_overlap_multipler + initial_compute = available_compute # Track initial compute time for wasted compute/path calculations + + available_compute = self._reduce_exposed_time_of_in_flight_collectives( + node, available_compute + ) + self._schedule_collectives_for_overlap(node, available_compute, initial_compute) self._schedule(node) + if is_compute_node(node): + self.current_compute_index += 1 + def _schedule(self, node: fx.Node) -> None: """Schedule a node.""" assert node not in self.scheduled @@ -699,37 +747,21 @@ def _handle_wait(self, node: fx.Node) -> None: del self.in_flight[coll_start] self._schedule(node) - def _handle_compute(self, node: fx.Node) -> None: - """Handle scheduling compute and finding overlaps.""" - - compute_time = benchmark_node(node, self.custom_runtime_estimation) - available_compute = compute_time * self.compute_overlap_multipler - - # TODO: separate overlap time per process group - # First reduce exposed time of in-flight collectives - for info in self.in_flight.values(): - if info.exposed_time_ms == 0: - continue - overlap_amount = min(info.exposed_time_ms, available_compute) - info.exposed_time_ms -= overlap_amount - available_compute -= overlap_amount - info.hiding_nodes.add(node) - if available_compute == 0: - break - - # Then, look for unscheduled collectives we can overlap - if available_compute: - self._schedule_collectives_for_overlap(node, available_compute) - - self._schedule(node) - self.current_compute_index += 1 - def _schedule_collectives_for_overlap( - self, compute_node: fx.Node, available_compute_time: float + self, compute_node: fx.Node, available_compute_time: float, initial_time: float ) -> None: """Opportunistically schedule collectives that can be hidden by compute.""" + if available_compute_time == 0: + return + compute_ancestors = self.node_ancestors[compute_node] + # Track how much time we've already used for hiding in-flight collectives + # This allows us to add back some time from pre-fetched paths + reduced_time = ( + initial_time - available_compute_time + ) # How much of initial time we've already used + # Filter collectives by distance and compute index domination possible_collectives = [] for collective in self.unscheduled_collectives: @@ -743,9 +775,12 @@ def _schedule_collectives_for_overlap( # pre-fetched memory before memory peak, and adjust allowed collective mem. if not self.off_compute_path(collective): if ( - self.compute_index_domination[collective] - - self.current_compute_index - ) > self.max_compute_pre_fetch: + abs( + self.compute_index_domination[collective] + - self.current_compute_index + ) + > self.max_compute_pre_fetch + ): continue possible_collectives.append(collective) @@ -799,6 +834,14 @@ def _schedule_collectives_for_overlap( self.current_compute_index, ) + # Track compute runtime of nodes we must schedule to reach collective and + # add back available overlap time corresponding to prior in-flight collectives + path_estimates = [self.get_non_collective_runtime_estimate(p) for p in path] + path_time = sum(p for p in path_estimates if p is not None) + additional_time = min(path_time, reduced_time) + reduced_time -= additional_time + available_compute_time += additional_time + # Schedule path to this collective self._schedule_path_to_collective(path, compute_node) self._handle_collective_start(collective) @@ -810,6 +853,8 @@ def _schedule_collectives_for_overlap( info.hiding_nodes.add(compute_node) available_compute_time -= overlap_amount + self.wasted_compute += available_compute_time + def _find_schedulable_path( self, target: fx.Node, curr_compute_node: fx.Node | None ) -> OrderedSet[fx.Node] | None: @@ -911,13 +956,21 @@ def _reorder_graph(self) -> None: if c.exposed_time_ms == c.estimated_time_ms ] - potentially_hidden_collectives = self.compute_potential_hidden_collectives( - limit_coll_per_compute=True - ) + potentially_hidden_collectives = self.compute_potential_hidden_collectives() bad_exposed = [ c for c in exposed if c.start_node in potentially_hidden_collectives ] + # Compute total exposed and potential exposed time + total_exposed = sum(c.exposed_time_ms for c in self.collective_info.values()) + hideable_exposed_ms = sum( + self.collective_info[c].exposed_time_ms + for c in potentially_hidden_collectives + ) + total_potential_exposed = sum( + c.estimated_time_ms for c in self.collective_info.values() + ) + counters["inductor"]["overlap_scheduling_exposed"] += len(exposed) counters["inductor"]["overlap_scheduling_bad_exposed"] += len(bad_exposed) counters["inductor"]["overlap_scheduling_potentially_hidden"] += len( @@ -928,12 +981,18 @@ def _reorder_graph(self) -> None: log.info( "Overlap scheduling results: exposed=%d, bad_exposed=%d, potentially_hidden=%d, " - "original_peak_memory=%d bytes, rescheduled_peak_memory=%d bytes", + "original_peak_memory=%d bytes, rescheduled_peak_memory=%d bytes, " + "total_exposed_ms=%.2f, hideable_exposed_ms=%.2f, total_potential_exposed_ms=%.2f, " + "wasted_compute_ms=%.2f", len(exposed), len(bad_exposed), len(potentially_hidden_collectives), self.original_peak_memory, self.memory_tracker.peak_memory, + total_exposed, + hideable_exposed_ms, + total_potential_exposed, + self.wasted_compute, ) self.reorder_graph() @@ -955,24 +1014,18 @@ def _bucket_collectives(self) -> None: bucketer.bucket_collectives() def compute_potential_hidden_nodes( - self, nodes_to_check: Iterable[fx.Node], limit_coll_per_compute: bool = False + self, nodes_to_check: Iterable[fx.Node] ) -> dict[fx.Node, fx.Node]: """ Returns a dict containing a mapping of nodes which could potentially be hidden to their hiding node """ - used_compute_nodes: OrderedSet[fx.Node] = OrderedSet() - def could_be_hidden(start: fx.Node) -> fx.Node | None: for compute_node in self.compute_nodes: - if limit_coll_per_compute and compute_node in used_compute_nodes: - continue if ( start not in self.node_ancestors[compute_node] and compute_node not in self.node_ancestors[start] ): - if limit_coll_per_compute: - used_compute_nodes.add(compute_node) return compute_node return None @@ -988,20 +1041,14 @@ def could_be_hidden(start: fx.Node) -> fx.Node | None: return potentially_hidden - def compute_potential_hidden_collectives( - self, limit_coll_per_compute: bool = False - ) -> dict[fx.Node, fx.Node]: + def compute_potential_hidden_collectives(self) -> dict[fx.Node, fx.Node]: """Compute which collective operations could be hidden by compute.""" - return self.compute_potential_hidden_nodes( - self.collective_info.keys(), limit_coll_per_compute - ) + return self.compute_potential_hidden_nodes(self.collective_info.keys()) - def compute_potential_hidden_waits( - self, limit_coll_per_compute: bool = False - ) -> dict[fx.Node, fx.Node]: + def compute_potential_hidden_waits(self) -> dict[fx.Node, fx.Node]: """Compute which wait operations could be hidden by compte.""" wait_nodes = [info.wait_node for info in self.collective_info.values()] - return self.compute_potential_hidden_nodes(wait_nodes, limit_coll_per_compute) + return self.compute_potential_hidden_nodes(wait_nodes) def schedule_overlap_bucketing( From 7641553e06bc05960f3e4903c8c174888ddc1c4d Mon Sep 17 00:00:00 2001 From: eellison Date: Thu, 20 Nov 2025 14:01:40 -0800 Subject: [PATCH 134/230] better use of mem tracking (#168121) Previously we had guards on memory by limiting pre fetching. Now, simulate the memory for the entire model, record the original peak memory for each compute node, record cumulative pre fetched memory, and block pre fetching if it would increase memory beyond allowed amount. Allow users to pass in both absolute mem increase and ratio increase and take max of these. Making some of the defaults more aggressive accordingly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168121 Approved by: https://github.com/IvanKobzarev ghstack dependencies: #167864 --- torch/_inductor/config.py | 5 + .../fx_passes/overlap_manual_scheduling.py | 2 + .../fx_passes/overlap_preserving_bucketer.py | 18 +- .../_inductor/fx_passes/overlap_scheduling.py | 255 +++++++++++++----- torch/_inductor/fx_passes/post_grad.py | 2 + 5 files changed, 205 insertions(+), 77 deletions(-) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index e4660f90e1eb4..af466dc61031a 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -950,6 +950,11 @@ class aten_distributed_optimizations: # "benchmark": Use CUDA events with power-of-2 rounding and interpolation collective_estimator: Literal["analytical", "benchmark"] = "analytical" + # Maximum memory increase above baseline for prefetch operations + # Uses minimum of absolute cap and ratio of baseline + max_memory_increase_gb: Optional[float] = None # Absolute cap in GB + max_memory_increase_ratio: Optional[float] = None # Ratio of baseline peak memory + def parallel_compile_enabled_internally() -> bool: """ diff --git a/torch/_inductor/fx_passes/overlap_manual_scheduling.py b/torch/_inductor/fx_passes/overlap_manual_scheduling.py index f5c131a7eab96..c8af70dc598f4 100644 --- a/torch/_inductor/fx_passes/overlap_manual_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_manual_scheduling.py @@ -172,6 +172,8 @@ def __init__( max_coll_distance=0, custom_runtime_estimation=None, collective_estimator="analytical", + max_memory_increase_gb=None, + max_memory_increase_ratio=None, ) self.module_bucket_plans = module_bucket_plans self.nodes_in_subgraph: list[list[fx.Node]] = [] diff --git a/torch/_inductor/fx_passes/overlap_preserving_bucketer.py b/torch/_inductor/fx_passes/overlap_preserving_bucketer.py index b6cbf32bfba8e..e306641ac1d8d 100644 --- a/torch/_inductor/fx_passes/overlap_preserving_bucketer.py +++ b/torch/_inductor/fx_passes/overlap_preserving_bucketer.py @@ -326,7 +326,7 @@ def _find_buckets( # Sort collectives by node index for efficient distance checking sorted_collectives = sorted(collective_group, key=lambda n: self.node_idx[n]) - for start_node in sorted_collectives: + for i, start_node in enumerate(sorted_collectives): if start_node in processed: continue @@ -336,25 +336,17 @@ def _find_buckets( total_bytes=self.collective_info[start_node].size_bytes, ) processed.add(start_node) - start_node_idx = self.node_idx[start_node] # Check candidates in sorted order, break when beyond max distance - for candidate in sorted_collectives: + for candidate in sorted_collectives[i + 1 : i + 1 + self.max_coll_distance]: if candidate in processed: continue - candidate_idx = self.node_idx[candidate] - # Check if candidate is within max distance from the bucket start - distance = abs(candidate_idx - start_node_idx) - if distance > self.max_coll_distance: - # Since sorted, all remaining candidates will be too far - if candidate_idx > start_node_idx: - break - continue - candidate_bytes = self.collective_info[candidate].size_bytes + # proxy on memory use, if we see a too large bucket, + # dont look for another, later bucket if bucket_info.total_bytes + candidate_bytes > max_bucket_bytes: - continue + break if self._can_add_to_bucket(bucket_info, candidate): bucket_info.collectives.append(candidate) diff --git a/torch/_inductor/fx_passes/overlap_scheduling.py b/torch/_inductor/fx_passes/overlap_scheduling.py index 4e70be861938e..436a3ab0db81b 100644 --- a/torch/_inductor/fx_passes/overlap_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_scheduling.py @@ -13,11 +13,7 @@ from torch._dynamo.utils import counters, dynamo_timed from torch._inductor.comm_analysis import estimate_fx_collective_memory_footprint from torch._inductor.fx_passes.bucketing import _schedulable_wait_node, is_wait_tensor -from torch._inductor.fx_passes.memory_estimator import ( - _is_releasable, - build_memory_profile, - MemoryTracker, -) +from torch._inductor.fx_passes.memory_estimator import MemoryTracker from torch.fx.operator_schemas import normalize_function from torch.utils._ordered_set import OrderedSet from torch.utils._python_dispatch import _disable_current_modes @@ -30,6 +26,27 @@ from ..pattern_matcher import stable_topological_sort +@dataclass +class WhyNoOverlap: + """Track reasons why a collective cannot overlap with compute.""" + + compute_name: str + collective_name: str + + def __init__(self, compute_node: fx.Node, collective_node: fx.Node) -> None: + self.compute_name = compute_node.name + self.collective_name = collective_node.name + + def __call__(self, reason: str, *args: Any) -> None: + if log.isEnabledFor(logging.DEBUG): + log.debug( + "cannot overlap %s with %s: " + reason, # noqa: G003 + self.collective_name, + self.compute_name, + *args, + ) + + def get_group_name(n: fx.Node) -> str: """Extract the group name from a collective operation node.""" opt_args_kwargs = normalize_function( @@ -247,6 +264,8 @@ def __init__( max_coll_distance: int, custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] | None, collective_estimator: Literal["analytical", "benchmark"], + max_memory_increase_gb: float | None = 1.0, + max_memory_increase_ratio: float | None = 0.05, ): self.gm = gm self.graph = gm.graph @@ -271,10 +290,40 @@ def __init__( self.collective_info: dict[fx.Node, CollectiveInfo] = {} self.unscheduled_collectives: OrderedSet[fx.Node] = OrderedSet() - # Memory tracking using abstracted MemoryTracker - self.original_peak_memory = max( - build_memory_profile(self.graph, _is_releasable) + # Identify compute nodes early (needed for baseline memory computation) + self.compute_nodes = [n for n in self.nodes if is_compute_node(n)] + self.current_compute_index = 0 + + # Compute baseline memory profile from original schedule + self.original_mem_before_compute_index: list[int] = [] + self.original_peak_memory = self._compute_baseline_memory() + + # Maximum allowed peak memory = baseline + max(absolute, ratio * baseline) + # When both limits are specified, use the more permissive one + memory_increase_bytes = None + if max_memory_increase_gb is not None: + memory_increase_bytes = gb_to_bytes(max_memory_increase_gb) + if max_memory_increase_ratio is not None: + ratio_increase = int(self.original_peak_memory * max_memory_increase_ratio) + memory_increase_bytes = ( + max(memory_increase_bytes, ratio_increase) + if memory_increase_bytes is not None + else ratio_increase + ) + if memory_increase_bytes is None: + memory_increase_bytes = 0 + + self.allowed_peak_memory_bytes = ( + self.original_peak_memory + memory_increase_bytes ) + + # Track cumulative prefetch memory at each compute index + # When we prefetch a collective at compute index i that will be used at index j, + # it adds memory from i to j, so we need to track this cumulative effect + self.cumulative_prefetch_mem_by_compute_index: list[int] = [ + 0 for _ in range(len(self.compute_nodes)) + ] + self.memory_tracker = MemoryTracker(self.graph) self.wait_to_start: dict[fx.Node, fx.Node] = {} @@ -282,8 +331,6 @@ def __init__( self.wasted_compute = 0.0 self.compute_index_domination = self._calculate_compute_node_domination_index() - self.compute_nodes = [n for n in self.nodes if is_compute_node(n)] - self.current_compute_index = 0 # Scheduling state self.potentially_hidden_collectives = ( @@ -312,6 +359,88 @@ def _collect_node_ancestors(self) -> dict[fx.Node, OrderedSet[fx.Node]]: return ancestors + def _compute_baseline_memory(self) -> int: + """ + Simulate the original schedule to compute baseline memory profile. + Returns the peak memory observed during simulation. + """ + baseline_tracker = MemoryTracker(self.graph) + + last_compute_max_memory = 0 + peak_memory = 0 + + for node in self.nodes: + baseline_tracker.schedule_node(node) + current_mem = baseline_tracker.current_memory_bytes + + # Record the max memory between this and previous compute node + last_compute_max_memory = max(last_compute_max_memory, current_mem) + + if is_compute_node(node): + self.original_mem_before_compute_index.append(last_compute_max_memory) + last_compute_max_memory = current_mem + + peak_memory = max(peak_memory, current_mem) + + return peak_memory + + def _prefetch_would_exceed_memory_budget(self, start_node: fx.Node) -> bool: + """ + Check if prefetching this collective would exceed memory budget at ANY compute node + between now and when it's used. + """ + info = self.collective_info[start_node] + size = info.size_bytes + + domination_index = self.compute_index_domination[start_node] + + # If off-path, assume it doesn't increase memory + if domination_index == sys.maxsize: + return False + + # check current mem + if ( + self.memory_tracker.current_memory_bytes + size + > self.allowed_peak_memory_bytes + ): + return True + + start_index = self.current_compute_index + + # then, check future mem + for compute_idx in range(start_index, domination_index): + cumulative_prefetch = self.cumulative_prefetch_mem_by_compute_index[ + compute_idx + ] + + # Check 1: Would cumulative prefetch exceed in-flight limit? + if (cumulative_prefetch + size) > self.max_in_flight_bytes: + return True + + # Check 2: Would total memory (baseline + cumulative prefetch) exceed budget? + baseline_mem = self.original_mem_before_compute_index[compute_idx] + projected = baseline_mem + cumulative_prefetch + size + + if projected > self.allowed_peak_memory_bytes: + return True + + return False + + def _update_cumulative_prefetch_memory( + self, collective: fx.Node, info: CollectiveInfo + ) -> None: + """ + Update cumulative prefetch memory for all compute indices this collective will be live. + """ + domination_index = self.compute_index_domination[collective] + if domination_index == sys.maxsize: + return + + for compute_idx in range(self.current_compute_index, domination_index): + self.cumulative_prefetch_mem_by_compute_index[compute_idx] += ( + info.size_bytes + ) + def off_compute_path(self, n: fx.Node) -> bool: """Check if a node is off the compute path (doesn't block any compute).""" return self.compute_index_domination[n] == sys.maxsize @@ -699,9 +828,8 @@ def _should_force_wait_for_memory(self) -> bool: """Check if we need to force a wait due to memory pressure""" if not self.in_flight: return False - return self.in_flight_bytes >= self.max_in_flight_bytes or ( - self.memory_tracker.current_memory_bytes - self.original_peak_memory - ) > gb_to_bytes(1.0) + + return self.in_flight_bytes >= self.max_in_flight_bytes def _force_oldest_wait(self) -> None: """Schedule the oldest in flight wait""" @@ -754,61 +882,47 @@ def _schedule_collectives_for_overlap( if available_compute_time == 0: return + reduced_time = initial_time - available_compute_time compute_ancestors = self.node_ancestors[compute_node] - # Track how much time we've already used for hiding in-flight collectives - # This allows us to add back some time from pre-fetched paths - reduced_time = ( - initial_time - available_compute_time - ) # How much of initial time we've already used - - # Filter collectives by distance and compute index domination - possible_collectives = [] - for collective in self.unscheduled_collectives: - distance = abs(self.node_idx[compute_node] - self.node_idx[collective]) - if distance > self.max_node_distance: + # Compile-time filtering: limit candidates by distance to bound O(compute * collectives) cost + candidates = [] + for i, collective in enumerate(self.unscheduled_collectives): + if i > self.max_node_distance: break - # Skip collectives that are too far ahead in compute index, but allow scheduling - # collectives which are off compute path (which typically release memory) - # TODO: we could potentially be more strict about limiting the amount of - # pre-fetched memory before memory peak, and adjust allowed collective mem. - if not self.off_compute_path(collective): - if ( - abs( - self.compute_index_domination[collective] - - self.current_compute_index - ) - > self.max_compute_pre_fetch - ): - continue + if ( + not self.off_compute_path(collective) + and self.compute_index_domination[collective] + - self.current_compute_index + > self.max_compute_pre_fetch + ): + continue - possible_collectives.append(collective) + candidates.append(collective) - possible_collectives = sorted( - possible_collectives, + candidates = sorted( + candidates, key=lambda n: (self.compute_index_domination[n], self.node_idx[n]), ) - log.debug( - "Scheduling collectives for overlap: compute_node=%s, available_time=%.2f ms, candidates=%d, current_memory=%d bytes", - compute_node.name, - available_compute_time, - len(possible_collectives), - self.memory_tracker.current_memory_bytes, - ) - - for collective in possible_collectives: + for collective in candidates: if available_compute_time == 0: break + why = WhyNoOverlap(compute_node, collective) info = self.collective_info[collective] - # Skip if compute depends on collective or vice versa if ( collective in compute_ancestors or compute_node in self.node_ancestors[collective] ): + why("dependency conflict") + continue + + # Check if prefetching would exceed memory budget + if self._prefetch_would_exceed_memory_budget(collective): + why("prefetch would exceed memory budget") continue while ( @@ -819,10 +933,11 @@ def _schedule_collectives_for_overlap( self._force_oldest_wait() if (self.max_in_flight_bytes - self.in_flight_bytes) < info.size_bytes: + why("in-flight memory limit") continue # Check if we can reach this collective without scheduling compute, other collectives, or waits - path = self._find_schedulable_path(collective, compute_node) + path = self._find_schedulable_path(collective, compute_node, why) if path is None: continue @@ -842,12 +957,11 @@ def _schedule_collectives_for_overlap( reduced_time -= additional_time available_compute_time += additional_time - # Schedule path to this collective self._schedule_path_to_collective(path, compute_node) self._handle_collective_start(collective) + self._update_cumulative_prefetch_memory(collective, info) - # Update the exposed time for this newly scheduled collective - # after scheduling, which will account for latency reduction of bucketing + # Update exposed time overlap_amount = min(available_compute_time, info.exposed_time_ms) info.exposed_time_ms -= overlap_amount info.hiding_nodes.add(compute_node) @@ -856,19 +970,20 @@ def _schedule_collectives_for_overlap( self.wasted_compute += available_compute_time def _find_schedulable_path( - self, target: fx.Node, curr_compute_node: fx.Node | None + self, target: fx.Node, curr_compute_node: fx.Node | None, why: WhyNoOverlap ) -> OrderedSet[fx.Node] | None: """Find path to target by collecting unscheduled dependencies.""" - - # TODO - following path faster than doing set difference here + # Get unscheduled ancestors unscheduled_ancestors = self.node_ancestors[target] - self.scheduled # only schedule non distributed, non compute nodes for node in unscheduled_ancestors: if is_compute_node(node): + why("path blocked by compute node %s", node.name) return None if node in self.unscheduled_collectives: + why("path blocked by unscheduled collective %s", node.name) return None # if we schedule a wait tensor whose start collective is hidden by the @@ -880,8 +995,13 @@ def _find_schedulable_path( if _schedulable_wait_node(node): info = self.collective_info[self.wait_to_start[node]] if info.hiding_nodes and curr_compute_node not in info.hiding_nodes: + why( + "path blocked by wait node %s with different hiding compute", + node.name, + ) continue elif node not in self.potentially_hidden_waits: + why("path blocked by wait node %s that could be hidden", node.name) continue return None @@ -1007,7 +1127,7 @@ def _bucket_collectives(self) -> None: collective_info=self.collective_info, node_ancestors=self.node_ancestors, scheduled=self.scheduled, - max_bucket_memory_gb=1.0, # Could make this configurable + max_bucket_memory_gb=2.0, # Could make this configurable max_coll_distance=self.max_node_distance, insert_overlap_deps=self.insert_overlap_deps, ) @@ -1053,15 +1173,17 @@ def compute_potential_hidden_waits(self) -> dict[fx.Node, fx.Node]: def schedule_overlap_bucketing( gm: torch.fx.GraphModule, - max_in_flight_gb: float = 2.0, - max_compute_pre_fetch: int = 5, + max_in_flight_gb: float = 5, + max_compute_pre_fetch: int = 200, collective_bucketing: bool = False, insert_overlap_deps: bool = False, compute_overlap_multipler: float = 1.0, - max_coll_distance: int = 1000, + max_coll_distance: int = 200, custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] | None = None, collective_estimator: Literal["analytical", "benchmark"] = "analytical", + max_memory_increase_gb: float | None = 1.0, + max_memory_increase_ratio: float | None = 0.05, ) -> torch.fx.GraphModule: """Schedule nodes to maximize compute-collective overlap. @@ -1069,19 +1191,22 @@ def schedule_overlap_bucketing( gm: Input graph module to optimize. max_in_flight_gb: Maximum GB of concurrent collective data. Too much in flight memory can cause memory fragmentation within the CUDA Caching Allocator. - max_compute_pre_fetch: Maximum compute node prefetch distance. + max_compute_pre_fetch: Maximum mm nodes to pre fetch. Note: should already be limited by max_in_flight_gb and + max_memory_increase_gb collective_bucketing: Enable overlap-preserving collective bucketing. insert_overlap_deps: Insert overlap dependencies using control deps operator. This should only be used if compiling with inductor, or for subsequent passes before removing the ops prior to execution. compute_overlap_multipler: Scale factor for compute time used to hide collectives. This can be used to address over or under aggressive overlapping. - max_coll_distance: Maximum node distance for overlap or bucketing. Mostly intended to reduce compile time. + max_coll_distance: Maximum pre fetch or bucketing candidates. Mainly intended for compile time custom_runtime_estimation: Custom runtime estimation function that estimates runtime in ms for an fx node. If None, uses default estimations. This is currently limited to collectives and compute nodes. collective_estimator: Method for estimating collective runtime. "analytical" uses bandwidth formulas, "benchmark" uses CUDA events with power-of-2 rounding and interpolation. + max_memory_increase_gb: Maximum GB increase above baseline memory (absolute cap). If None, no absolute limit. + max_memory_increase_ratio: Maximum increase as ratio of baseline peak memory. If None, no ratio limit. + Uses minimum of absolute and ratio limits when both are specified. """ - return OverlapScheduler( gm, compute_overlap_multipler=compute_overlap_multipler, @@ -1092,4 +1217,6 @@ def schedule_overlap_bucketing( collective_bucketing=collective_bucketing, insert_overlap_deps=insert_overlap_deps, collective_estimator=collective_estimator, + max_memory_increase_gb=max_memory_increase_gb, + max_memory_increase_ratio=max_memory_increase_ratio, ).run() diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index e403e82ff6c3b..a21e78821e52b 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -303,6 +303,8 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): "custom_runtime_estimation", "insert_overlap_deps", "collective_estimator", + "max_memory_increase_gb", + "max_memory_increase_ratio", ) for key in config_keys: if (val := getattr(dist_opts, key)) is not None: From 1328a02d2eb26ee67346048ee242327ea90d6315 Mon Sep 17 00:00:00 2001 From: eellison Date: Thu, 20 Nov 2025 14:01:41 -0800 Subject: [PATCH 135/230] bucketing compile time improve (#168122) Strict compile time improvement. We always maintain that start -> hiding nodes -> wait. Add start, to hiding nodes ancestors, and hiding nodes to wait ancestors, to minimize repeated graph searches by precomputing the dependencies. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168122 Approved by: https://github.com/IvanKobzarev ghstack dependencies: #167864, #168121 --- .../test_overlap_bucketing_unit.py | 48 ++------------ .../fx_passes/overlap_manual_scheduling.py | 1 - .../fx_passes/overlap_preserving_bucketer.py | 65 +++++++++++++++---- .../_inductor/fx_passes/overlap_scheduling.py | 1 - 4 files changed, 59 insertions(+), 56 deletions(-) diff --git a/test/distributed/test_overlap_bucketing_unit.py b/test/distributed/test_overlap_bucketing_unit.py index c0c4c31cc1a81..8dd937a31c240 100644 --- a/test/distributed/test_overlap_bucketing_unit.py +++ b/test/distributed/test_overlap_bucketing_unit.py @@ -93,28 +93,6 @@ def build_collective_info(graph, hiding_annotations): return collective_info -def compute_ancestors(graph): - """Compute ancestor sets for all nodes in the graph.""" - node_ancestors = {} - - for node in graph.nodes: - ancestors = OrderedSet() - stack = list(node.all_input_nodes) - visited = set() - - while stack: - current = stack.pop() - if current in visited: - continue - visited.add(current) - ancestors.add(current) - stack.extend(current.all_input_nodes) - - node_ancestors[node] = ancestors - - return node_ancestors - - @requires_accelerator_dist_backend() @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @instantiate_parametrized_tests @@ -190,9 +168,8 @@ def func(a, b): ag2: mm2, # mm2 hides ag2 } - # Build collective info and ancestors + # Build collective info and scheduled collective_info = build_collective_info(traced.graph, hiding_annotations) - node_ancestors = compute_ancestors(traced.graph) scheduled = OrderedSet(traced.graph.nodes) # Run bucketing @@ -203,7 +180,6 @@ def func(a, b): bucketer = OverlapPreservingBucketer( traced.graph, collective_info, - node_ancestors, scheduled, ) bucketer.bucket_collectives() @@ -278,9 +254,8 @@ def func(a, b): ag2: mm2, # mm2 hides ag2 } - # Build collective info and ancestors + # Build collective info and scheduled collective_info = build_collective_info(traced.graph, hiding_annotations) - node_ancestors = compute_ancestors(traced.graph) scheduled = OrderedSet(traced.graph.nodes) # Run bucketing @@ -291,7 +266,6 @@ def func(a, b): bucketer = OverlapPreservingBucketer( traced.graph, collective_info, - node_ancestors, scheduled, ) bucketer.bucket_collectives() @@ -381,9 +355,8 @@ def func(a, b, c): if final_mm_hidden: hiding_annotations[rs] = mm2 - # Build collective info and ancestors + # Build collective info and scheduled collective_info = build_collective_info(traced.graph, hiding_annotations) - node_ancestors = compute_ancestors(traced.graph) scheduled = OrderedSet(traced.graph.nodes) # Run bucketing logic to find buckets (without applying them, which would require process groups) @@ -394,7 +367,6 @@ def func(a, b, c): bucketer = OverlapPreservingBucketer( traced.graph, collective_info, - node_ancestors, scheduled, ) @@ -467,7 +439,6 @@ def func(a, b): # Build collective info collective_info = build_collective_info(traced.graph, hiding_annotations) - node_ancestors = compute_ancestors(traced.graph) scheduled = OrderedSet(traced.graph.nodes) # Run bucketing @@ -478,7 +449,6 @@ def func(a, b): bucketer = OverlapPreservingBucketer( traced.graph, collective_info, - node_ancestors, scheduled, ) bucketer.bucket_collectives() @@ -550,9 +520,8 @@ def func(a, b): ag2: mm2, # mm2 hides ag2 } - # Build collective info and ancestors + # Build collective info and scheduled collective_info = build_collective_info(traced.graph, hiding_annotations) - node_ancestors = compute_ancestors(traced.graph) scheduled = OrderedSet(traced.graph.nodes) # Run bucketing with multidtype mode @@ -563,7 +532,6 @@ def func(a, b): bucketer = OverlapPreservingBucketer( traced.graph, collective_info, - node_ancestors, scheduled, bucket_mode="custom_ops_multidtype", ) @@ -635,9 +603,8 @@ def func(a, b): ag2: [mm2, mm3], # ag2 is hidden by mm2 and mm3 } - # Build collective info and ancestors + # Build collective info and scheduled collective_info = build_collective_info(traced.graph, hiding_annotations) - node_ancestors = compute_ancestors(traced.graph) scheduled = OrderedSet(traced.graph.nodes) # Verify hiding_nodes are correctly set @@ -656,7 +623,6 @@ def func(a, b): bucketer = OverlapPreservingBucketer( traced.graph, collective_info, - node_ancestors, scheduled, ) bucketer.bucket_collectives() @@ -729,9 +695,8 @@ def func(a, b, c): ag3: mm, } - # Build collective info and ancestors + # Build collective info and scheduled collective_info = build_collective_info(traced.graph, hiding_annotations) - node_ancestors = compute_ancestors(traced.graph) scheduled = OrderedSet(traced.graph.nodes) # Run bucketing @@ -742,7 +707,6 @@ def func(a, b, c): bucketer = OverlapPreservingBucketer( traced.graph, collective_info, - node_ancestors, scheduled, ) bucketer.bucket_collectives() diff --git a/torch/_inductor/fx_passes/overlap_manual_scheduling.py b/torch/_inductor/fx_passes/overlap_manual_scheduling.py index c8af70dc598f4..d2c8b588d2011 100644 --- a/torch/_inductor/fx_passes/overlap_manual_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_manual_scheduling.py @@ -182,7 +182,6 @@ def __init__( self.bucketer = ManualOverlapPreservingBucketer( graph=self.graph, collective_info=self.collective_info, - node_ancestors=self.node_ancestors, node_users=self.node_users, scheduled=OrderedSet(self.graph.nodes), ) diff --git a/torch/_inductor/fx_passes/overlap_preserving_bucketer.py b/torch/_inductor/fx_passes/overlap_preserving_bucketer.py index e306641ac1d8d..ed37c0902c325 100644 --- a/torch/_inductor/fx_passes/overlap_preserving_bucketer.py +++ b/torch/_inductor/fx_passes/overlap_preserving_bucketer.py @@ -1,3 +1,4 @@ +import itertools import logging from collections import defaultdict from dataclasses import dataclass @@ -130,7 +131,6 @@ def __init__( self, graph: fx.Graph, collective_info: dict[fx.Node, CollectiveInfo], - node_ancestors: dict[fx.Node, OrderedSet[fx.Node]], scheduled: OrderedSet[fx.Node], max_bucket_memory_gb: float = 1.0, max_coll_distance: int = 1000, @@ -139,19 +139,46 @@ def __init__( ): self.graph = graph self.collective_info = collective_info - self.node_ancestors = node_ancestors self.scheduled = scheduled self.max_bucket_memory_gb = max_bucket_memory_gb self.node_idx = {n: i for i, n in enumerate(scheduled)} - self.aug_graph = AugmentedGraphHelper(self.graph, self.node_ancestors) self.max_coll_distance = max_coll_distance self.insert_overlap_deps = insert_overlap_deps self.bucket_mode = bucket_mode self.node_to_event: dict[fx.Node, PGEvent] = {} - self.pg_to_timeline_head: dict[str, Optional[PGEvent]] = self.build_timelines() + # Compute ancestors including original graph edges and hiding interval dependencies + self.node_ancestors = self._compute_node_ancestors() + self.aug_graph = AugmentedGraphHelper(self.graph, self.node_ancestors) + + # Build timelines and add constraints to aug_graph + self.pg_to_timeline_head: dict[str, Optional[PGEvent]] = self.build_timelines() self._add_hiding_interval_constraints() + def _compute_node_ancestors(self) -> dict[fx.Node, OrderedSet[fx.Node]]: + """ + Compute ancestor sets for all nodes including: + 1. Original graph edges + 2. Hiding interval deps: collective_start -> hiding_node -> wait + """ + augmented_inputs: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet) + for start, info in self.collective_info.items(): + if info.is_exposed: + continue + for hiding_node in info.hiding_nodes: + augmented_inputs[hiding_node].add(start) + augmented_inputs[info.wait_node].add(hiding_node) + + node_ancestors: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet) + for node in self.scheduled: + for input_node in itertools.chain( + augmented_inputs[node], node.all_input_nodes + ): + node_ancestors[node].add(input_node) + node_ancestors[node] |= node_ancestors[input_node] + + return node_ancestors + def build_timelines(self) -> dict[str, Optional[PGEvent]]: "Construct each process groups ordered series of event" all_pgs: OrderedSet[str] = OrderedSet() @@ -337,21 +364,30 @@ def _find_buckets( ) processed.add(start_node) + # Greedy optimization: stop after consecutive failures + consecutive_failures = 0 + max_consecutive_failures = 20 + # Check candidates in sorted order, break when beyond max distance for candidate in sorted_collectives[i + 1 : i + 1 + self.max_coll_distance]: - if candidate in processed: - continue - candidate_bytes = self.collective_info[candidate].size_bytes # proxy on memory use, if we see a too large bucket, # dont look for another, later bucket if bucket_info.total_bytes + candidate_bytes > max_bucket_bytes: break + if candidate in processed: + continue + if self._can_add_to_bucket(bucket_info, candidate): bucket_info.collectives.append(candidate) bucket_info.total_bytes += candidate_bytes processed.add(candidate) + consecutive_failures = 0 # Reset on success + else: + consecutive_failures += 1 + if consecutive_failures >= max_consecutive_failures: + break if len(bucket_info.collectives) > 1: buckets.append(bucket_info) @@ -656,23 +692,28 @@ def _has_ancestor_conflicts( candidate_wait = candidate_info.wait_node for coll in bucket_info.collectives: - # Check if collectives are ancestors of each other - if self._ancestor_dep(coll, candidate): + if ( + coll in self.node_ancestors[candidate] + or candidate in self.node_ancestors[coll] + ): return True # Check if waits are ancestors of each other coll_wait = self.collective_info[coll].wait_node - if self._ancestor_dep(candidate_wait, coll_wait): + if ( + coll_wait in self.node_ancestors[candidate_wait] + or candidate_wait in self.node_ancestors[coll_wait] + ): return True # Check if existing hiding node conflicts with candidate wait for old_hiding_node in self.collective_info[coll].hiding_nodes: - if self._ancestor_dep(old_hiding_node, candidate_wait): + if candidate_wait in self.node_ancestors[old_hiding_node]: return True # Check if candidate hiding node conflicts with existing wait for new_hiding_node in candidate_info.hiding_nodes: - if self._ancestor_dep(new_hiding_node, coll_wait): + if coll_wait in self.node_ancestors[new_hiding_node]: return True return False diff --git a/torch/_inductor/fx_passes/overlap_scheduling.py b/torch/_inductor/fx_passes/overlap_scheduling.py index 436a3ab0db81b..14555c84b43ce 100644 --- a/torch/_inductor/fx_passes/overlap_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_scheduling.py @@ -1125,7 +1125,6 @@ def _bucket_collectives(self) -> None: bucketer = OverlapPreservingBucketer( graph=self.graph, collective_info=self.collective_info, - node_ancestors=self.node_ancestors, scheduled=self.scheduled, max_bucket_memory_gb=2.0, # Could make this configurable max_coll_distance=self.max_node_distance, From 5cb57184c22fdb57cc7b8c4db9399c586323588c Mon Sep 17 00:00:00 2001 From: drisspg Date: Thu, 20 Nov 2025 22:54:25 +0000 Subject: [PATCH 136/230] Add public grouped_mm (#168298) Fixes #166651 One thing I found was that, nvm I rediscovered: https://github.com/NVIDIA/cutlass/issues/2674 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168298 Approved by: https://github.com/ngimel, https://github.com/slayton58 --- docs/source/nn.functional.rst | 1 + test/distributed/tensor/test_matrix_ops.py | 10 ++--- test/inductor/test_cutedsl_grouped_mm.py | 5 ++- test/inductor/test_torchinductor.py | 2 +- test/test_matmul_cuda.py | 13 +++--- test/test_scaled_matmul_cuda.py | 13 ++++-- torch/nn/functional.py | 46 ++++++++++++++++++++++ torch/nn/functional.pyi.in | 11 ++++++ torch/overrides.py | 1 + 9 files changed, 85 insertions(+), 17 deletions(-) diff --git a/docs/source/nn.functional.rst b/docs/source/nn.functional.rst index 015d1d9ffda1a..ba39d80700f28 100644 --- a/docs/source/nn.functional.rst +++ b/docs/source/nn.functional.rst @@ -227,5 +227,6 @@ Low-Precision functions ScalingType SwizzleType + grouped_mm scaled_mm scaled_grouped_mm diff --git a/test/distributed/tensor/test_matrix_ops.py b/test/distributed/tensor/test_matrix_ops.py index 6e3dd23c44210..65c8cc6f36af4 100644 --- a/test/distributed/tensor/test_matrix_ops.py +++ b/test/distributed/tensor/test_matrix_ops.py @@ -549,7 +549,7 @@ def test_tensordot_shampoo(self): ], ) def test_grouped_mm(self, kwargs): - # TODO: torch._grouped_mm can take inputs of dimension (2D, 3D) x (2D, 3D) + # TODO: torch.nn.functional.grouped_mm can take inputs of dimension (2D, 3D) x (2D, 3D) # More tests need to be added. device_mesh = self.build_device_mesh() comm_mode = CommDebugMode() @@ -574,8 +574,8 @@ def test_grouped_mm(self, kwargs): ) offs = torch.tensor([16, 64], device=self.device_type, dtype=torch.int32) - h = torch._grouped_mm(inp, w1, offs=offs) - out = torch._grouped_mm(h, w2, offs=offs) + h = F.grouped_mm(inp, w1, offs=offs) + out = F.grouped_mm(h, w2, offs=offs) dist_inp = distribute_tensor(inp, device_mesh, kwargs["inp_placements"]) # colwise sharded @@ -585,8 +585,8 @@ def test_grouped_mm(self, kwargs): dist_offs = distribute_tensor(offs, device_mesh, [Replicate()]) with comm_mode: - dist_h = torch._grouped_mm(dist_inp, dist_w1, offs=dist_offs) - dist_out = torch._grouped_mm(dist_h, dist_w2, offs=dist_offs) + dist_h = F.grouped_mm(dist_inp, dist_w1, offs=dist_offs) + dist_out = F.grouped_mm(dist_h, dist_w2, offs=dist_offs) self.assertEqual( comm_mode.get_total_counts(), kwargs["expected_comm_counts_fwd"] ) diff --git a/test/inductor/test_cutedsl_grouped_mm.py b/test/inductor/test_cutedsl_grouped_mm.py index c26def3a54099..bd7221adc4065 100644 --- a/test/inductor/test_cutedsl_grouped_mm.py +++ b/test/inductor/test_cutedsl_grouped_mm.py @@ -9,6 +9,7 @@ from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch from torch._inductor.test_case import run_tests, TestCase as InductorTestCase from torch._inductor.utils import ensure_cute_available +from torch.nn import functional as F from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -59,7 +60,7 @@ def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int): A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype) def grouped_gemm_fn(A_packed, B_batched, offs): - return torch._grouped_mm(A_packed, B_batched, offs=offs) + return F.grouped_mm(A_packed, B_batched, offs=offs) # Eager execution c_eager = grouped_gemm_fn(A, B, offsets) @@ -126,7 +127,7 @@ def test_grouped_gemm_assorted_layouts( assert B.stride(0) == 0 def grouped_gemm_fn(A_packed, B_batched, offs): - return torch._grouped_mm(A_packed, B_batched, offs=offs) + return F.grouped_mm(A_packed, B_batched, offs=offs) # --- eager --- c_eager = grouped_gemm_fn(A, B, offsets) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 7cfb815a93d7d..a120c5b394f01 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -15165,7 +15165,7 @@ def forward( def test_grouped_mm(self): @torch.compile(fullgraph=True) def f(a, b, offs, out_dtype): - return torch._grouped_mm( + return F.grouped_mm( a, b.transpose(-2, -1), offs=offs, out_dtype=out_dtype ) diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 7a6585f3b63a8..ec1fc41547f83 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -8,6 +8,7 @@ from collections.abc import Callable import torch +import torch.nn.functional as F from torch.quantization._quantized_conversions import ( pack_int4_to_int8, @@ -404,7 +405,7 @@ def test_grouped_gemm_2d_2d(self, strided, a_row_major, b_row_major, dtype): b.requires_grad_(True) offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32) - f = torch._grouped_mm + f = F.grouped_mm out = f(a, b.t(), offs=offs, out_dtype=dtype) gO = torch.rand_like(out) out.backward(gO) @@ -456,7 +457,7 @@ def test_grouped_gemm_2d_3d(self, strided, a_row_major, b_row_major, dtype): if check_zero_size: offs[0] = offs[1] - f = torch._grouped_mm + f = F.grouped_mm out = f(a, b.transpose(-2, -1), offs=offs, out_dtype=dtype) gO = torch.rand_like(out) if not check_zero_size: @@ -501,7 +502,7 @@ def test_grouped_gemm_3d_3d(self, strided, a_row_major, b_row_major, dtype): b_contig = b if b_row_major else b.transpose(-2, -1) self.assertTrue(b_contig.is_contiguous() is not strided) - f = torch._grouped_mm + f = F.grouped_mm out = f(a, b.transpose(-2, -1), out_dtype=dtype) gO = torch.rand_like(out) out.backward(gO) @@ -541,7 +542,7 @@ def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype): if check_zero_size: offs[0] = offs[1] - f = torch._grouped_mm + f = F.grouped_mm out = f(a, b.transpose(-2, -1), offs=offs, out_dtype=dtype) gO = torch.rand_like(out) if not check_zero_size: @@ -559,7 +560,7 @@ def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype): self.grouped_mm_helper(a, blist, gOlist, agradlist, bgradlist, outlist) @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") - # TODO(future PR): enable compile for torch._grouped_mm fallback path + # TODO(future PR): enable compile for torch.nn.functional.grouped_mm fallback path @unittest.skipIf(not SM90OrLater, "Grouped gemm with compile supported on SM90") @parametrize("op", ["2d/2d", "2d/3d", "3d/2d", "3d/3d"]) @parametrize("a_row_major", [False, True]) @@ -572,7 +573,7 @@ def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major, max_autotune) align = 16 // dtype_AB.itemsize - f_ref = torch._grouped_mm + f_ref = F.grouped_mm options = {} if max_autotune: diff --git a/test/test_scaled_matmul_cuda.py b/test/test_scaled_matmul_cuda.py index 94d6ece0f6369..25c4efe35a1ab 100644 --- a/test/test_scaled_matmul_cuda.py +++ b/test/test_scaled_matmul_cuda.py @@ -11,7 +11,14 @@ import torch -from torch.nn.functional import pad, scaled_mm, scaled_grouped_mm, ScalingType, SwizzleType +from torch.nn.functional import ( + grouped_mm, + pad, + scaled_mm, + scaled_grouped_mm, + ScalingType, + SwizzleType, +) from torch.testing._internal.common_cuda import ( IS_SM90, _get_torch_cuda_version, @@ -785,7 +792,7 @@ def test_mxfp8_nvfp4_scaled_grouped_mm_2d_2d(self, G, M, N, K, format): ) # bf16 reference output - y_bf16 = torch._grouped_mm( + y_bf16 = grouped_mm( # Note: Reference result should be on reconstructed, not original values. # as-in float(fp4(t)) not t itself. xh, wh.t(), offs=input_group_end_offsets, out_dtype=torch.bfloat16 @@ -931,7 +938,7 @@ def _2d_to_blocked_scaled(X, K, G, offs, format): # Compute reference bf16 grouped gemm. # Note: Reference result should be on reconstructed, not original values. # as-in float(fp4(t)) not t itself. - y_bf16 = torch._grouped_mm( + y_bf16 = grouped_mm( xh, wh.transpose(-2, -1), offs=input_group_end_offsets, diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 07fec131d618a..0ee7b0f964fee 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -6640,6 +6640,52 @@ def multi_head_attention_forward( return attn_output, None +def grouped_mm( + mat_a: Tensor, + mat_b: Tensor, + *, + offs: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, +) -> Tensor: + r""" + grouped_mm(mat_a, mat_b, *, offs=None, bias=None, out_dtype=None) + + Computes a grouped matrix multiply that shares weight shapes across experts but + allows jagged token counts per expert, which is common in Mixture-of-Experts + (MoE) layers. Both ``mat_a`` and ``mat_b`` must be 2D or 3D tensors that already + satisfy the physical layout restrictions of grouped GEMM kernels (e.g., row-major + ``mat_a`` and column-major ``mat_b`` for FP8 inputs). Inputs are currently + expected to be ``torch.bfloat16`` values on CUDA devices with :math:`SM \ge 80`. + + Args: + mat_a: Left operand. When 2D, its leading dimension is sliced into groups + according to ``offs``. When 3D, its first dimension enumerates the groups + directly and ``offs`` must be ``None``. + mat_b: Right operand. When both operands are 2D (e.g., MoE weight-gradient + updates), the trailing dimension of ``mat_a`` and the leading dimension of + ``mat_b`` are partitioned according to the same ``offs`` tensor. For the + common forward pass (``out = input @ weight.T``) ``mat_b`` is 3D with + shape ``(num_groups, N, K)``. + offs: Optional 1D tensor of monotonically increasing ``int32`` offsets that + delimit the jagged dimension of any 2D operand. ``offs[i]`` marks the end + of group ``i`` and ``offs[-1]`` must be strictly less than the total + length of that operand's sliced dimension; elements beyond ``offs[-1]`` + are ignored. + bias: Optional tensor that is added to the grouped outputs. Bias is not + jagged and must be broadcastable to the result shape of each group. + out_dtype: Optional dtype that controls the accumulation/output dtype. + Passing ``torch.float32`` accumulates BF16 inputs in FP32 while keeping + the grouped GEMM API non-differentiable. + + Returns: + A tensor containing the concatenated results of each per-group GEMM with + shape inferred from the operands and ``offs``. + """ + + return torch._grouped_mm(mat_a, mat_b, offs=offs, bias=bias, out_dtype=out_dtype) + + def scaled_mm( mat_a: Tensor, mat_b: Tensor, diff --git a/torch/nn/functional.pyi.in b/torch/nn/functional.pyi.in index 5a3e24b115df7..58e2d65a81175 100644 --- a/torch/nn/functional.pyi.in +++ b/torch/nn/functional.pyi.in @@ -733,6 +733,17 @@ def scaled_mm( __all__ += ["scaled_mm"] +def grouped_mm( + mat_a: Tensor, + mat_b: Tensor, + *, + offs: Tensor | None = None, + bias: Tensor | None = None, + out_dtype: _dtype | None = None, +) -> Tensor: ... + +__all__ += ["grouped_mm"] + class SwizzleType(Enum): NO_SWIZZLE = 0 SWIZZLE_32_4_4 = 1 diff --git a/torch/overrides.py b/torch/overrides.py index dea75f69ea49b..22dfb67b825cc 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -252,6 +252,7 @@ def get_ignored_functions() -> set[Callable]: torch.nn.functional.has_torch_function_unary, torch.nn.functional.has_torch_function_variadic, torch.nn.functional.handle_torch_function, + torch.nn.functional.grouped_mm, torch.nn.functional.scaled_grouped_mm, torch.nn.functional.scaled_mm, torch.nn.functional.sigmoid, From 64904c29b0cf5abed48f45d3904b1cffc220b760 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Fri, 21 Nov 2025 01:17:00 +0000 Subject: [PATCH 137/230] [7/N] Use Python 3.10 typing (#167790) This PR applies new Union typing syntax to some python files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167790 Approved by: https://github.com/albanD --- torchgen/_autoheuristic/ah_tree.py | 4 ++-- torchgen/context.py | 8 ++++---- torchgen/gen.py | 3 +-- torchgen/gen_aoti_c_shim.py | 15 +++++++-------- torchgen/gen_functionalization_type.py | 4 ++-- torchgen/gen_schema_utils.py | 8 ++++---- torchgen/model.py | 4 ++-- torchgen/static_runtime/gen_static_runtime_ops.py | 4 ++-- 8 files changed, 24 insertions(+), 26 deletions(-) diff --git a/torchgen/_autoheuristic/ah_tree.py b/torchgen/_autoheuristic/ah_tree.py index c2ec2b8d94788..0afc8751e6b82 100644 --- a/torchgen/_autoheuristic/ah_tree.py +++ b/torchgen/_autoheuristic/ah_tree.py @@ -7,8 +7,8 @@ class DecisionTreeNode: def __init__( self, - feature: Optional[str] = None, - threshold: Optional[float] = None, + feature: str | None = None, + threshold: float | None = None, left: Optional["DecisionTreeNode"] = None, right: Optional["DecisionTreeNode"] = None, class_probs: Any = None, diff --git a/torchgen/context.py b/torchgen/context.py index e3725d66b9643..a99d7119c656f 100644 --- a/torchgen/context.py +++ b/torchgen/context.py @@ -2,7 +2,7 @@ import contextlib import functools -from typing import Any, Optional, TYPE_CHECKING, TypeVar, Union +from typing import Any, TYPE_CHECKING, TypeVar import torchgen.local as local from torchgen.model import ( @@ -26,15 +26,15 @@ NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup, - Union[NativeFunction, NativeFunctionsGroup], - Union[NativeFunction, NativeFunctionsViewGroup], + NativeFunction | NativeFunctionsGroup, + NativeFunction | NativeFunctionsViewGroup, ) F2 = TypeVar( "F2", NativeFunction, NativeFunctionsGroup, - Optional[NativeFunction], + NativeFunction | None, bool, str, ) diff --git a/torchgen/gen.py b/torchgen/gen.py index ae0e4b52a0fc8..2bc9ed6996705 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -97,7 +97,6 @@ if TYPE_CHECKING: from collections.abc import Callable, Sequence - from typing import Optional T = TypeVar("T") @@ -2218,7 +2217,7 @@ def gen_source_files( per_operator_headers: bool, skip_dispatcher_op_registration: bool, update_aoti_c_shim: bool, - aoti_backends: set[Optional[DispatchKey]], + aoti_backends: set[DispatchKey | None], extend_aoti_c_shim: bool, ) -> None: extra_cuda_headers = """\ diff --git a/torchgen/gen_aoti_c_shim.py b/torchgen/gen_aoti_c_shim.py index ead2a2a1cf4cc..e0724f6c3959b 100644 --- a/torchgen/gen_aoti_c_shim.py +++ b/torchgen/gen_aoti_c_shim.py @@ -31,7 +31,6 @@ if TYPE_CHECKING: from collections.abc import Sequence - from typing import Optional base_type_to_c_type = { @@ -393,7 +392,7 @@ def gen_static_dispatch_backend_call_signature( def gen_static_dispatch_backend_call( f: NativeFunction, - backend_index: Optional[BackendIndex] = None, + backend_index: BackendIndex | None = None, ) -> str: sig = DispatcherSignature.from_schema(f.func) cpp_sig = gen_static_dispatch_backend_call_signature(sig, f) @@ -421,7 +420,7 @@ def gen_static_dispatch_backend_call( def get_backend_index_for_aoti( func: NativeFunction, func_group_mapping: dict[OperatorName, NativeFunctionsGroup], - dispatch_key: Optional[DispatchKey], + dispatch_key: DispatchKey | None, backend_indices: dict[DispatchKey, BackendIndex], extend_aoti_c_shim: bool, ) -> BackendIndex | None: @@ -463,7 +462,7 @@ def get_backend_index_for_aoti( def get_header_for_aoti( func: NativeFunction, func_group_mapping: dict[OperatorName, NativeFunctionsGroup], - dispatch_key: Optional[DispatchKey], + dispatch_key: DispatchKey | None, backend_indices: dict[DispatchKey, BackendIndex], extend_aoti_c_shim: bool, ) -> str | None: @@ -490,7 +489,7 @@ def gen_c_shim( func: NativeFunction, version_info: dict[str, list[str]], func_group_mapping: dict[OperatorName, NativeFunctionsGroup], - dispatch_key: Optional[DispatchKey], + dispatch_key: DispatchKey | None, backend_indices: dict[DispatchKey, BackendIndex], header: bool, extend_aoti_c_shim: bool, @@ -528,7 +527,7 @@ def gen_c_shim( class ShimGenerator: inductor_fallback_ops: dict[str, dict[str, list[str]]] func_group_mapping: dict[OperatorName, NativeFunctionsGroup] - dispatch_key: Optional[DispatchKey] + dispatch_key: DispatchKey | None backend_indices: dict[DispatchKey, BackendIndex] header: bool # True to generate .h and False to generate .cpp extend_aoti_c_shim: bool @@ -555,7 +554,7 @@ def gen_aoti_c_shim( native_functions: Sequence[NativeFunction], inductor_fallback_ops: dict[str, dict[str, list[str]]], func_group_mapping: dict[OperatorName, NativeFunctionsGroup], - dispatch_key: Optional[DispatchKey], + dispatch_key: DispatchKey | None, backend_indices: dict[DispatchKey, BackendIndex], header: bool, extend_aoti_c_shim: bool, @@ -646,7 +645,7 @@ def gen_aoti_c_shim( def gen_aoti_c_shim_files( aoti_fm: FileManager, - aoti_backends: set[Optional[DispatchKey]], + aoti_backends: set[DispatchKey | None], native_functions: Sequence[NativeFunction], backend_indices: dict[DispatchKey, BackendIndex], structured_native_functions: Sequence[NativeFunctionsGroup], diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index 1cb681ba19d34..0ef91332df9ff 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from torchgen.api import cpp, dispatcher, functionalization from torchgen.api.translate import translate @@ -928,7 +928,7 @@ def new(self, out_index: str = "0") -> str: def map( g: NativeFunctionsViewGroup, run: Callable[[ViewMetaSpecialization], list[str]] ) -> list[str]: - def maybe_run(f: Optional[NativeFunction]) -> list[str]: + def maybe_run(f: NativeFunction | None) -> list[str]: if f is None: return [] with native_function_manager(f): diff --git a/torchgen/gen_schema_utils.py b/torchgen/gen_schema_utils.py index b81c91527baa1..1238a5a5a3933 100644 --- a/torchgen/gen_schema_utils.py +++ b/torchgen/gen_schema_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any from torchgen.model import ( Annotation, @@ -29,7 +29,7 @@ class TypeGen: } @staticmethod - def from_example(obj: Any) -> Union[BaseType, ListType, CustomClassType]: + def from_example(obj: Any) -> BaseType | ListType | CustomClassType: import torch if isinstance(obj, torch.fx.GraphModule): @@ -61,7 +61,7 @@ def from_example(obj: Any) -> Union[BaseType, ListType, CustomClassType]: class ReturnGen: @staticmethod def from_example( - name: Optional[str], obj: Any, annotation: Optional[Annotation] + name: str | None, obj: Any, annotation: Annotation | None ) -> Return: return Return(name, TypeGen.from_example(obj), annotation) @@ -69,7 +69,7 @@ def from_example( class ArgumentGen: @staticmethod def from_example( - name: str, obj: Any, default: Optional[str], annotation: Optional[Annotation] + name: str, obj: Any, default: str | None, annotation: Annotation | None ) -> Argument: return Argument( name, TypeGen.from_example(obj), default=default, annotation=annotation diff --git a/torchgen/model.py b/torchgen/model.py index 906b61e2f19cc..7971b893e7585 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -5,7 +5,7 @@ import re from dataclasses import dataclass from enum import auto, Enum -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from typing_extensions import assert_never from torchgen.utils import NamespaceHelper, OrderedSet @@ -2563,7 +2563,7 @@ class BaseOperatorName: # as part of the base operator name, for __str__() to consume. # The canonical input (from the rest of the infra) will not contain namespace, but # we have a usecase in ExecuTorch where we want to support BaseOperatorName with namespace. - namespace: Optional[str] = None + namespace: str | None = None @staticmethod def parse(op: str) -> BaseOperatorName: diff --git a/torchgen/static_runtime/gen_static_runtime_ops.py b/torchgen/static_runtime/gen_static_runtime_ops.py index e35221c3f50eb..d6909bc4d7f67 100644 --- a/torchgen/static_runtime/gen_static_runtime_ops.py +++ b/torchgen/static_runtime/gen_static_runtime_ops.py @@ -3,7 +3,7 @@ import argparse import itertools import os -from typing import TYPE_CHECKING, TypeVar, Union +from typing import TYPE_CHECKING, TypeVar from libfb.py.log import set_simple_logging # type: ignore[import] @@ -23,7 +23,7 @@ NativeGroupT = TypeVar( "NativeGroupT", - bound=Union[NativeFunctionsGroup, NativeFunctionsViewGroup], + bound=NativeFunctionsGroup | NativeFunctionsViewGroup, ) From 9f10cb067874255c56c263f3e96cd32bdff72554 Mon Sep 17 00:00:00 2001 From: linhaifeng <1371675203@qq.com> Date: Fri, 21 Nov 2025 01:44:30 +0000 Subject: [PATCH 138/230] [BugFix] Fix incorrect usage of const_data_ptr in memcpy (#168233) Inspired by #168165 1st argument of memcpy should be a mutable pointer, therefore replacing it with TensorBase::mutable_data_ptr Pull Request resolved: https://github.com/pytorch/pytorch/pull/168233 Approved by: https://github.com/cyyever, https://github.com/Skylion007 --- aten/src/ATen/native/cuda/CUDAScalar.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/cuda/CUDAScalar.cu b/aten/src/ATen/native/cuda/CUDAScalar.cu index 0d34bd52f211a..169a2ab92615f 100644 --- a/aten/src/ATen/native/cuda/CUDAScalar.cu +++ b/aten/src/ATen/native/cuda/CUDAScalar.cu @@ -29,7 +29,7 @@ Scalar _local_scalar_dense_cuda(const Tensor& self) { std::nullopt /* memory format */ ); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - at::cuda::memcpy_and_sync((void *)value.const_data_ptr(), self.const_data_ptr(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream); + at::cuda::memcpy_and_sync(value.mutable_data_ptr(), self.const_data_ptr(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream); r = Scalar(*value.const_data_ptr()); }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); return r; From 064f80dfa0482f6bd365a2f7db2e9c2f9f3ea88c Mon Sep 17 00:00:00 2001 From: atalman Date: Fri, 21 Nov 2025 01:51:31 +0000 Subject: [PATCH 139/230] Smoke test numpy coverage in nightlies (#168270) Make sure numpy failures are visible on cd Pull Request resolved: https://github.com/pytorch/pytorch/pull/168270 Approved by: https://github.com/albanD, https://github.com/malfet --- .circleci/scripts/binary_linux_test.sh | 30 +++++++++++--------------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/.circleci/scripts/binary_linux_test.sh b/.circleci/scripts/binary_linux_test.sh index c24a50b8b17ed..58d0af29e133b 100755 --- a/.circleci/scripts/binary_linux_test.sh +++ b/.circleci/scripts/binary_linux_test.sh @@ -31,23 +31,6 @@ if [[ "$PACKAGE_TYPE" != libtorch ]]; then export PATH="\${python_path}/bin:\$PATH" fi -EXTRA_CONDA_FLAGS="" -NUMPY_PIN="" -PROTOBUF_PACKAGE="defaults::protobuf" - -if [[ "\$python_nodot" = *310* ]]; then - # There's an issue with conda channel priority where it'll randomly pick 1.19 over 1.20 - # we set a lower boundary here just to be safe - NUMPY_PIN=">=1.21.2" - PROTOBUF_PACKAGE="protobuf>=3.19.0" -fi - -if [[ "\$python_nodot" = *39* ]]; then - # There's an issue with conda channel priority where it'll randomly pick 1.19 over 1.20 - # we set a lower boundary here just to be safe - NUMPY_PIN=">=1.20" -fi - # Move debug wheels out of the package dir so they don't get installed mkdir -p /tmp/debug_final_pkgs mv /final_pkgs/debug-*.zip /tmp/debug_final_pkgs || echo "no debug packages to move" @@ -66,12 +49,23 @@ fi if [[ "$PACKAGE_TYPE" != libtorch ]]; then if [[ "\$BUILD_ENVIRONMENT" != *s390x* ]]; then pip install "\$pkg" --index-url "https://download.pytorch.org/whl/\${CHANNEL}/${DESIRED_CUDA}" - retry pip install -q numpy protobuf typing-extensions + + # numpy tests: + # We test 1 version no numpy. 1 version with numpy 1.x and rest with numpy 2.x + if [[ "\$python_nodot" = *311* ]]; then + retry pip install -q protobuf typing-extensions + elif [[ "\$python_nodot" = *312* ]]; then + retry pip install -q numpy==1.21.2 protobuf typing-extensions + else + retry pip install -q numpy protobuf typing-extensions + fi + else pip install "\$pkg" retry pip install -q numpy protobuf typing-extensions fi fi + if [[ "$PACKAGE_TYPE" == libtorch ]]; then pkg="\$(ls /final_pkgs/*-latest.zip)" unzip "\$pkg" -d /tmp From 7ebca682975ad5283cce065d73958034e53e2b7f Mon Sep 17 00:00:00 2001 From: Jithun Nair Date: Fri, 21 Nov 2025 02:02:25 +0000 Subject: [PATCH 140/230] [ROCm][CI] Move periodic-rocm-mi300 and inductor-rocm-mi300 to Ubuntu noble images (#168230) This will ensure post-submit MI3xx default/distributed/inductor config jobs will consistently test on Ubuntu noble and py3.12 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168230 Approved by: https://github.com/jeffdaily --- .github/workflows/inductor-rocm-mi300.yml | 20 ++++++++++---------- .github/workflows/periodic-rocm-mi300.yml | 20 ++++++++++---------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/.github/workflows/inductor-rocm-mi300.yml b/.github/workflows/inductor-rocm-mi300.yml index dee10a0db3c16..57e5cb856729a 100644 --- a/.github/workflows/inductor-rocm-mi300.yml +++ b/.github/workflows/inductor-rocm-mi300.yml @@ -38,14 +38,14 @@ jobs: curr_ref_type: ${{ github.ref_type }} opt_out_experiments: lf - linux-jammy-rocm-py3_10-inductor-build: - name: rocm-py3.10-inductor-mi300 + linux-noble-rocm-py3_12-inductor-build: + name: rocm-py3.12-inductor-mi300 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-rocm-py3.10-mi300 - docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 + build-environment: linux-noble-rocm-py3.12-mi300 + docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3 test-matrix: | { include: [ { config: "inductor", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, @@ -53,15 +53,15 @@ jobs: ]} secrets: inherit - linux-jammy-rocm-py3_10-inductor-test: + linux-noble-rocm-py3_12-inductor-test: permissions: id-token: write contents: read - name: rocm-py3.10-inductor-mi300 + name: rocm-py3.12-inductor-mi300 uses: ./.github/workflows/_rocm-test.yml - needs: linux-jammy-rocm-py3_10-inductor-build + needs: linux-noble-rocm-py3_12-inductor-build with: - build-environment: linux-jammy-rocm-py3.10-mi300 - docker-image: ${{ needs.linux-jammy-rocm-py3_10-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-rocm-py3_10-inductor-build.outputs.test-matrix }} + build-environment: linux-noble-rocm-py3.12-mi300 + docker-image: ${{ needs.linux-noble-rocm-py3_12-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-noble-rocm-py3_12-inductor-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/periodic-rocm-mi300.yml b/.github/workflows/periodic-rocm-mi300.yml index 12a20a2993f8d..f3356cfa4fc77 100644 --- a/.github/workflows/periodic-rocm-mi300.yml +++ b/.github/workflows/periodic-rocm-mi300.yml @@ -50,14 +50,14 @@ jobs: curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} - linux-jammy-rocm-py3_10-build: - name: linux-jammy-rocm-py3.10-mi300 + linux-noble-rocm-py3_12-build: + name: linux-noble-rocm-py3.12-mi300 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-rocm-py3.10-mi300 - docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 + build-environment: linux-noble-rocm-py3.12-mi300 + docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3 test-matrix: | { include: [ { config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4.b", owners: ["module:rocm", "oncall:distributed"] }, @@ -66,17 +66,17 @@ jobs: ]} secrets: inherit - linux-jammy-rocm-py3_10-test: + linux-noble-rocm-py3_12-test: permissions: id-token: write contents: read - name: linux-jammy-rocm-py3.10-mi300 + name: linux-noble-rocm-py3.12-mi300 uses: ./.github/workflows/_rocm-test.yml needs: - - linux-jammy-rocm-py3_10-build + - linux-noble-rocm-py3_12-build - target-determination with: - build-environment: linux-jammy-rocm-py3.10-mi300 - docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} + build-environment: linux-noble-rocm-py3.12-mi300 + docker-image: ${{ needs.linux-noble-rocm-py3_12-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-noble-rocm-py3_12-build.outputs.test-matrix }} secrets: inherit From a7f3b10866098c452d89cd7a30bc4ce5713b8319 Mon Sep 17 00:00:00 2001 From: "Andy (An) Wang" Date: Fri, 21 Nov 2025 02:08:46 +0000 Subject: [PATCH 141/230] [Full Inductor][Pytorch] Prevent decomposition and enable fallback of aten.native_layer_norm for MTIA (#168290) Summary: MTIA-Triton currently doesn't support aten.native_layer_norm and we need Inductor to fallback it to Aten. Currently `make_fallback` doesn't work for aten.native_layer_norm due to decomposition. This PR prevents the decomposition, following the PR [#151637](https://github.com/pytorch/pytorch/pull/151637) where XPU enabled fallback for embedding_dense_backward. Test Plan: Ran Full Inductor with MAI Differential Revision: D87566269 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168290 Approved by: https://github.com/jansel, https://github.com/blaine-rister, https://github.com/eellison --- torch/_inductor/decomposition.py | 16 ++++++++++++++++ torch/_inductor/lowering.py | 5 +++++ 2 files changed, 21 insertions(+) diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 3cedad185c3f2..db9c8f5f0333c 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -35,6 +35,7 @@ ELEMENTWISE_TYPE_PROMOTION_KIND, type_to_dtype, ) +from torch._refs import native_layer_norm as decomp_native_layer_norm from torch.fx.experimental.symbolic_shapes import guard_or_false, statically_known_true from . import config, inductor_prims @@ -118,6 +119,7 @@ aten.clamp_max, aten.clamp_min, aten.embedding_dense_backward, # we fall back on xpu + aten.native_layer_norm, # we fall back on mtia aten.index_add, # we conditionally call this decomp aten.glu, # inductor lowers this directly aten.select_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass @@ -159,6 +161,20 @@ def _embedding_dense_backward( ) +@register_decomposition(aten.native_layer_norm) +def _native_layer_norm( + input: torch.Tensor, + normalized_shape: utils.ShapeType, + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + eps: float, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if input.is_mtia: + return NotImplemented + # We can write a util function to update decomp table if we have more ops to fallback. + return decomp_native_layer_norm(input, normalized_shape, weight, bias, eps) + + @register_decomposition([aten.sym_constrain_range_for_size.default]) def sym_constrain_range_for_size( symbol: torch.SymInt, diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index d374be59c9446..7eafc45036b10 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2902,6 +2902,11 @@ def is_aligned(x): aten.embedding_dense_backward, warn=False ) # (XPU-only and faster than decomp) +if torch.mtia.is_available(): + make_fallback( + aten.native_layer_norm, warn=False + ) # (MTIA-only and faster than decomp) + # 1.5) Easy or Impossible make_fallback(aten._cdist_forward) # p=2 should be feasible From 2ae4b850d8a7f5c93d26bcd703a7879bf7a2c5fe Mon Sep 17 00:00:00 2001 From: Rob Timpe Date: Thu, 20 Nov 2025 00:50:14 +0000 Subject: [PATCH 142/230] [3.14] Update profiler test (#168205) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168205 Approved by: https://github.com/guilhermeleobas, https://github.com/williamwen42 --- test/profiler/test_profiler_tree.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/test/profiler/test_profiler_tree.py b/test/profiler/test_profiler_tree.py index e8d28d7eff032..3c5ef2aeeb83c 100644 --- a/test/profiler/test_profiler_tree.py +++ b/test/profiler/test_profiler_tree.py @@ -27,13 +27,16 @@ PRUNE_ALL = 1 KEEP_ELLIPSES = 2 KEEP_NAME_AND_ELLIPSES = 3 +IGNORE = 4 PRUNE_FUNCTIONS = { "torch/utils/_pytree.py(...): tree_map": KEEP_NAME_AND_ELLIPSES, "torch/profiler/profiler.py(...): start": KEEP_ELLIPSES, "torch/profiler/profiler.py(...): stop_trace": KEEP_ELLIPSES, "torch/profiler/profiler.py(...): _transit_action": KEEP_ELLIPSES, + "": PRUNE_ALL, "": PRUNE_ALL, + "": IGNORE, "cudaStreamIsCapturing": PRUNE_ALL, # These show up only on CUDA, prune them so the CUDA and CPU expected results can be the same "cudaGetDeviceCount": PRUNE_ALL, @@ -117,6 +120,8 @@ def flatten(nodes, depth=0, out=None): if prune_level is None: out.append((depth, name)) flatten(node.children, depth + 1, out) + elif prune_level == IGNORE: + flatten(node.children, depth, out) elif prune_level == KEEP_NAME_AND_ELLIPSES: out.append((depth, name)) if node.children: @@ -720,10 +725,9 @@ def test_profiler_experimental_tree_with_stack_and_torch_function(self): test_profiler_tree.py(...): __torch_function__ torch/_tensor.py(...): __torch_function__ - - torch/_tensor.py(...): - - torch/_tensor.py(...): + torch/_tensor.py(...): + + torch/_tensor.py(...): aten::add torch/_tensor.py(...): _convert From a60eb2da06a334de1cd58e73d329daf89115dd0b Mon Sep 17 00:00:00 2001 From: dolpm <34420038+dolpm@users.noreply.github.com> Date: Fri, 21 Nov 2025 03:00:47 +0000 Subject: [PATCH 143/230] fix philoxstate bad cast (#168310) we were seeing some issues where a seed from one state was used in another state and it was being interpreted as a negative. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168310 Approved by: https://github.com/dzmitry-huba --- test/distributed/tensor/test_random_ops.py | 19 +++++++++++++++++++ torch/distributed/tensor/_random.py | 2 +- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/test/distributed/tensor/test_random_ops.py b/test/distributed/tensor/test_random_ops.py index 4ff470511f2ad..4bcddc198836b 100644 --- a/test/distributed/tensor/test_random_ops.py +++ b/test/distributed/tensor/test_random_ops.py @@ -634,6 +634,25 @@ def blockwise_iter_if_localtensor(local_tensor, local_shard_offset): blockwise_iter_if_localtensor(local_tensor, local_shard_offset) + def test_philox_state_seed_roundtrip(self): + """ + Test that _PhiloxState seed can be read and re-set without error. + + This test addresses the issue where reading a seed value from the state + (which uses uint64 view) and then re-setting it would fail with: + OverflowError: can't convert negative int to unsigned + + The fix ensures the seed getter uses uint64 view, preventing negative + values from appearing when the high bit is set. + """ + from torch.distributed.tensor._random import _PhiloxState + + state = torch.zeros(16, dtype=torch.uint8, device="cpu") + philox = _PhiloxState(state) + test_seed = 2**63 + 42 # This has the sign bit set when viewed as int64 + philox.seed = test_seed + philox.seed = philox.seed + class DistTensorRandomOpsTest3D(DTensorTestBase): @property diff --git a/torch/distributed/tensor/_random.py b/torch/distributed/tensor/_random.py index 42bf1ebeebf0e..d117df2d67e2e 100644 --- a/torch/distributed/tensor/_random.py +++ b/torch/distributed/tensor/_random.py @@ -135,7 +135,7 @@ def offset(self, offset: int) -> None: @property def seed(self) -> int: - return int(self._state[:8].view(dtype=torch.int64).item()) + return int(self._state[:8].view(dtype=torch.uint64).item()) @seed.setter def seed(self, seed: int) -> None: From 8ad78bbc2b9c3606f05b882d3e7563c3dd6c8b88 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Thu, 20 Nov 2025 17:48:51 -0500 Subject: [PATCH 144/230] Revert C++ fastpath dispatch path for DTensor (#168264) ``` git revert --no-commit 567dcdba757 200156e3850 3d801a4c01f 2034ca99ae7 480b4ff8828 f570e589da1 ``` And Revert "[DTensor] Document fast-path dispatch (#168192)" And Revert "[DTensor] Fix deadlock after fast cache clear (#168069)" Reverts: * #167860 * #167588 * #167475 * #166808 * #166372 * #168192 * #168069 Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/168264 Approved by: https://github.com/seemethere, https://github.com/malfet --- test/cpp/jit/test_custom_operators.cpp | 13 +- test/custom_operator/test_custom_ops.cpp | 2 +- test/distributed/tensor/test_op_strategy.py | 31 +- test/distributed/tensor/test_tensor_ops.py | 40 +- torch/_C/__init__.pyi.in | 2 - torch/csrc/PyInterpreter.cpp | 2 - torch/csrc/autograd/python_variable.cpp | 1168 +---------------- torch/csrc/autograd/python_variable.h | 9 - torch/csrc/jit/frontend/schema_matching.cpp | 2 +- torch/csrc/jit/ir/alias_analysis.cpp | 2 +- torch/csrc/jit/ir/ir.cpp | 2 +- torch/csrc/jit/python/init.cpp | 6 +- torch/csrc/jit/runtime/operator.cpp | 70 +- torch/csrc/jit/runtime/operator.h | 5 +- .../jit/runtime/symbolic_shape_registry.cpp | 2 +- torch/csrc/utils/python_arg_parser.cpp | 82 +- torch/csrc/utils/python_arg_parser.h | 12 - torch/distributed/_tools/mem_tracker.py | 1 + torch/distributed/tensor/_api.py | 42 +- torch/distributed/tensor/_dispatch.py | 86 +- torch/distributed/tensor/debug/__init__.py | 27 +- 21 files changed, 186 insertions(+), 1420 deletions(-) diff --git a/test/cpp/jit/test_custom_operators.cpp b/test/cpp/jit/test_custom_operators.cpp index 66295d0380629..58f87717844de 100644 --- a/test/cpp/jit/test_custom_operators.cpp +++ b/test/cpp/jit/test_custom_operators.cpp @@ -15,7 +15,7 @@ namespace jit { TEST(CustomOperatorTest, InferredSchema) { torch::RegisterOperators reg( "foo::bar", [](double a, at::Tensor b) { return a + b; }); - auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar")); + auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar")); ASSERT_EQ(ops.size(), 1); auto& op = ops.front(); @@ -43,7 +43,8 @@ TEST(CustomOperatorTest, ExplicitSchema) { "foo::bar_with_schema(float a, Tensor b) -> Tensor", [](double a, at::Tensor b) { return a + b; }); - auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema")); + auto& ops = + getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema")); ASSERT_EQ(ops.size(), 1); auto& op = ops.front(); @@ -76,7 +77,7 @@ TEST(CustomOperatorTest, ListParameters) { torch::List> complexdoubles, torch::List tensors) { return floats; }); - auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists")); + auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists")); ASSERT_EQ(ops.size(), 1); auto& op = ops.front(); @@ -122,7 +123,7 @@ TEST(CustomOperatorTest, ListParameters2) { "foo::lists2(Tensor[] tensors) -> Tensor[]", [](torch::List tensors) { return tensors; }); - auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists2")); + auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists2")); ASSERT_EQ(ops.size(), 1); auto& op = ops.front(); @@ -212,7 +213,7 @@ TEST(TestCustomOperator, OperatorGeneratorUndeclared) { }, aliasAnalysisFromSchema())}); - auto ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::not_exist")); + auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::not_exist")); ASSERT_EQ(ops.size(), 0); } @@ -231,7 +232,7 @@ TEST(TestCustomOperator, OperatorGeneratorBasic) { }, aliasAnalysisFromSchema())}); - auto ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::bar")); + auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::bar")); ASSERT_EQ(ops.size(), 1); auto& op = ops.front(); diff --git a/test/custom_operator/test_custom_ops.cpp b/test/custom_operator/test_custom_ops.cpp index 9791006d1498f..a526bebd26144 100644 --- a/test/custom_operator/test_custom_ops.cpp +++ b/test/custom_operator/test_custom_ops.cpp @@ -22,7 +22,7 @@ void check_all_parameters( template Result get_operator_from_registry_and_execute(const char* op_name, Args&&... args) { - auto ops = torch::jit::getAllOperatorsFor( + auto& ops = torch::jit::getAllOperatorsFor( torch::jit::Symbol::fromQualString(op_name)); TORCH_INTERNAL_ASSERT(ops.size() == 1); diff --git a/test/distributed/tensor/test_op_strategy.py b/test/distributed/tensor/test_op_strategy.py index 72d95efcfa8c9..139f5fb61fac8 100644 --- a/test/distributed/tensor/test_op_strategy.py +++ b/test/distributed/tensor/test_op_strategy.py @@ -34,11 +34,7 @@ register_op_strategy, replicate_op_strategy, ) -from torch.distributed.tensor.debug import ( - _clear_fast_path_sharding_prop_cache, - _clear_python_sharding_prop_cache, - CommDebugMode, -) +from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.distributed._tensor.common_dtensor import ( create_local_tensor_test_class, @@ -483,8 +479,7 @@ def op_strategy_context(op_overload, strategy_func, schema_info=None): del propagator.op_to_schema_info[op_overload] else: propagator.op_to_schema_info[op_overload] = _origin_op_strategy_schema - _clear_fast_path_sharding_prop_cache() - _clear_python_sharding_prop_cache() + propagator.propagate_op_sharding.cache.cache_clear() def detect_exists_identical_opspec(*args, op, mesh, strategy_function) -> bool: @@ -650,28 +645,6 @@ def test_call_with_different_nontensor_args(self): self.assertEqual(out1.full_tensor(), out2.full_tensor()) -class TestStrategyOperation(DTensorTestBase): - @property - def world_size(self): - return 2 - - @with_comms - def test_cache_clean(self): - mesh = self.build_device_mesh() - test_op = torch.ops.mylib.numpy_sin - x = torch.randn(2, device=self.device_type) - y = torch.randn(2, device=self.device_type) - x_dt = distribute_tensor(x, mesh, [Shard(0)]) - y_dt = distribute_tensor(y, mesh, [Shard(0)]) - with op_strategy_context(test_op.default, replicate_op_strategy): - self._test_op_on_dtensor(test_op, x_dt, y_dt) - with self.assertRaisesRegex( - NotImplementedError, - f"Operator {test_op.default} does not have a sharding strategy registered", - ): - self._test_op_on_dtensor(test_op, x_dt, y_dt) - - DistTensorReplicateStrategyRegistrationTestWithLocalTensor = ( create_local_tensor_test_class( DistTensorReplicateStrategyRegistrationTest, diff --git a/test/distributed/tensor/test_tensor_ops.py b/test/distributed/tensor/test_tensor_ops.py index 4748db4f7377b..fc0a2b16955ca 100644 --- a/test/distributed/tensor/test_tensor_ops.py +++ b/test/distributed/tensor/test_tensor_ops.py @@ -2,6 +2,7 @@ # Owner(s): ["oncall: distributed"] import itertools +import unittest import torch from torch.distributed.tensor import ( @@ -13,6 +14,7 @@ Replicate, Shard, ) +from torch.distributed.tensor._sharding_prop import ShardingPropagator from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_utils import run_tests, skipIfRocm @@ -334,6 +336,34 @@ def test_stack(self): torch.stack([global_input, global_input], dim=1), ) + @with_comms + def test_stack_cache(self): + device_mesh = self.build_device_mesh() + + shape = (4, 8) + placements = [Replicate()] + dtensor_list = [] + for _ in range(3): + local_tensor = torch.randn(shape) + dt = DTensor.from_local(local_tensor, device_mesh, placements) + dtensor_list.append(dt) + + _ = torch.stack(dtensor_list) + + dtensor_list2 = [] + for _ in range(3): + local_tensor = torch.randn(shape) + dt = DTensor.from_local(local_tensor, device_mesh, placements) + dtensor_list2.append(dt) + + def error(*args, **kwargs): + raise AssertionError + + with unittest.mock.patch.object( + ShardingPropagator, "_propagate_tensor_meta_non_cached", error + ): + _ = torch.stack(dtensor_list2) + @with_comms def test_equal(self): device_mesh = self.build_device_mesh() @@ -706,11 +736,11 @@ def test_where_type_promotion(self): @with_comms def test_dtensor_dtype_conversion(self): from torch.distributed.tensor.debug import ( - _clear_fast_path_sharding_prop_cache, - _get_fast_path_sharding_prop_cache_stats, + _clear_sharding_prop_cache, + _get_sharding_prop_cache_info, ) - _clear_fast_path_sharding_prop_cache() + _clear_sharding_prop_cache() device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] # by default we start from bf16 dtype @@ -730,13 +760,13 @@ def test_dtensor_dtype_conversion(self): self.assertEqual(bf16_sharded_dtensor1.to_local().dtype, torch.bfloat16) # by this point we only have cache misses - hits, misses = _get_fast_path_sharding_prop_cache_stats() + hits, misses, _, _ = _get_sharding_prop_cache_info() self.assertEqual(hits, 0) self.assertEqual(misses, 2) # convert to fp32 again and see if there's cache hit bf16_sharded_dtensor1.float() - hits, misses = _get_fast_path_sharding_prop_cache_stats() + hits, misses, _, _ = _get_sharding_prop_cache_info() # by now we should have cache hit self.assertEqual(hits, 1) self.assertEqual(misses, 2) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index e9b58b9ce71eb..1af6df5e7664a 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1967,8 +1967,6 @@ def _DTensor_OpSchema_recompute_comparison_key(self: OpSchema) -> None: ... def _DTensor_compute_global_tensor_info( tensor: Tensor, mesh: DeviceMesh, placements: Sequence[Placement] ) -> tuple[list[_int], list[_int]]: ... -def _get_DTensor_sharding_propagator_cache_stats() -> tuple[_int, _int]: ... -def _clear_DTensor_sharding_propagator_cache() -> None: ... # Defined in torch/csrc/multiprocessing/init.cpp def _multiprocessing_init() -> None: ... diff --git a/torch/csrc/PyInterpreter.cpp b/torch/csrc/PyInterpreter.cpp index 7f36d88bdaa32..8a2e0d533ff0c 100644 --- a/torch/csrc/PyInterpreter.cpp +++ b/torch/csrc/PyInterpreter.cpp @@ -338,8 +338,6 @@ void ConcretePyInterpreterVTable::dispatch( nullptr, torch_api_function_overload.ptr(), nullptr, - &op, - &arguments, TorchFunctionName::TorchDispatch); pushPyOutToStack( op, stack, py::reinterpret_steal(obj), "__torch_dispatch__"); diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index de7f3dc53c323..8165fd910c2c1 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -1,13 +1,8 @@ -#include #include -#include #include -#include #include #include #include -#include -#include #include #include #include @@ -45,6 +40,7 @@ #include +#include #include #include #include @@ -824,84 +820,29 @@ static PyObject* THPVariable_make_wrapper_subclass( END_HANDLE_TH_ERRORS } +static py::handle get_dtensor_spec_class() { #if IS_PYBIND_2_13_PLUS -#define DEFINE_CACHING_PYTHON_IMPORT_GETTER(name, import_expr) \ - static py::handle name() { \ - PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store \ - storage; \ - return storage \ - .call_once_and_store_result( \ - []() -> py::object { return import_expr; }) \ - .get_stored(); \ - } + PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store + storage; + return storage + .call_once_and_store_result([]() -> py::object { + return py::module::import("torch") + .attr("distributed") + .attr("tensor") + .attr("_dtensor_spec") + .attr("DTensorSpec"); + }) + .get_stored(); #else -#define DEFINE_CACHING_PYTHON_IMPORT_GETTER(name, import_expr) \ - static py::handle name() { \ - static py::handle storage = py::object(import_expr).release(); \ - return storage; \ - } + static py::handle dtensor_spec_class = py::object(py::module::import("torch") + .attr("distributed") + .attr("tensor") + .attr("_dtensor_spec") + .attr("DTensorSpec")) + .release(); + return dtensor_spec_class; #endif - -DEFINE_CACHING_PYTHON_IMPORT_GETTER( - get_dtensor_class_impl, - py::module::import("torch.distributed.tensor").attr("DTensor")) - -py::handle get_dtensor_class() { - return get_dtensor_class_impl(); -} - -DEFINE_CACHING_PYTHON_IMPORT_GETTER( - get_dtensor_spec_class, - py::module::import("torch.distributed.tensor") - .attr("_dtensor_spec") - .attr("DTensorSpec")) - -DEFINE_CACHING_PYTHON_IMPORT_GETTER( - get_replicate_class, - py::module::import("torch.distributed.tensor") - .attr("placement_types") - .attr("Replicate")) - -DEFINE_CACHING_PYTHON_IMPORT_GETTER( - get_tensor_meta_class, - py::module::import("torch.distributed.tensor") - .attr("_dtensor_spec") - .attr("TensorMeta")) - -DEFINE_CACHING_PYTHON_IMPORT_GETTER( - get_dtensor_op_dispatcher, - py::module::import("torch.distributed.tensor") - .attr("DTensor") - .attr("_op_dispatcher")) - -DEFINE_CACHING_PYTHON_IMPORT_GETTER( - get_dtensor_dispatch, - py::module::import("torch.distributed.tensor") - .attr("DTensor") - .attr("_op_dispatcher") - .attr("_dispatch_fast_path_python_tail")) - -DEFINE_CACHING_PYTHON_IMPORT_GETTER( - get_dtensor_dispatcher_wrap, - py::module::import("torch.distributed.tensor") - .attr("DTensor") - .attr("_op_dispatcher") - .attr("wrap")) - -DEFINE_CACHING_PYTHON_IMPORT_GETTER( - get_dtensor_get_local_results_slow_path, - py::module::import("torch") - .attr("distributed") - .attr("tensor") - .attr("DTensor") - .attr("_op_dispatcher") - .attr("_dispatch_get_local_results_slow_path")) - -DEFINE_CACHING_PYTHON_IMPORT_GETTER( - get_output_sharding_class, - py::module::import("torch.distributed.tensor") - .attr("_op_schema") - .attr("OutputSharding")) +} static bool arg_type_tensor_or_tensor_list_like(py::handle arg) { const auto dtensor_spec_class = get_dtensor_spec_class(); @@ -929,26 +870,13 @@ static bool arg_type_tensor_or_tensor_list_like(py::handle arg) { #define FOR_EACH_DTENSOR_INTERNED_STRING(_) \ MAYBE_FOR_EACH_PYTHON_3_10_MINUS_DTENSOR_INTERNED_STRING(_) \ _(_comparison_key) \ - _(_custom_op_handlers) \ _(_local_tensor) \ _(_spec) \ - _(_unwrap_to_op_info_impl) \ _(args_schema) \ - _(compute_mesh) \ - _(device_mesh) \ - _(dtype) \ - _(get_coordinate) \ _(kwargs_schema) \ - _(ndim) \ - _(needs_pytree) \ - _(needs_redistribute) \ _(op) \ - _(op_to_schema_info) \ - _(output_sharding) \ - _(output_spec) \ _(schema_info) \ _(shape) \ - _(sharding_propagator) \ _(size) \ _(static_argnum) \ _(static_kwargkey) \ @@ -963,7 +891,6 @@ struct DTensorInternedStrings { static DTensorInternedStrings dtensor_interned_strings; -#ifdef USE_DISTRIBUTED static bool intern_dtensor_strings() { #define INTERN_DTENSOR_STRING(s) \ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dtensor_interned_strings.s == nullptr); \ @@ -976,7 +903,6 @@ static bool intern_dtensor_strings() { #undef INTERN_DTENSOR_STRING return true; } -#endif static bool checked_not(PyObject* obj) { int result = PyObject_Not(obj); @@ -986,36 +912,6 @@ static bool checked_not(PyObject* obj) { return result; } -static bool checked_istrue(PyObject* obj) { - int result = PyObject_IsTrue(obj); - if (result == -1) { - throw py::error_already_set(); - } - return result; -} - -// pybind11 does not not use PyObject_Vectorcall currently; it seems -// to materialize a tuple of args instead. -template -static py::object checked_vectorcall( - PyObject* obj, - std::array args) { - PyObject* result = PyObject_Vectorcall(obj, args.data(), N, nullptr); - if (!result) { - throw py::error_already_set(); - } - return py::reinterpret_steal(result); -} - -template -static py::object checked_vectorcall(PyObject* obj, Args... args) { - static_assert( - (std::is_same_v && ...), - "must pass PyObject* to checked_vectorcall!"); - std::array arr = {args...}; - return checked_vectorcall(obj, arr); -} - static c10::SymDimVector tuple_to_symintlist(PyObject* obj) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(PyTuple_Check(obj)); c10::SymDimVector res; @@ -1036,579 +932,6 @@ static c10::SymDimVector tuple_to_symintlist(PyObject* obj) { return res; } -// As a Python object, DTensorSpec can be stored directly within -// IValue, but doing so is inefficient -- it requires a -// heap-allocated, reference counted intermediate -// ivalue::PyObjectHolder. -// Representation options: -// 1) Add an IValue tag to represent a placeholder object. -// 2) Play representational tricks -- stuff information into an IValue -// payload, such as by creating impossible -// intrusive_ptr_target*. Problem: this would cause IValue copying and -// possibly destruction to crash and so would be horribly unsafe. -// 3) Represent DTensorSpec directly inside IValue despite the inefficiency. -// 4) Leave the actual DTensor in the list of IValues, but detect it efficiently -// and transparently replace. -// 5) Just use a 24-byte struct of IValue + extra py::object. -// -// Given the high blast radius of (1), the unsafety of (2), the likely -// poor performance of (3), and detection of (4) looking less -// efficient than (5), (5) seems like the best path forward. - -// We can't safely steal bits from IValue, so we just use 24 bytes of -// space. If dtensor_spec is non-null (truthy) then it's the active -// member, otherwise it's iv. -struct IValueOrDTensorSpec { - IValueOrDTensorSpec() = default; - explicit IValueOrDTensorSpec(c10::IValue v) : iv(std::move(v)) {} - explicit IValueOrDTensorSpec(py::object dts) : dtensor_spec(std::move(dts)) {} - c10::IValue iv; - py::object dtensor_spec; - - bool operator==(const IValueOrDTensorSpec& rhs) const { - return dtensor_spec - ? (rhs.dtensor_spec && dtensor_spec.equal(rhs.dtensor_spec)) - : (iv == rhs.iv); - } -}; - -// This corresponds to the Python OpSchema class in that it is the key -// for the (native version of the) sharding propagator cache. It is -// missing essentially everything else from the Python OpSchema -// though. -class NativeOpSchema { - public: - NativeOpSchema( - const c10::OperatorHandle& op, - c10::SmallVector comparison_key, - std::size_t comparison_key_hash, - std::size_t args_schema_len) - : op_(op), - hash_(hash_combine( - hash_combine( - std::hash()(op), - comparison_key_hash), - args_schema_len)), - args_schema_len_(args_schema_len), - comparison_key_(std::move(comparison_key)) {} - - bool operator==(const NativeOpSchema& rhs) const { - // If two NativeOpSchema are being compared, they are probably - // equal, because comparison is occurring during a hash table - // lookup and we know the hashes are already equal. Therefore, we - // don't bother checking hash_ first. - return op_ == rhs.op_ && args_schema_len_ == rhs.args_schema_len_ && - comparison_key_ == rhs.comparison_key_; - } - - std::size_t hash() const { - return hash_; - } - - private: - // It would *not* be correct to store this by reference, because we - // have no guarantees about its lifetime. This class is cheap anyway. - c10::OperatorHandle op_; - std::size_t hash_; - // Subtle point: consider clamp.Tensor(Tensor self, Tensor? - // min=None, Tensor? max=None). The invocations clamp(t1, None, t2) - // and clamp(t1, t2, None) have the same comparison key (t1, t2) - // because we drop non-static non-tensor args from comparison. The - // only way we happen to be able to tell them apart is that we omit - // trailing defaulted arguments from the args tuple passed to - // __torch_dispatch__ (and hence to DTensor dispatch as well), so - // they have different args_schema_len_. - // - // I am preserving this existing behavior, but I suspect we should - // make an algorithm change to be less brittle, such as including - // None defaults for Tensor arguments in the comparison. - std::size_t args_schema_len_; - // There is no particular justification for the choice of 8 - // here. Feel free to change it. - c10::SmallVector comparison_key_; -}; - -namespace std { -template <> -struct hash { - std::size_t operator()(const NativeOpSchema& schema) const { - return schema.hash(); - } -}; -} // namespace std - -// Map from OpSchema to pyobject sharding propagation config. -class NativeShardingPropagatorCache { - public: - // Returns an invalid (falsey) py::object if the lookup fails. - py::object find(const NativeOpSchema& op_schema) const { - if (auto it = repr_.find(op_schema); it != repr_.end()) { - hits_++; - return py::object(it->second); - } - misses_++; - return py::object(); - } - - void insert(NativeOpSchema&& op_schema, py::object output_sharding) { - auto [it, inserted] = - repr_.emplace(std::move(op_schema), std::move(output_sharding)); - TORCH_INTERNAL_ASSERT( - inserted, - "tried to insert already-present element in NativeShardingPropagatorCache!"); - } - - auto hits() const { - return hits_; - } - - auto misses() const { - return misses_; - } - - private: - c10::FastMap repr_; - // Cache is thread-local, so we don't take any further action for - // thread-safety of these. - mutable std::size_t hits_ = 0; - mutable std::size_t misses_ = 0; -}; - -static std::optional> -create_native_op_schema( - const c10::OperatorHandle& op, - py::handle py_op, - torch::jit::Stack* stack); - -static std::mutex native_sharding_propagator_cache_cleanup_mutex; -static c10:: - FastMap*> - all_thread_caches; -thread_local std::optional - native_sharding_propagator_cache_DO_NOT_USE; - -NativeShardingPropagatorCache& -get_thread_local_native_sharding_propagator_cache() { - if (!native_sharding_propagator_cache_DO_NOT_USE.has_value()) { - native_sharding_propagator_cache_DO_NOT_USE.emplace(); - std::lock_guard lock( - native_sharding_propagator_cache_cleanup_mutex); - const auto this_thread_id = std::this_thread::get_id(); - all_thread_caches[this_thread_id] = - &native_sharding_propagator_cache_DO_NOT_USE; - py::dict thread_dict = - py::reinterpret_borrow(PyThreadState_GetDict()); - // We need to clean up before Python detaches from the thread if - // the thread is being destroyed. - if (!thread_dict.contains("__DTensor_fastpath_thread_cache_cleanup")) { - thread_dict["__DTensor_fastpath_thread_cache_cleanup"] = - py::capsule(new std::thread::id(this_thread_id), [](void* p) { - auto* ptid = reinterpret_cast(p); - { - std::lock_guard inner_lock( - native_sharding_propagator_cache_cleanup_mutex); - auto it = all_thread_caches.find(*ptid); - if (it != all_thread_caches.end()) { - // We need to both: - // 1) free python objects, and - it->second->reset(); - // 2) make sure we don't try to come back and mess with - // a destroyed thread-local at module unload (e.g., - // process exit) time. - all_thread_caches.erase(it); - } - } - delete ptid; - }); - } - } - return native_sharding_propagator_cache_DO_NOT_USE.value(); -} - -// We need to clean up all thread_locals if our module is getting -// unloaded. -void cleanup_thread_local_native_sharding_propagator_caches() { - std::lock_guard lock( - native_sharding_propagator_cache_cleanup_mutex); - for (auto& [_, popt_cache] : all_thread_caches) { - popt_cache->reset(); - } - all_thread_caches.clear(); -} - -static void replace_dtensors_with_local_tensor(torch::jit::Stack& stack); - -static bool is_default_overload(const std::string& overload_name) { - return overload_name.empty() || overload_name == "default"; -} - -static bool is_random_op(const c10::OperatorHandle& op) { - // NOTE: must stay in sync with _random_ops in - // torch/distributed/tensor/_dispatch.py - constexpr auto aten_namespace_prefix_len = 6; - const auto& op_name = op.operator_name(); - if (op_name.name.size() <= aten_namespace_prefix_len || - memcmp(op_name.name.data(), "aten::", aten_namespace_prefix_len) != 0) { - return false; - } - static constexpr std::array random_names = {{ - "native_dropout", - "normal_", - "rand_like", - "randn_like", - "uniform_", - "bernoulli", - }}; - std::string_view name_without_namespace( - op_name.name.c_str() + aten_namespace_prefix_len, - op_name.name.size() - aten_namespace_prefix_len); - if (name_without_namespace == "bernoulli_") { - return op_name.overload_name == "float"; - } - if (name_without_namespace == "randint_like") { - return is_default_overload(op_name.overload_name) || - op_name.overload_name == "low_dtype" || - op_name.overload_name == "low_dtype_out"; - } - const auto it = std::find( - random_names.begin(), random_names.end(), name_without_namespace); - if (it == random_names.end()) { - return false; - } - return is_default_overload(op_name.overload_name); -} - -// Puts local results on the stack. Return true for success, false for bailout -// to slow path. -static bool get_local_results( - const c10::OperatorHandle& op, - py::handle output_sharding, - py::handle compute_mesh, - bool participating, - torch::jit::Stack* stack) { - if (participating) { - // computation that happens in the current rank of the mesh, normal case - if (checked_istrue( - output_sharding.attr(dtensor_interned_strings.needs_redistribute) - .ptr()) || - is_random_op(op)) { - // Bail out to slow path. - return false; - } - // normal case, run local sharded op computation. - - // It is slightly inefficient that we take another pass over - // arguments here when we just did one in create_native_op_schema to - // create the comparison key. However, we have a crucial difference: - // in the NativeOpSchema, we don't want to waste time dealing with - // defaulted args. Here, we need to provide defaulted args because - // we are going to make a local op call. - replace_dtensors_with_local_tensor(*stack); - op.callBoxed(*stack); - } else { - // For a non-participating device (happens on rank that does not - // belong to the device mesh), we do: - // - // 1. if the return type is scalar, set the local result to - // None. - // 2. if the return type is Tensor or List[Tensor], return - // empty tensor(s) with correct dtype. - - stack->clear(); - - auto spec = output_sharding.attr(dtensor_interned_strings.output_spec); - if (spec.is_none()) { - // For a scalar return type, the non-participating device has - // None as its local result. - stack->emplace_back(); // Return None. - return true; - } - - const auto default_tensor = [](py::handle spec) -> Tensor { - auto tensor_meta = spec.attr(dtensor_interned_strings.tensor_meta); - TORCH_CHECK( - !tensor_meta.is_none(), py::str(spec), " has no tensor metadata."); - const auto sizes = tensor_meta.attr(dtensor_interned_strings.shape); - TORCH_CHECK( - PyTuple_Check(sizes.ptr()), "spec.tensor_meta.shape must be a tuple"); - const auto dtype = tensor_meta.attr(dtensor_interned_strings.dtype); - TORCH_CHECK( - THPDtype_Check(dtype.ptr()), - "spec.tensor_meta.dtype must be a torch.dtype"); - const auto scalar_type = - reinterpret_cast(dtype.ptr())->scalar_type; - if (py::cast(sizes).empty()) { - // scalar tensor - return at::zeros({}, scalar_type); - } else { - // non-scalar tensor - return at::empty({0}, scalar_type); - } - }; - auto handle_sequence = [&default_tensor, &op, stack](auto sequence) { - c10::List result(op.schema().returns().at(0).type()); - for (const auto& item : sequence) { - TORCH_CHECK( - !item.is_none(), - "return type ", - op.schema().returns().at(0).type(), - " in DTensor op is not supported"); - result.push_back(default_tensor(item)); - } - stack->push_back(std::move(result)); - }; - - if (py::isinstance(spec, get_dtensor_spec_class())) { - stack->push_back(default_tensor(spec)); - } else if (PyList_Check(spec.ptr())) { - handle_sequence(py::reinterpret_borrow(spec)); - } else if (PyTuple_Check(spec.ptr())) { - handle_sequence(py::reinterpret_borrow(spec)); - } else if (PySequence_Check(spec.ptr())) { - handle_sequence(py::reinterpret_borrow(spec)); - } else { - // return None. - stack->emplace_back(); - } - } - return true; -} - -static void functionalize_unsafe_set(at::Tensor& dst, const at::Tensor& src) { - at::native::checkSetStorage( - dst, - src.storage(), - dst.sym_storage_offset(), - dst.sym_sizes(), - dst.sym_strides(), - /*check_offset_in_bounds=*/false); -} - -static bool sets_intersect( - const std::unordered_set& smaller, - const std::unordered_set& bigger) { - if (smaller.size() > bigger.size()) { - return sets_intersect(bigger, smaller); - } - for (const auto& item : smaller) { - if (bigger.find(item) != bigger.end()) { - return true; - } - } - return false; -} - -py::object dispatchDTensorOp( - const c10::OperatorHandle& op, - py::handle py_op, - py::handle args, - py::handle kwargs, - torch::jit::Stack* stack) { - py::object cached_sharding; - const auto op_dispatcher = get_dtensor_op_dispatcher(); - { - const auto custom_op_handlers = - op_dispatcher.attr(dtensor_interned_strings._custom_op_handlers); - TORCH_CHECK( - PyDict_Check(custom_op_handlers.ptr()), - "_custom_op_handlers must be a dict!"); - PyObject* custom_op_handler = - PyDict_GetItemWithError(custom_op_handlers.ptr(), py_op.ptr()); - if (custom_op_handler) { - auto result = checked_vectorcall( - custom_op_handler, py_op.ptr(), args.ptr(), kwargs.ptr()); - stack->clear(); - return result; - } else if (PyErr_Occurred()) { - throw py::error_already_set(); - } - } - - torch::jit::Stack saved_args = *stack; - NativeShardingPropagatorCache* native_sharding_propagator_cache = nullptr; - auto opt_native_op_schema = create_native_op_schema(op, py_op, stack); - if (opt_native_op_schema.has_value()) { - native_sharding_propagator_cache = - &get_thread_local_native_sharding_propagator_cache(); - cached_sharding = - native_sharding_propagator_cache->find(opt_native_op_schema->first); - } - py::object py_op_info; - if (!cached_sharding) { - py_op_info = checked_vectorcall( - op_dispatcher.attr("unwrap_to_op_info").ptr(), - py_op.ptr(), - args.ptr(), - kwargs.ptr()); - py::object sharding = checked_vectorcall( - op_dispatcher - .attr("_propagate_op_sharding_non_cached_dispatch_slow_path") - .ptr(), - py_op.ptr(), - args.ptr(), - kwargs.ptr(), - py_op_info.ptr()); - if (!py::isinstance(sharding, get_output_sharding_class())) { - stack->clear(); - return sharding; - } - cached_sharding = sharding; - if (opt_native_op_schema.has_value()) { - native_sharding_propagator_cache->insert( - std::move(opt_native_op_schema->first), std::move(sharding)); - } - py_op_info.attr(dtensor_interned_strings.output_sharding) = cached_sharding; - } - - const auto get_py_op_info_if_needed = [&, &args = args, &kwargs = kwargs]() { - if (!py_op_info) { - py_op_info = checked_vectorcall( - op_dispatcher.attr(dtensor_interned_strings._unwrap_to_op_info_impl) - .ptr(), - py_op.ptr(), - args.ptr(), - kwargs.ptr(), - Py_False); - py_op_info.attr(dtensor_interned_strings.output_sharding) = - cached_sharding; - } - }; - - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - !kwargs.is_none(), - "Python op_dispatch implementation expects non-None kwargs"); - - py::object compute_mesh; - if (opt_native_op_schema.has_value()) { - compute_mesh = std::move(opt_native_op_schema->second); - } else { - get_py_op_info_if_needed(); - compute_mesh = py_op_info.attr(dtensor_interned_strings.compute_mesh); - } - - const bool participating = - !checked_vectorcall( - compute_mesh.attr(dtensor_interned_strings.get_coordinate).ptr()) - .is_none(); - const bool local_results_success = get_local_results( - op, cached_sharding, compute_mesh, participating, stack); - py::object py_local_results; - if (local_results_success) { - py_local_results = torch::jit::createPyObjectForStack(std::move(*stack)); - } else { - get_py_op_info_if_needed(); - py_local_results = checked_vectorcall( - get_dtensor_get_local_results_slow_path().ptr(), - py_op.ptr(), - args.ptr(), - py_op_info.ptr()); - } - - const auto& operator_name = op.operator_name(); - // Simple analysis of function schema to determine if this is an - // inplace variant. It might not be entirely correct, but it's good - // enough for now. - const bool is_inplace_op = - !operator_name.name.empty() && operator_name.name.back() == '_'; - // Simple analysis of function schema to determine if this is an - // ou variant. It might not be entirely correct, but it's good - // enough for now. - const bool is_out_variant_op = !is_inplace_op && - operator_name.overload_name.find("out") != std::string::npos; - - // Fast path for default or view ops. - const auto output_spec = - cached_sharding.attr(dtensor_interned_strings.output_spec); - if (!is_inplace_op && !is_out_variant_op && - !(output_spec.is_none() && - (op.operator_name().name == "aten::equal" && - is_default_overload(op.operator_name().overload_name)))) { - const auto wrap = get_dtensor_dispatcher_wrap(); - auto wrapped_result = checked_vectorcall( - wrap.ptr(), py_local_results.ptr(), output_spec.ptr()); - if (!participating) { - stack->clear(); - return wrapped_result; - } - - // Direct C++ implementation of return_and_correct_aliasing for view ops. - - // py::tuple's default constructor allocates a size-0 tuple, so we - // wrap in optional to get a detectable empty state. - std::optional wrapped_result_tuple; - if (PyTuple_Check(wrapped_result.ptr())) { - wrapped_result_tuple = py::reinterpret_borrow(wrapped_result); - } - const auto& returns = op.schema().returns(); - const auto num_arguments = op.schema().arguments().size(); - for (const auto arg_idx : c10::irange(num_arguments)) { - const auto& arg_schema = op.schema().arguments()[arg_idx]; - const auto* arg_alias_info = arg_schema.alias_info(); - if (!arg_alias_info || arg_alias_info->isWrite()) { - continue; - } - // If we ever get here, it's a view op. Therefore, it does not - // have mutable output aliases, so we skip that portion of - // return_and_correct_aliasing. Furthermore, we *only* want to - // return_and_correct_aliasing if it's a view op, so we do not - // need to port the mutable output aliases portion of - // return_and_correct_aliasing at all. - const c10::IValue& arg_iv = - saved_args.at(saved_args.size() - num_arguments + arg_idx); - if (!arg_iv.isTensor()) { - continue; - } - const auto& arg = arg_iv.toTensor(); - int ret_idx = 0; - for (const auto& ret_schema : returns) { - const auto* ret_alias_info = ret_schema.alias_info(); - if (!ret_alias_info) { - ret_idx++; - continue; - } - if (sets_intersect( - arg_alias_info->beforeSets(), ret_alias_info->beforeSets())) { - py::object ret; - if (wrapped_result_tuple.has_value()) { - ret = wrapped_result_tuple.value()[ret_idx]; - } else { - TORCH_INTERNAL_ASSERT(ret_idx == 0); - ret = wrapped_result; - } - if (PyList_Check(ret.ptr())) { - py::list ret_list = py::reinterpret_borrow(ret); - for (const auto& r : ret_list) { - auto tensor = py::cast(r); - functionalize_unsafe_set(tensor, arg); - } - } else { - auto tensor = py::cast(ret); - functionalize_unsafe_set(tensor, arg); - } - } - ret_idx++; - } - } - stack->clear(); - return wrapped_result; - } - - auto dispatch = get_dtensor_dispatch(); - auto result = checked_vectorcall( - dispatch.ptr(), - py_op.ptr(), - args.ptr(), - kwargs.ptr(), - compute_mesh.ptr(), - cached_sharding.ptr(), - py_local_results.ptr(), - participating ? Py_True : Py_False, - is_inplace_op ? Py_True : Py_False, - is_out_variant_op ? Py_True : Py_False); - stack->clear(); - return result; -} - // DTensor-specific variant of make_wrapper_subclass to minimize DTensor // overhead. static PyObject* THPVariable_dtensor_new( @@ -1685,44 +1008,27 @@ static PyObject* THPVariable_dtensor_new( END_HANDLE_TH_ERRORS } -struct NativeRuntimeSchemaInfo { - py::object static_kwargkey; - size_t static_argnum; -}; - -NativeRuntimeSchemaInfo unpack_runtime_schema_info( - py::handle runtime_schema_info, - size_t num_args) { - NativeRuntimeSchemaInfo result; - if (!runtime_schema_info) { - result.static_argnum = num_args; - } else { - result.static_argnum = py::cast( - runtime_schema_info.attr(dtensor_interned_strings.static_argnum)); - result.static_kwargkey = - runtime_schema_info.attr(dtensor_interned_strings.static_kwargkey); - TORCH_CHECK( - result.static_kwargkey.is_none() || - PyList_Check(result.static_kwargkey.ptr()), - "RuntimeSchemaInfo.static_kwargkey must be a list!"); - } - return result; -} - static bool DTensor_OpSchema_recompute_comparison_key_impl( PyObject* self, const py::tuple& args_schema) { + py::object static_kwargkey; + size_t static_argnum = 0; const py::handle self_handle = py::handle(self); - const auto schema_info = + const py::handle schema_info = self_handle.attr(dtensor_interned_strings.schema_info); - NativeRuntimeSchemaInfo native_info = unpack_runtime_schema_info( - checked_not(schema_info.ptr()) ? py::handle() : py::handle(schema_info), - args_schema.size()); + if (checked_not(schema_info.ptr())) { + static_argnum = args_schema.size(); + static_kwargkey = py::none(); + } else { + static_argnum = py::cast( + schema_info.attr(dtensor_interned_strings.static_argnum)); + static_kwargkey = + schema_info.attr(dtensor_interned_strings.static_kwargkey); + } c10::SmallVector args_to_hash; size_t idx = 0; for (const auto& e : args_schema) { - if (idx >= native_info.static_argnum || - arg_type_tensor_or_tensor_list_like(e)) { + if (idx >= static_argnum || arg_type_tensor_or_tensor_list_like(e)) { if (PyList_Check(e.ptr())) { args_to_hash.push_back( py::reinterpret_steal(PyList_AsTuple(e.ptr()))); @@ -1737,19 +1043,24 @@ static bool DTensor_OpSchema_recompute_comparison_key_impl( args_to_hash_tup[idx] = std::move(args_to_hash[idx]); } PyObject* comparison_key = nullptr; - if (native_info.static_kwargkey && !native_info.static_kwargkey.is_none()) { - py::list static_kwargkey = - py::reinterpret_borrow(native_info.static_kwargkey); + if (!static_kwargkey.is_none()) { + if (!PyList_Check(static_kwargkey.ptr())) { + PyErr_SetString( + PyExc_TypeError, "self.schema_info.static_kwargkey must be a list!"); + return false; + } + py::list static_kwargkey_list = + py::reinterpret_borrow(static_kwargkey); auto raw_kwargs_schema = self_handle.attr(dtensor_interned_strings.kwargs_schema); if (!PyDict_Check(raw_kwargs_schema.ptr())) { PyErr_SetString(PyExc_TypeError, "self.kwargs_schema must be a dict!"); return false; } - py::tuple kwargs_to_hash(static_kwargkey.size()); + py::tuple kwargs_to_hash(static_kwargkey_list.size()); int idx = 0; auto kwargs_schema = py::reinterpret_borrow(raw_kwargs_schema); - for (const auto& k : static_kwargkey) { + for (const auto& k : static_kwargkey_list) { PyObject* item = PyDict_GetItemWithError(kwargs_schema.ptr(), k.ptr()); if (item) { kwargs_to_hash[idx++] = py::reinterpret_borrow(item); @@ -1944,370 +1255,6 @@ static PyObject* DTensor_compute_global_tensor_info( END_HANDLE_TH_ERRORS } -enum class TensorFlavor { - NON_TENSOR, - EXACTLY_DTENSOR, - EXACTLY_TENSOR, - DTENSOR_SUBCLASS, - NON_DTENSOR_TENSOR_SUBCLASS, -}; - -static std::pair check_for_dtensor_or_tensor( - const at::Tensor& tensor) { - if (!tensor.defined()) { - return {TensorFlavor::NON_TENSOR, py::object()}; - } - - // I don't think we need to check for wrapped_number() tensors here; - // the try_replicate_spec_for_scalar_tensor stuff in our caller - // specifically handles 1-element tensors. - - torch::jit::guardAgainstNamedTensor(tensor); - auto py_tensor = py::cast(tensor); - - const auto dtensor = get_dtensor_class(); - auto* const obj_type = Py_TYPE(py_tensor.ptr()); - if (obj_type == (PyTypeObject*)dtensor.ptr()) { - return {TensorFlavor::EXACTLY_DTENSOR, std::move(py_tensor)}; - } - // Fast path for plain old Tensors. - if (THPVariable_CheckTypeExact(obj_type)) { - return {TensorFlavor::EXACTLY_TENSOR, std::move(py_tensor)}; - } - if (py::isinstance(py_tensor, dtensor)) { - return {TensorFlavor::DTENSOR_SUBCLASS, std::move(py_tensor)}; - } - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - THPVariableClass && py::isinstance(py_tensor, THPVariableClass)); - return {TensorFlavor::NON_DTENSOR_TENSOR_SUBCLASS, std::move(py_tensor)}; -} - -static std::pair check_for_dtensor_or_tensor( - const c10::IValue& iv) { - if (!iv.isTensor()) { - return {TensorFlavor::NON_TENSOR, py::object()}; - } - - return check_for_dtensor_or_tensor(iv.toTensor()); -} - -static c10::List replace_dtensors_with_local_tensor( - const c10::List& tl) { - c10::List local_list(tl.elementType()); - local_list.reserve(tl.size()); - for (const auto& elt : tl) { - const auto [tensor_flavor, py_tensor] = check_for_dtensor_or_tensor(elt); - if (tensor_flavor == TensorFlavor::EXACTLY_DTENSOR || - tensor_flavor == TensorFlavor::DTENSOR_SUBCLASS) { - local_list.push_back(THPVariable_Unpack( - py_tensor.attr(dtensor_interned_strings._local_tensor).ptr())); - } else { - local_list.push_back(elt); - } - } - return local_list; -} - -static void replace_dtensors_with_local_tensor(torch::jit::Stack& stack) { - for (auto& arg : stack) { - if (arg.isList()) { - arg = replace_dtensors_with_local_tensor(arg.toList()); - continue; - } - const auto [tensor_flavor, py_tensor] = check_for_dtensor_or_tensor(arg); - if (tensor_flavor == TensorFlavor::EXACTLY_DTENSOR || - tensor_flavor == TensorFlavor::DTENSOR_SUBCLASS) { - arg = THPVariable_Unpack( - py_tensor.attr(dtensor_interned_strings._local_tensor).ptr()); - } - } -} - -static py::object try_find_mesh_from_args( - const c10::OperatorHandle& op, - const OperatorArgsKwargsView& args_kwargs) { - for (auto argument_it = args_kwargs.args_begin(); - argument_it != args_kwargs.args_end(); - ++argument_it) { - const auto [tensor_flavor, py_tensor] = - check_for_dtensor_or_tensor(*argument_it); - if (tensor_flavor == TensorFlavor::EXACTLY_DTENSOR || - tensor_flavor == TensorFlavor::DTENSOR_SUBCLASS) { - return py::reinterpret_borrow( - py_tensor.attr(dtensor_interned_strings.device_mesh)); - } - } - TORCH_CHECK_VALUE( - false, "Cannot find device mesh from args for op : ", op.operator_name()); -} - -static /*DTensorSpec*/ py::object try_replicate_spec_for_scalar_tensor( - bool allow_implicit_replication, - py::handle op_call, - py::handle py_tensor, - py::handle compute_mesh) { - const Tensor& tensor_arg = THPVariable_Unpack(py_tensor.ptr()); - const bool numel_is_one = tensor_arg.numel() == 1; - if (numel_is_one && tensor_arg.dim() == 1) { - TORCH_WARN( - "Found a non-scalar tensor with numel=1 and ndim!=0, " - "we are implicitly creating a replicated DTensor for it. " - "However, please consider changing it to a scalar tensor " - "or explicitly create a DTensor under distributed environment."); - } - - TORCH_CHECK( - numel_is_one || allow_implicit_replication, - py::str(op_call), - " got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!"); - - // scalar tensor can be safely treated as replicated. - const auto num_placements = - py::cast(compute_mesh.attr(dtensor_interned_strings.ndim)); - py::tuple placements_tuple(num_placements); - py::object replicate = get_replicate_class()(); - for (const auto idx : c10::irange(num_placements)) { - PyTuple_SET_ITEM( - placements_tuple.ptr(), - idx, - py::reinterpret_borrow(replicate).release().ptr()); - } - - return checked_vectorcall( - get_dtensor_spec_class().ptr(), - compute_mesh.ptr(), - placements_tuple.ptr(), - checked_vectorcall( - get_tensor_meta_class().ptr(), - py_tensor.attr(dtensor_interned_strings.shape).ptr(), - py_tensor.attr(dtensor_interned_strings.stride)().ptr(), - py_tensor.attr(dtensor_interned_strings.dtype).ptr()) - .ptr()); -} - -// May return unset object, in which case there was no runtime schema -// info. -static py::object get_runtime_schema_info_for_op(py::handle py_op) { - const auto op_dispatcher = get_dtensor_op_dispatcher(); - const auto sharding_propagator = - op_dispatcher.attr(dtensor_interned_strings.sharding_propagator); - const py::dict op_to_schema_info = py::reinterpret_borrow( - sharding_propagator.attr(dtensor_interned_strings.op_to_schema_info)); - - PyObject* runtime_schema_info = - PyDict_GetItemWithError(op_to_schema_info.ptr(), py_op.ptr()); - if (!runtime_schema_info && PyErr_Occurred()) { - throw py::error_already_set(); - } - return py::reinterpret_borrow(runtime_schema_info); -} - -static bool contains_any_symint(const py::tuple& tup) { - for (const auto& s : tup) { - if (THPUtils_checkLong(s.ptr())) { - continue; - } - if (torch::is_symint(s)) { - return true; - } - } - return false; -} - -static bool dtensor_spec_has_symints(py::handle spec) { - const auto tensor_meta = spec.attr(dtensor_interned_strings.tensor_meta); - if (tensor_meta.is_none()) { - return false; - } - py::object raw_shape = tensor_meta.attr(dtensor_interned_strings.shape); - if (!PyTuple_Check(raw_shape.ptr())) { - PyErr_SetString(PyExc_TypeError, "TensorMeta.shape must be a tuple!"); - throw py::error_already_set(); - } - const auto shape = py::reinterpret_steal(raw_shape.release()); - return contains_any_symint(shape); -} - -static std::optional> -create_native_op_schema( - const c10::OperatorHandle& op, - py::handle py_op, - torch::jit::Stack* stack) { - // fused schema part of unwrap_to_op_info + recompute_comparison_key, - // operating on IValues instead of Python stuff. - - py::object runtime_schema_info = get_runtime_schema_info_for_op(py_op); - if (runtime_schema_info && - checked_istrue(py::handle(runtime_schema_info) - .attr(dtensor_interned_strings.needs_pytree) - .ptr())) { - // Punting on pytree flattening in the fast path on IValues for - // now since only a minority of ops need it. - return std::nullopt; - } - - OperatorArgsKwargsView args_kwargs(op, *stack); - auto native_info = unpack_runtime_schema_info( - py::handle(runtime_schema_info), args_kwargs.num_positional_args()); - - c10::SmallVector comparison_key; - std::size_t comparison_key_hash = 0; - - py::object compute_mesh = py::none(); - - const auto handle_non_dtensor_arg = - [&comparison_key, &comparison_key_hash, &native_info]( - size_t idx, c10::IValue arg) { - if (idx >= native_info.static_argnum) { - if (arg.isList()) { - const auto& list = arg.toList(); - if (list.empty()) { - arg = c10::ivalue::Tuple::create({}); - } else { - // WARNING: here we rely on c10::List being represented - // by a contiguous array of IValue for efficiency! - arg = c10::ivalue::Tuple::create(c10::ArrayRef( - &(*list.begin()).get(), list.size())); - } - } else if (arg.isTensor() && !arg.toTensor().defined()) { - // Coerce undefined Tensor to None, just as we do when - // converting IValues to PyObject. Otherwise comparison - // doesn't work. (undefined Tensors can get here because - // check_for_dtensor_or_tensor calls them non-Tensors, but - // doesn't have a way to do the coercion for us.) - arg = c10::IValue(); - } - comparison_key_hash = - c10::hash_combine(comparison_key_hash, c10::IValue::hash(arg)); - comparison_key.emplace_back(std::move(arg)); - } - }; - const auto handle_dtensor_arg = [&comparison_key, - &comparison_key_hash](py::object arg) { - comparison_key_hash = c10::hash_combine( - comparison_key_hash, static_cast(py::hash(arg))); - comparison_key.emplace_back(std::move(arg)); - }; - - Py_ssize_t idx = 0; - const bool allow_implicit_replication = - at::get_dtensor_allow_implicit_replication(); - for (auto argument_it = args_kwargs.args_begin(); - argument_it != args_kwargs.args_end(); - ++argument_it) { - const auto& arg = *argument_it; - const auto [tensor_flavor, py_tensor] = check_for_dtensor_or_tensor(arg); - switch (tensor_flavor) { - case TensorFlavor::EXACTLY_DTENSOR: - case TensorFlavor::DTENSOR_SUBCLASS: { - py::object spec = py_tensor.attr(dtensor_interned_strings._spec); - if (dtensor_spec_has_symints(spec)) { - // Symints are unhashable, so we can't use the cache for - // sharding propagation. bail out to slow path. - return std::nullopt; - } - handle_dtensor_arg(std::move(spec)); - if (compute_mesh.is_none()) { - compute_mesh = py::reinterpret_borrow( - py_tensor.attr(dtensor_interned_strings.device_mesh)); - } - break; - } - case TensorFlavor::EXACTLY_TENSOR: - case TensorFlavor::NON_DTENSOR_TENSOR_SUBCLASS: { - if (compute_mesh.is_none()) { - compute_mesh = try_find_mesh_from_args(op, args_kwargs); - } - handle_dtensor_arg(try_replicate_spec_for_scalar_tensor( - allow_implicit_replication, py_op, py_tensor, compute_mesh)); - break; - } - case TensorFlavor::NON_TENSOR: { - // non DTensor/Tensor args (i.e. int/float/bool), just add to - // local_args - handle_non_dtensor_arg(idx, arg); - break; - } - default: - TORCH_INTERNAL_ASSERT(false, "can't happen"); - break; - } - idx++; - } - - TORCH_CHECK( - !compute_mesh.is_none(), - "found no DeviceMesh from dtensor args for ", - op.operator_name()); - - if (native_info.static_kwargkey && !native_info.static_kwargkey.is_none()) { - // Separator to disambiguate kwargs from args in comparison and hashing. - static constexpr int64_t kwargs_separator = 0x0011223344556677LL; - comparison_key.emplace_back(static_cast(kwargs_separator)); - comparison_key_hash = hash_combine(comparison_key_hash, kwargs_separator); - - for (auto argument_it = args_kwargs.kwargs_begin(); - argument_it != args_kwargs.kwargs_end(); - ++argument_it) { - // Rather than hash/compare the string key, we can just use the - // index of the kwarg in the schema! - const auto underlying_index = argument_it.underlying_index(); - comparison_key.emplace_back(c10::IValue(underlying_index)); - comparison_key_hash = hash_combine( - comparison_key_hash, c10::IValue::hash(comparison_key.back().iv)); - const auto [tensor_flavor, py_tensor] = - check_for_dtensor_or_tensor(*argument_it); - switch (tensor_flavor) { - case TensorFlavor::EXACTLY_DTENSOR: - case TensorFlavor::DTENSOR_SUBCLASS: { - handle_dtensor_arg(py_tensor.attr(dtensor_interned_strings._spec)); - break; - } - case TensorFlavor::EXACTLY_TENSOR: - case TensorFlavor::NON_DTENSOR_TENSOR_SUBCLASS: { - handle_dtensor_arg(try_replicate_spec_for_scalar_tensor( - allow_implicit_replication, py_op, py_tensor, compute_mesh)); - break; - } - case TensorFlavor::NON_TENSOR: { - handle_non_dtensor_arg(native_info.static_argnum, *argument_it); - break; - } - default: - TORCH_INTERNAL_ASSERT(false, "can't happen"); - break; - } - } - } - - return std::make_pair( - NativeOpSchema( - op, - std::move(comparison_key), - comparison_key_hash, - args_kwargs.num_positional_args()), - std::move(compute_mesh)); -} - -static PyObject* get_DTensor_sharding_propagator_cache_stats( - PyObject* self, - PyObject* noargs) { - HANDLE_TH_ERRORS - auto& cache = get_thread_local_native_sharding_propagator_cache(); - py::tuple result(2); - result[0] = cache.hits(); - result[1] = cache.misses(); - return result.release().ptr(); - END_HANDLE_TH_ERRORS -} - -static PyObject* clear_DTensor_sharding_propagator_cache( - PyObject* self, - PyObject* noargs) { - native_sharding_propagator_cache_DO_NOT_USE.reset(); - Py_RETURN_NONE; -} - using getter = PyObject* (*)(PyObject*, void*); using setter = int (*)(PyObject*, PyObject*, void*); @@ -3281,7 +2228,7 @@ static PyMethodDef extra_methods[] = { {nullptr}}; // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) -static PyMethodDef extra_dtensor_functions[] = { +static PyMethodDef extra_functions[] = { {"_DTensor_OpSchema_post_init", DTensor_OpSchema_post_init, METH_O, @@ -3294,14 +2241,6 @@ static PyMethodDef extra_dtensor_functions[] = { castPyCFunctionFast(DTensor_compute_global_tensor_info), METH_FASTCALL, compute_global_tensor_info_doc}, - {"_get_DTensor_sharding_propagator_cache_stats", - get_DTensor_sharding_propagator_cache_stats, - METH_NOARGS, - nullptr}, - {"_clear_DTensor_sharding_propagator_cache", - clear_DTensor_sharding_propagator_cache, - METH_NOARGS, - nullptr}, {nullptr}}; struct THPVariableMeta { @@ -3662,22 +2601,13 @@ bool THPVariable_initModule(PyObject* module) { PyModule_AddObject(module, "TensorBase", (PyObject*)&THPVariableType); Py_INCREF(&THPVariableType); PyModule_AddObject(module, "_TensorBase", (PyObject*)&THPVariableType); -#ifdef USE_DISTRIBUTED - PyModule_AddObject( - module, - "__DTensor_fastpath_cache_cleanup", - py::capsule( - []() { cleanup_thread_local_native_sharding_propagator_caches(); }) - .release() - .ptr()); - if (!intern_dtensor_strings()) { - return false; - } - PyModule_AddFunctions(module, extra_dtensor_functions); -#endif torch::autograd::initTorchFunctions(module); torch::autograd::initTensorImplConversion(module); torch::utils::validate_numpy_for_dlpack_deleter_bug(); + if (!intern_dtensor_strings()) { + return false; + } + PyModule_AddFunctions(module, extra_functions); return true; } diff --git a/torch/csrc/autograd/python_variable.h b/torch/csrc/autograd/python_variable.h index 5b6f089990693..af733f2ad1769 100644 --- a/torch/csrc/autograd/python_variable.h +++ b/torch/csrc/autograd/python_variable.h @@ -90,15 +90,6 @@ void pushPyOutToStack( py::object out, const char* msg); -py::handle get_dtensor_class(); - -py::object dispatchDTensorOp( - const c10::OperatorHandle& op, - py::handle py_op, - py::handle args, - py::handle kwargs, - torch::jit::Stack* stack); - inline PyObject* THPVariable_WrapList( const torch::autograd::variable_list& inputs) { PyObject* pyinput = PyList_New(static_cast(inputs.size())); diff --git a/torch/csrc/jit/frontend/schema_matching.cpp b/torch/csrc/jit/frontend/schema_matching.cpp index c3525ac9c8a20..d866e4f434448 100644 --- a/torch/csrc/jit/frontend/schema_matching.cpp +++ b/torch/csrc/jit/frontend/schema_matching.cpp @@ -679,7 +679,7 @@ Value* emitBuiltinCall( at::ArrayRef args, at::ArrayRef kwargs, const std::optional& self) { - auto variants = getAllOperatorsFor(name); + const auto& variants = getAllOperatorsFor(name); const auto& builtin_functions = getAllBuiltinFunctionsFor(name); // first let's set the graph's version diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index 513258236ac4b..ac99385401be4 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -617,7 +617,7 @@ void AliasDb::analyzeImpl(Node* node) { oss << input->type()->str() << ", "; } oss << "\n\nCandidates:"; - auto candidates = getAllOperatorsFor(node->kind()); + const auto& candidates = getAllOperatorsFor(node->kind()); for (const auto& candidate : candidates) { oss << "\n\t" << candidate->schema(); } diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index 9b00a703e352e..08bfe47382952 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -1088,7 +1088,7 @@ const FunctionSchema* Node::maybeSchema() const { const Operator* Node::maybeOperator() const { if (!op_) { - auto candidates = getAllOperatorsFor(kind()); + const auto& candidates = getAllOperatorsFor(kind()); for (const auto& candidate : candidates) { if (matches(candidate->schema())) { op_ = candidate.get(); diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index a7f16a7dc5a04..8dc4cb7ac9349 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1693,7 +1693,7 @@ void initJITBindings(PyObject* module) { [](const std::string& op_name, const std::string& overload_name) { try { auto symbol = Symbol::fromQualString(op_name); - auto operations = getAllOperatorsFor(symbol); + const auto& operations = getAllOperatorsFor(symbol); for (const auto& op : operations) { if (op->schema().overload_name() == overload_name) { return op->schema(); @@ -1714,7 +1714,7 @@ void initJITBindings(PyObject* module) { const std::string& overload_name) -> std::optional { try { auto symbol = Symbol::fromQualString(op_name); - auto operations = getAllOperatorsFor(symbol); + const auto& operations = getAllOperatorsFor(symbol); bool allow_numbers_as_tensors = opAllowsNumbersAsTensors(symbol); for (const auto& op : operations) { if (op->schema().overload_name() == overload_name) { @@ -2138,7 +2138,7 @@ void initJITBindings(PyObject* module) { m.def("_jit_get_custom_class_schemas", customClassSchemasForBCCheck); m.def("_jit_get_schemas_for_operator", [](const std::string& qualified_name) { auto symbol = Symbol::fromQualString(qualified_name); - auto operations = getAllOperatorsFor(symbol); + const auto& operations = getAllOperatorsFor(symbol); return fmap(operations, [](const std::shared_ptr& op) { return op->schema(); }); diff --git a/torch/csrc/jit/runtime/operator.cpp b/torch/csrc/jit/runtime/operator.cpp index 6f9dec70cddc9..35dead2a395c9 100644 --- a/torch/csrc/jit/runtime/operator.cpp +++ b/torch/csrc/jit/runtime/operator.cpp @@ -53,16 +53,6 @@ struct OperatorRegistry { to_register.clear(); } - const std::vector>& getOperatorsWithLockHeld( - Symbol name) { - registerPendingOperators(); - static std::vector> empty; - auto it = operators.find(name); - if (it != operators.end()) - return it->second; - return empty; - } - public: void registerOperator(Operator&& op) { std::lock_guard guard(lock); @@ -153,35 +143,14 @@ struct OperatorRegistry { return it->second; } - // This function returns internal lock-protected state. We need to - // copy it to avoid race conditions. - std::vector> getOperators(Symbol name) { + const std::vector>& getOperators(Symbol name) { std::lock_guard guard(lock); - return getOperatorsWithLockHeld(name); - } - - std::vector> getSortedOperators(Symbol name) { - std::lock_guard guard(lock); - const auto& unsortedOps = getOperatorsWithLockHeld(name); - // Depending on the order of registration, aten or jit ops may be - // registered first. This sorting is helpful in cases where - // deterministic (i.e. not dependent on build config) behavior is - // desired; e.g. torch.ops.aten.* uses this function, and tries to - // find the "first" op that matches input args. Without the sorting, - // the "first" op may change depending on registration order. - std::vector> sortedOps; - sortedOps.reserve(unsortedOps.size()); - std::copy_if( - unsortedOps.begin(), - unsortedOps.end(), - std::back_inserter(sortedOps), - [](const std::shared_ptr& op) { return op->isC10Op(); }); - std::copy_if( - unsortedOps.begin(), - unsortedOps.end(), - std::back_inserter(sortedOps), - [](const std::shared_ptr& op) { return !op->isC10Op(); }); - return sortedOps; + registerPendingOperators(); + static std::vector> empty; + auto it = operators.find(name); + if (it != operators.end()) + return it->second; + return empty; } std::vector findSimilarOperators(Symbol input_op) { @@ -418,16 +387,35 @@ void deregisterOperator(const FunctionSchema& schema) { getRegistry().deregisterOperator(schema); } -std::vector> getAllOperators() { +const std::vector> getAllOperators() { return getRegistry().getAllOperators(); } -std::vector> getAllOperatorsFor(Symbol name) { +const std::vector>& getAllOperatorsFor(Symbol name) { return getRegistry().getOperators(name); } std::vector> getAllSortedOperatorsFor(Symbol name) { - return getRegistry().getSortedOperators(name); + const auto& unsortedOps = getAllOperatorsFor(name); + // Depending on the order of registration, aten or jit ops may be + // registered first. This sorting is helpful in cases where + // deterministic (i.e. not dependent on build config) behavior is + // desired; e.g. torch.ops.aten.* uses this function, and tries to + // find the "first" op that matches input args. Without the sorting, + // the "first" op may change depending on registration order. + std::vector> sortedOps; + sortedOps.reserve(unsortedOps.size()); + std::copy_if( + unsortedOps.begin(), + unsortedOps.end(), + std::back_inserter(sortedOps), + [](const std::shared_ptr& op) { return op->isC10Op(); }); + std::copy_if( + unsortedOps.begin(), + unsortedOps.end(), + std::back_inserter(sortedOps), + [](const std::shared_ptr& op) { return !op->isC10Op(); }); + return sortedOps; } std::shared_ptr findOperatorFor(const c10::OperatorName& full_name) { diff --git a/torch/csrc/jit/runtime/operator.h b/torch/csrc/jit/runtime/operator.h index 6b6972deeebf0..bde3825f5ea38 100644 --- a/torch/csrc/jit/runtime/operator.h +++ b/torch/csrc/jit/runtime/operator.h @@ -260,9 +260,8 @@ struct TORCH_API Operator { TORCH_API std::string canonicalSchemaString(const FunctionSchema& schema); -TORCH_API std::vector> getAllOperators(); -// This function returns a copy for thread safety. -TORCH_API std::vector> getAllOperatorsFor( +TORCH_API const std::vector> getAllOperators(); +TORCH_API const std::vector>& getAllOperatorsFor( Symbol name); // Returns operators in the order which OpOverloadPacket resolves them. TORCH_API std::vector> getAllSortedOperatorsFor( diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp index b1f0f410f14fe..74f87e46757ea 100644 --- a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp +++ b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp @@ -79,7 +79,7 @@ auto compilation_unit = std::make_shared(); const std::optional getInplaceVariant( const FunctionSchema& base_schema) { - auto inplace_variants = + auto& inplace_variants = getAllOperatorsFor(c10::Symbol::fromQualString(base_schema.name() + "_")); for (const auto& variant : inplace_variants) { diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index e89f7887320a0..0a046523127d5 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -4,7 +4,6 @@ #include #include #include -#include #include #include #include @@ -13,7 +12,6 @@ #include #include #include -#include #include #include @@ -303,16 +301,6 @@ static py::object maybe_get_registered_torch_dispatch_rule( return result; } -static bool is_dtensor(PyObject* obj) { -#ifdef USE_DISTRIBUTED - const py::handle dtensor = get_dtensor_class(); - return (PyObject*)Py_TYPE(obj) == dtensor.ptr() || - py::isinstance(py::handle(obj), dtensor); -#else - return false; -#endif -} - // NB: Invariant: if you run this function, you MUST test if the returned // py::object is nullptr, as this will occur WITHOUT error condition being set. // And if an error happens, this function is responsible for throwing a C++ @@ -325,8 +313,8 @@ static py::object dispatch_on_subclass( PyObject* torch_api_function, bool is_torch_function, const char* torch_function_name_str, - const c10::OperatorHandle* opt_op, - torch::jit::Stack* opt_stack) { + std::optional maybe_mode_key = + std::nullopt) { py::object ret; for (auto& arg : overloaded_args) { py::object torch_function = @@ -379,39 +367,13 @@ static py::object dispatch_on_subclass( } } - if (!is_torch_function && is_dtensor(arg)) { - if (opt_op && opt_stack) { - ret = dispatchDTensorOp( - *opt_op, torch_api_function, args, kwargs, opt_stack); - } else { - // Slow path -- reconstruct C++ data structures since they were not - // provided. - auto schema = py::cast( - py::handle(torch_api_function).attr("_schema")); - auto opt_op_handle = - c10::Dispatcher::singleton().findOp(schema.operator_name()); - TORCH_CHECK( - opt_op_handle.has_value(), - "could not look up op for ", - schema.operator_name()); - const auto& op_handle = *opt_op_handle; - auto stack = torch::jit::createStackForSchema( - op_handle.schema(), - py::reinterpret_borrow(args), - py::reinterpret_borrow(kwargs), - std::nullopt); - ret = dispatchDTensorOp( - op_handle, torch_api_function, args, kwargs, &stack); - } - } else { - ret = py::reinterpret_steal(PyObject_CallFunctionObjArgs( - torch_function.ptr(), - torch_api_function, - py_types.ptr(), - args, - kwargs, - NULL)); - } + ret = py::reinterpret_steal(PyObject_CallFunctionObjArgs( + torch_function.ptr(), + torch_api_function, + py_types.ptr(), + args, + kwargs, + NULL)); if (ret.ptr() == nullptr) { throw python_error(); } @@ -518,28 +480,6 @@ auto handle_torch_function_no_python_arg_parser( PyObject* torch_api_function, const char* module_name, TorchFunctionName torch_function_name) -> PyObject* { - return handle_torch_function_no_python_arg_parser( - overloaded_args, - args, - kwargs, - func_name, - torch_api_function, - module_name, - nullptr, - nullptr, - torch_function_name); -} - -auto handle_torch_function_no_python_arg_parser( - at::ArrayRef overloaded_args, - PyObject* args, - PyObject* kwargs, - const char* func_name, - PyObject* torch_api_function, - const char* module_name, - const c10::OperatorHandle* opt_op, - torch::jit::Stack* opt_stack, - TorchFunctionName torch_function_name) -> PyObject* { const char* torch_function_name_str = nullptr; switch (torch_function_name) { case TorchFunctionName::TorchFunction: @@ -639,9 +579,7 @@ auto handle_torch_function_no_python_arg_parser( py_types, torch_api_function, is_torch_function, - torch_function_name_str, - opt_op, - opt_stack); + torch_function_name_str); if (curr_ret.ptr() != nullptr) { ret = curr_ret; } diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index 4a73a21916776..3ee12f14528e2 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -1287,18 +1287,6 @@ auto TORCH_PYTHON_API handle_torch_function_no_python_arg_parser( TorchFunctionName torch_function_name = TorchFunctionName::TorchFunction) -> PyObject*; -auto handle_torch_function_no_python_arg_parser( - at::ArrayRef overloaded_args, - PyObject* args, - PyObject* kwargs, - const char* func_name, - PyObject* torch_api_function, - const char* module_name, - const c10::OperatorHandle* opt_op, - torch::jit::Stack* opt_stack, - TorchFunctionName torch_function_name = TorchFunctionName::TorchFunction) - -> PyObject*; - // Used for getters of Tensor properties auto handle_torch_function_getter( THPVariable* self, diff --git a/torch/distributed/_tools/mem_tracker.py b/torch/distributed/_tools/mem_tracker.py index 819e16ca99698..59692d9237b66 100644 --- a/torch/distributed/_tools/mem_tracker.py +++ b/torch/distributed/_tools/mem_tracker.py @@ -391,6 +391,7 @@ def __init__(self) -> None: # Weak references to the topmost AC module currently active self._ac_mod: Optional[weakref.ref] = None self._orig_resize = torch.UntypedStorage.resize_ + self._orig_dtensor_dispatch = DTensor._op_dispatcher.dispatch self._depth = 0 def _update_snap( diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py index fb072d8dce629..a6b6e39511974 100644 --- a/torch/distributed/tensor/_api.py +++ b/torch/distributed/tensor/_api.py @@ -1,7 +1,6 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates -import copy import inspect import warnings from collections.abc import Callable, Sequence @@ -97,23 +96,16 @@ def backward(ctx, grad_output: torch.Tensor): # type: ignore[override] ) tensor_stride = tuple(tensor_stride) grad_placements = grad_placements or dtensor_spec.placements - if ( - tensor_stride == dtensor_meta.stride - and grad_placements == dtensor_spec.placements - ): - # Avoid actual sharing of specs in case they're modified during (e.g.) - # sharding propagation. - grad_spec = copy.copy(dtensor_spec) - else: - grad_spec = DTensorSpec( - mesh, - grad_placements, - tensor_meta=TensorMeta( - shape=dtensor_meta.shape, - stride=tensor_stride, - dtype=dtensor_meta.dtype, - ), - ) + grad_spec = DTensorSpec( + mesh, + grad_placements, + tensor_meta=TensorMeta( + shape=dtensor_meta.shape, + stride=tensor_stride, + dtype=dtensor_meta.dtype, + ), + ) + return ( # pyrefly: ignore [bad-argument-type] DTensor( @@ -346,14 +338,14 @@ def __coerce_same_metadata_as_tangent__(self, flatten_spec, expected_type=None): ) @classmethod + @torch._disable_dynamo + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] - # We just need to have an implementation here; the __torch_dispatch__ machinery - # calls into a specific C++ fast path that doesn't call here. - # See #167051 for details - # python_arg_parser.cpp: dispatch_on_subclass() - # -> python_variable.cpp: dispatchDTensorOp() - raise NotImplementedError( - "DTensor.__torch_dispatch__ should not actually get called" + return DTensor._op_dispatcher.dispatch( + func, + args, + kwargs or {}, ) @staticmethod diff --git a/torch/distributed/tensor/_dispatch.py b/torch/distributed/tensor/_dispatch.py index f52538c0cf368..b883c954de3b6 100644 --- a/torch/distributed/tensor/_dispatch.py +++ b/torch/distributed/tensor/_dispatch.py @@ -12,12 +12,7 @@ from torch._library.utils import fill_defaults from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta -from torch.distributed.tensor._op_schema import ( - OpInfo, - OpSchema, - OutputSharding, - OutputSpecType, -) +from torch.distributed.tensor._op_schema import OpInfo, OpSchema, OutputSpecType from torch.distributed.tensor._random import is_rng_supported_mesh from torch.distributed.tensor._redistribute import redistribute_local_tensor from torch.distributed.tensor._sharding_prop import ShardingPropagator @@ -130,8 +125,6 @@ class OpDispatcher: def __init__(self) -> None: self.sharding_propagator = ShardingPropagator() - # NOTE: must stay in sync with is_random_op in - # torch/csrc/autograd/python_variable.cpp self._random_ops = { aten.native_dropout.default, aten.normal_.default, @@ -154,17 +147,6 @@ def __init__(self) -> None: aten.as_strided.default: as_strided_handler, } - # ******************************************************************************************** - # def dispatch(...) - # - # NOTE: this class no longer contains the top-level dispatch entrypoint! - # See #167051 for details - # - # The entrypoint has been moved to C++, and it handles common cases and then calls back into - # OpDispatcher python to handle corner cases. - # See dispatchDTensorOp() defined in python_variable.cpp and called from python_arg_parser.cpp - # ******************************************************************************************** - # This flag is used internally to control whether we treat the torch.Tensor(non-DTensor) # as implicitly replicated or we throw error to user. # NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave @@ -177,17 +159,26 @@ def _allow_implicit_replication(self) -> bool: def _allow_implicit_replication(self, value: bool) -> None: return torch._C._set_dtensor_allow_implicit_replication(value) - def _propagate_op_sharding_non_cached_dispatch_slow_path( + def dispatch( self, op_call: torch._ops.OpOverload, args: tuple[object, ...], kwargs: dict[str, object], - op_info: OpInfo, ) -> object: + """ + Main dispatching logic. Follows precedence order: + (1) custom_op_handler + (2) registered sharding strategy, then rule + (3) composite implicit autograd decomposition + """ + if op_call in self._custom_op_handlers: + return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator] + + # extract local tensor and sharding infos to a OpInfo + op_info = self.unwrap_to_op_info(op_call, args, kwargs) + try: - return self.sharding_propagator.propagate_op_sharding_non_cached( - op_info.schema - ) + self.sharding_propagator.propagate(op_info) except NotImplementedError: if torch._C._dispatch_has_kernel_for_dispatch_key( op_call.name(), torch._C.DispatchKey.CompositeImplicitAutograd @@ -204,12 +195,6 @@ def _propagate_op_sharding_non_cached_dispatch_slow_path( f"{e}\n\nSharding propagation failed for {op_info.schema}" ) from e - def _dispatch_get_local_results_slow_path( - self, - op_call: torch._ops.OpOverload, - args: tuple[object, ...], - op_info: OpInfo, - ) -> object: output_sharding = op_info.output_sharding assert output_sharding is not None, "output sharding should not be None" @@ -281,7 +266,7 @@ def _dispatch_get_local_results_slow_path( # 2. if the return type is Tensor or List[Tensor], return empty # tensor(s) with correct dtype. spec = output_sharding.output_spec - ret_list = op_call._schema.returns + ret_list = op_info.schema.op._schema.returns if spec is None: # For a scalar return type, the non-participating device has None @@ -316,23 +301,6 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor: raise NotImplementedError( f"return type {ret_type} in DTensor op is not supported" ) - return local_results - - def _dispatch_fast_path_python_tail( - self, - op_call: torch._ops.OpOverload, - args: tuple[object, ...], - kwargs: dict[str, object], - compute_mesh: DeviceMesh, - output_sharding: OutputSharding, - local_results: object, - participating: bool, - is_inplace_op: bool, - is_out_variant_op: bool, - ) -> object: - """ - Tail of main dispatching logic, called from C++ fast path. - """ if output_sharding.output_spec is None: if op_call == aten.equal.default: @@ -342,12 +310,12 @@ def _dispatch_fast_path_python_tail( assert local_results is None or isinstance(local_results, bool) r = torch.tensor( int(local_results) if local_results is not None else 1, - device=compute_mesh.device_type, + device=mesh.device_type, ) dist.all_reduce(r, op=dist.ReduceOp.MIN) local_results = bool(r.item()) - if is_inplace_op: + if op_info.schema.is_inplace_op(): # inplace op should return self instead of re-wrapping if output_sharding.output_spec is not None: output_spec = output_sharding.output_spec @@ -381,7 +349,7 @@ def _dispatch_fast_path_python_tail( return args[0] else: return None - elif is_out_variant_op: + elif op_info.schema.is_out_variant_op(): # out variant could possibly have multiple out args (i.e. lu_unpack.out) output_specs = ( (output_sharding.output_spec,) @@ -400,9 +368,8 @@ def _dispatch_fast_path_python_tail( assert len(out_dts) >= 1, "out variant should have at least one out arg" return tuple(out_dts) if len(out_dts) > 1 else out_dts[0] else: - assert op_call == aten.equal.default, op_call ret = self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined] - if participating and op_call._schema._is_view_op(): + if participating and op_info.schema.is_view_op(): return return_and_correct_aliasing(op_call, args, kwargs, ret) else: return ret @@ -469,15 +436,6 @@ def unwrap_to_op_info( op_call: torch._ops.OpOverload, args: tuple[object, ...], kwargs: dict[str, object], - ) -> OpInfo: - return self._unwrap_to_op_info_impl(op_call, args, kwargs, True) - - def _unwrap_to_op_info_impl( - self, - op_call: torch._ops.OpOverload, - args: tuple[object, ...], - kwargs: dict[str, object], - create_schema: bool, ) -> OpInfo: # get runtime schema info to determine whether to use pytree to flatten inputs runtime_schema_info = self.sharding_propagator.op_to_schema_info.get( @@ -554,9 +512,7 @@ def _unwrap_to_op_info_impl( ), kwargs_schema, schema_info=runtime_schema_info, - ) - if create_schema - else None, # type: ignore[arg-type] + ), args_schema, tuple(local_args), local_kwargs, diff --git a/torch/distributed/tensor/debug/__init__.py b/torch/distributed/tensor/debug/__init__.py index e6aeca3b93a12..a74f1449ad125 100644 --- a/torch/distributed/tensor/debug/__init__.py +++ b/torch/distributed/tensor/debug/__init__.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -import torch._C from torch.distributed.tensor.debug._comm_mode import CommDebugMode from torch.distributed.tensor.debug._visualize_sharding import visualize_sharding @@ -7,12 +6,11 @@ __all__ = ["CommDebugMode", "visualize_sharding"] -def _get_python_sharding_prop_cache_info(): +def _get_sharding_prop_cache_info(): """ - Get the cache info for the Python sharding propagation cache, used for debugging purpose only. + Get the cache info for the sharding propagation cache, used for debugging purpose only. This would return a named tuple showing hits, misses, maxsize and cursize of the sharding - propagator cache. Note that directly calling into the sharding propagator does not share cache - state with the DTensor dispatch fast path! + propagator cache. """ from torch.distributed.tensor._api import DTensor @@ -21,17 +19,9 @@ def _get_python_sharding_prop_cache_info(): ) -def _get_fast_path_sharding_prop_cache_stats(): +def _clear_sharding_prop_cache(): """ - Get a tuple (hits, misses) for the fast path sharding propagation cache, used for debugging - only. - """ - return torch._C._get_DTensor_sharding_propagator_cache_stats() - - -def _clear_python_sharding_prop_cache(): - """ - Clears the cache for the Python sharding propagation cache, used for debugging purpose only. + Clears the cache for the sharding propagation cache, used for debugging purpose only. """ from torch.distributed.tensor._api import DTensor @@ -40,13 +30,6 @@ def _clear_python_sharding_prop_cache(): ) -def _clear_fast_path_sharding_prop_cache(): - """ - Clears the cache for the fast path sharding propagation cache, used for debugging purpose only. - """ - torch._C._clear_DTensor_sharding_propagator_cache() - - # Set namespace for exposed private names CommDebugMode.__module__ = "torch.distributed.tensor.debug" visualize_sharding.__module__ = "torch.distributed.tensor.debug" From d3ccb8f3d0b21f8cc33299c2b7441a4d1fbb83f6 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Fri, 21 Nov 2025 03:17:06 +0000 Subject: [PATCH 145/230] Remove c10::is_pod (#166383) `c10::is_pod` is not used in OSS. New code should instead use `std::is_trivial`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166383 Approved by: https://github.com/albanD --- c10/util/C++17.h | 9 --------- 1 file changed, 9 deletions(-) diff --git a/c10/util/C++17.h b/c10/util/C++17.h index 5dafb245f92e8..8e88a1ec50cc4 100644 --- a/c10/util/C++17.h +++ b/c10/util/C++17.h @@ -34,15 +34,6 @@ namespace c10 { -// std::is_pod is deprecated in C++20, std::is_standard_layout and -// std::is_trivial are introduced in C++11, std::conjunction has been introduced -// in C++17. -template -using is_pod = std::conjunction, std::is_trivial>; - -template -constexpr bool is_pod_v = is_pod::value; - namespace guts { #if defined(__HIP__) From 056d2635c696b536c8a3d4092a7983f53100291a Mon Sep 17 00:00:00 2001 From: Rob Timpe Date: Thu, 20 Nov 2025 22:33:00 +0000 Subject: [PATCH 146/230] Update numpy tests for python 3.11/3.12 (#168299) This fixes the dynamo-unittest workflow. This test was fixed in https://github.com/pytorch/pytorch/pull/167619 but wasn't caught because these tests are skipped for numpy versions > 2.0 (meaning our CI for 3.13 doesn't run them). As a separate task, we should update all of these tests for recent numpy versions so we can start running them again. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168299 Approved by: https://github.com/williamwen42 --- .../TestArrayCreationCopyArgument.test_striding_not_ok | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 test/dynamo_expected_failures/TestArrayCreationCopyArgument.test_striding_not_ok diff --git a/test/dynamo_expected_failures/TestArrayCreationCopyArgument.test_striding_not_ok b/test/dynamo_expected_failures/TestArrayCreationCopyArgument.test_striding_not_ok deleted file mode 100644 index e69de29bb2d1d..0000000000000 From 65b9892f5c1284bf2130a48b709efb8ad40a35ed Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Fri, 21 Nov 2025 04:22:14 +0000 Subject: [PATCH 147/230] Replace string with char for output (#168215) This PR replaces `"\n"` with `'\n'` in assertion macros for constructing error messages. The possible violating instances were searched by ``` grep '"[\]n"' -r aten c10 torch | grep -v py ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/168215 Approved by: https://github.com/Skylion007 --- aten/src/ATen/code_template.h | 2 +- aten/src/ATen/miopen/Descriptors.cpp | 4 ++-- aten/src/ATen/native/cudnn/Conv_v7.cpp | 2 +- aten/src/ATen/test/vec_test_all_types.h | 20 ++++++++++---------- torch/csrc/distributed/c10d/reducer.cpp | 2 +- torch/csrc/jit/codegen/onednn/kernel.cpp | 4 ++-- torch/csrc/jit/ir/subgraph_matcher.cpp | 4 ++-- torch/csrc/jit/jit_log.h | 4 ++-- torch/csrc/jit/tensorexpr/block_codegen.cpp | 2 +- torch/csrc/jit/tensorexpr/codegen.cpp | 2 +- torch/csrc/jit/tensorexpr/kernel.cpp | 4 ++-- torch/csrc/jit/tensorexpr/llvm_codegen.cpp | 2 +- 12 files changed, 26 insertions(+), 26 deletions(-) diff --git a/aten/src/ATen/code_template.h b/aten/src/ATen/code_template.h index 2cde802dac172..edc1124240251 100644 --- a/aten/src/ATen/code_template.h +++ b/aten/src/ATen/code_template.h @@ -232,7 +232,7 @@ struct CodeTemplate { emitIndent(out, indent); emitStringWithIndents(out, indent, strings[i]); if (i + 1 != strings.size()) - out << "\n"; + out << '\n'; } } std::string template_text; diff --git a/aten/src/ATen/miopen/Descriptors.cpp b/aten/src/ATen/miopen/Descriptors.cpp index 3fe27c7a0825b..7b3d790f067b5 100644 --- a/aten/src/ATen/miopen/Descriptors.cpp +++ b/aten/src/ATen/miopen/Descriptors.cpp @@ -114,8 +114,8 @@ void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_fo // that is the common case, so we can catch most client errors with this test. TORCH_CHECK(t.is_contiguous(memory_format), "MIOpen filters (a.k.a. weights) must be contiguous in desired memory_format\n", - "Weight sizes: ", t.sizes(), "\n", - "Weight strides: ", t.strides(), "\n", + "Weight sizes: ", t.sizes(), '\n', + "Weight strides: ", t.strides(), '\n', "cuDNN suggested memory_format: ", memory_format); int size[MIOPEN_DIM_MAX]; diff --git a/aten/src/ATen/native/cudnn/Conv_v7.cpp b/aten/src/ATen/native/cudnn/Conv_v7.cpp index d5102910c6471..74a3f0afb9c11 100644 --- a/aten/src/ATen/native/cudnn/Conv_v7.cpp +++ b/aten/src/ATen/native/cudnn/Conv_v7.cpp @@ -774,7 +774,7 @@ void raw_cudnn_convolution_forward_out_32bit( args, "Forward algorithm: ", static_cast(fwdAlgPerf.algo), - "\n"); + '\n'); }); } diff --git a/aten/src/ATen/test/vec_test_all_types.h b/aten/src/ATen/test/vec_test_all_types.h index f7206cc340973..7ee77f53d5377 100644 --- a/aten/src/ATen/test/vec_test_all_types.h +++ b/aten/src/ATen/test/vec_test_all_types.h @@ -843,22 +843,22 @@ class AssertVectorized std::stringstream stream; stream.precision(std::numeric_limits::max_digits10); stream << "Failure Details:\n"; - stream << additionalInfo << "\n"; + stream << additionalInfo << '\n'; if (hasSeed) { - stream << "Test Seed to reproduce: " << testSeed << "\n"; + stream << "Test Seed to reproduce: " << testSeed << '\n'; } if (argSize > 0) { stream << "Arguments:\n"; - stream << "#\t " << arg0 << "\n"; + stream << "#\t " << arg0 << '\n'; if (argSize == 2) { - stream << "#\t " << arg1 << "\n"; + stream << "#\t " << arg1 << '\n'; } if (argSize == 3) { - stream << "#\t " << arg2 << "\n"; + stream << "#\t " << arg2 << '\n'; } } stream << "Expected:\n#\t" << exp << "\nActual:\n#\t" << act; @@ -890,7 +890,7 @@ class AssertVectorized else if (checkWithTolerance) { for (const auto i : c10::irange(sizeX)) { - EXPECT_EQ(nearlyEqual(expArr[i], actArr[i], absErr), true) << expArr[i] << "!=" << actArr[i] << "\n" << getDetail(i / unitStorageCount); + EXPECT_EQ(nearlyEqual(expArr[i], actArr[i], absErr), true) << expArr[i] << "!=" << actArr[i] << '\n' << getDetail(i / unitStorageCount); if (::testing::Test::HasFailure()) return true; } @@ -1116,11 +1116,11 @@ void test_binary_fp8( if (is_bit_wise) { EXPECT_EQ(static_cast(ref_res_scalar), static_cast(res_scalar)) << "Test failed for input0: " << c10::detail::fp8e4m3fn_to_fp32_value(f8_0.x) - << " input1: " << c10::detail::fp8e4m3fn_to_fp32_value(f8_1.x) << "\n"; + << " input1: " << c10::detail::fp8e4m3fn_to_fp32_value(f8_1.x) << '\n'; } else { EXPECT_EQ(ref_res_scalar, res_scalar) << "Test failed for input0: " << c10::detail::fp8e4m3fn_to_fp32_value(f8_0.x) - << " input1: " << c10::detail::fp8e4m3fn_to_fp32_value(f8_1.x) << "\n"; + << " input1: " << c10::detail::fp8e4m3fn_to_fp32_value(f8_1.x) << '\n'; } } else { at::vec::cvtfp8e5m2_fp32(_mm512_castsi512_si128(res), res_fp32_512); @@ -1128,11 +1128,11 @@ void test_binary_fp8( if (is_bit_wise) { EXPECT_EQ(static_cast(ref_res_scalar), static_cast(res_scalar)) << "Test failed for input0: " << c10::detail::fp8e5m2_to_fp32_value(f8_0.x) - << " input1: " << c10::detail::fp8e5m2_to_fp32_value(f8_1.x) << "\n"; + << " input1: " << c10::detail::fp8e5m2_to_fp32_value(f8_1.x) << '\n'; } else { EXPECT_EQ(ref_res_scalar, res_scalar) << "Test failed for input0: " << c10::detail::fp8e5m2_to_fp32_value(f8_0.x) - << " input1: " << c10::detail::fp8e5m2_to_fp32_value(f8_1.x) << "\n"; + << " input1: " << c10::detail::fp8e5m2_to_fp32_value(f8_1.x) << '\n'; } } } diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index a1c9b4a3039d5..c4af19ef44209 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -341,7 +341,7 @@ void Reducer::check_grad_layout( grad.sizes(), ", strides() = ", grad.strides(), - "\n", + '\n', "bucket_view.sizes() = ", bucket_view.sizes(), ", strides() = ", diff --git a/torch/csrc/jit/codegen/onednn/kernel.cpp b/torch/csrc/jit/codegen/onednn/kernel.cpp index c5421643e8c43..85afc5fa8dc7b 100644 --- a/torch/csrc/jit/codegen/onednn/kernel.cpp +++ b/torch/csrc/jit/codegen/onednn/kernel.cpp @@ -28,7 +28,7 @@ LlgaKernel::LlgaKernel(const Node* fusionNode) partition_ = partitions[0]; nPartitionInputs_ = partition_.get_input_ports().size(); #ifdef GRAPH_DEBUG_ENABLED - GRAPH_DEBUG("Initialized ", debugName(), "\n", graph_->toString()); + GRAPH_DEBUG("Initialized ", debugName(), '\n', graph_->toString()); #endif } @@ -243,7 +243,7 @@ compiled_partition LlgaKernel::compile(const partition& partition) { void LlgaKernel::run(Stack& stack) { #ifdef GRAPH_DEBUG_ENABLED - GRAPH_DEBUG("In ", debugName(), "\n"); + GRAPH_DEBUG("In ", debugName(), '\n'); #endif // Grab input values from stack diff --git a/torch/csrc/jit/ir/subgraph_matcher.cpp b/torch/csrc/jit/ir/subgraph_matcher.cpp index 17a82dc4ac6c3..37dd8e3280de9 100644 --- a/torch/csrc/jit/ir/subgraph_matcher.cpp +++ b/torch/csrc/jit/ir/subgraph_matcher.cpp @@ -272,8 +272,8 @@ bool SubgraphMatcher::matchNodes(const Node* n1, Node* n2) { if (!endsWith(real_typename, pattern_typename)) { GRAPH_DEBUG( "Nodes did not match because expected module type is different:\n"); - GRAPH_DEBUG(" actualtype: ", real_typename, "\n"); - GRAPH_DEBUG(" expected type: ", pattern_typename, "\n"); + GRAPH_DEBUG(" actualtype: ", real_typename, '\n'); + GRAPH_DEBUG(" expected type: ", pattern_typename, '\n'); GRAPH_DEBUG("Nodes:", *n1, *n2); return false; } diff --git a/torch/csrc/jit/jit_log.h b/torch/csrc/jit/jit_log.h index 49851e81082a7..333c3edcdd01f 100644 --- a/torch/csrc/jit/jit_log.h +++ b/torch/csrc/jit/jit_log.h @@ -95,12 +95,12 @@ TORCH_API std::ostream& operator<<( JIT_LOG( \ ::torch::jit::JitLoggingLevels::GRAPH_DUMP, \ MSG, \ - "\n", \ + '\n', \ ::torch::jit::log_function(G)); // use GRAPH_DUMP for dumping graphs after optimization passes #define GRAPH_DUMP(MSG, G) \ JIT_LOG( \ - ::torch::jit::JitLoggingLevels::GRAPH_DUMP, MSG, "\n", (G)->toString()); + ::torch::jit::JitLoggingLevels::GRAPH_DUMP, MSG, '\n', (G)->toString()); // use GRAPH_UPDATE for reporting graph transformations (i.e. node deletion, // constant folding, CSE) #define GRAPH_UPDATE(...) \ diff --git a/torch/csrc/jit/tensorexpr/block_codegen.cpp b/torch/csrc/jit/tensorexpr/block_codegen.cpp index 6ec55f998cce0..99dd289fb0964 100644 --- a/torch/csrc/jit/tensorexpr/block_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/block_codegen.cpp @@ -351,7 +351,7 @@ void BlockCodeGen::Initialize() { stmt_v->accept(printer_.get()); - GRAPH_DEBUG("Generated Block code: ", oss_.str(), "\n"); + GRAPH_DEBUG("Generated Block code: ", oss_.str(), '\n'); } void BlockCodeGen::call(const std::vector& args) { diff --git a/torch/csrc/jit/tensorexpr/codegen.cpp b/torch/csrc/jit/tensorexpr/codegen.cpp index b19a8b8964ad5..04034554e25ed 100644 --- a/torch/csrc/jit/tensorexpr/codegen.cpp +++ b/torch/csrc/jit/tensorexpr/codegen.cpp @@ -320,7 +320,7 @@ void CodeGen::allocIntermediateBufs() { set_stmt(stmt_new); } - GRAPH_DEBUG("\nMemory Allocation:\n\n", *stmt(), "\n"); + GRAPH_DEBUG("\nMemory Allocation:\n\n", *stmt(), '\n'); } } // namespace torch::jit::tensorexpr diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index cc15663720383..d696d29bf733e 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -795,7 +795,7 @@ static void parallelizeOuterLoops(LoopNest& l, const Bufs& bufs) { StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) { torch::jit::tensorexpr::LoopNest l(std::move(st), bufOutputs_); LoopNest::sanitizeNames(l.root_stmt()); - GRAPH_DEBUG("Original Stmt:\n", std::to_string(l.root_stmt()), "\n"); + GRAPH_DEBUG("Original Stmt:\n", std::to_string(l.root_stmt()), '\n'); int64_t random_tr_seed = randomTransformsRequested(); if (random_tr_seed) { if (random_tr_seed == -1) @@ -939,7 +939,7 @@ StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) { StmtPtr stmt = l.root_stmt(); // Arithmetic Simplification. stmt = IRSimplifier::simplify(stmt); - GRAPH_DEBUG("Final Stmt:\n", std::to_string(stmt), "\n"); + GRAPH_DEBUG("Final Stmt:\n", std::to_string(stmt), '\n'); return stmt; } diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index 918d82579444f..17db0872eb78f 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -777,7 +777,7 @@ void LLVMCodeGenImpl::emitKernel( PM.run(*module_); asmCode_ = asmStream.str().str(); - GRAPH_DEBUG("\nLLVM generated assembly code\n\n", asmCode_, "\n"); + GRAPH_DEBUG("\nLLVM generated assembly code\n\n", asmCode_, '\n'); } // TODO: The binary ops are copypaste. From 6707dc8e444de405401f2e36e2868a19bb7ba43a Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Fri, 21 Nov 2025 05:06:40 +0000 Subject: [PATCH 148/230] Revert #154859 (#168297) We suspect it's causing intermittent segfaults Pull Request resolved: https://github.com/pytorch/pytorch/pull/168297 Approved by: https://github.com/malfet --- test/profiler/test_execution_trace.py | 7 +++ .../standalone/execution_trace_observer.cpp | 59 +------------------ 2 files changed, 10 insertions(+), 56 deletions(-) diff --git a/test/profiler/test_execution_trace.py b/test/profiler/test_execution_trace.py index dbd5d89ad6a61..26c0ab42905de 100644 --- a/test/profiler/test_execution_trace.py +++ b/test/profiler/test_execution_trace.py @@ -2,6 +2,7 @@ import json import os +import sys import tempfile import unittest from typing import Any @@ -364,6 +365,9 @@ def test_execution_trace_env_disabled(self, device): self.assertTrue(p.execution_trace_observer is None) @unittest.skipIf(IS_WINDOWS, "torch.compile does not support WINDOWS") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) @unittest.skipIf( (not has_triton()) or (not TEST_CUDA and not TEST_XPU), "need triton and device(CUDA or XPU) availability to run", @@ -419,6 +423,9 @@ def fn(a, b, c): assert found_call_compiled_fx_graph @unittest.skipIf(IS_WINDOWS, "torch.compile does not support WINDOWS") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) @unittest.skipIf( (not has_triton()) or (not TEST_CUDA and not TEST_XPU), "need triton and device(CUDA or XPU) availability to run", diff --git a/torch/csrc/profiler/standalone/execution_trace_observer.cpp b/torch/csrc/profiler/standalone/execution_trace_observer.cpp index 29b2b94af4472..b46e1d19bcd0e 100644 --- a/torch/csrc/profiler/standalone/execution_trace_observer.cpp +++ b/torch/csrc/profiler/standalone/execution_trace_observer.cpp @@ -112,59 +112,8 @@ struct TORCH_API ExecutionTraceObserver { // NOLINT std::map> opStack; // Uses the underlying TensorImpl object pointer as the key and map to its // unique id. - std::map objectId; - - using weak_storage_ptr = c10::weak_intrusive_ptr; - std::unordered_map data_ptr_to_storage_id; - std::unordered_map - data_ptr_to_weak_storage_ptr; - - ID get_tensor_storage_ID(const c10::Storage& t_storage) { - const std::lock_guard lock(gMutex); - - const void* raw_data_ptr = nullptr; - bool should_track_liveness = false; - // FakeTensor/FunctionalTensor may clear the Storage handle entirely or use - // a nullptr data pointer. Treat both cases as a shared cache key but avoid - // touching the weak-ref table so they can reuse the same ID without - // tripping the liveness check. - if (t_storage.unsafeGetStorageImpl()) { - raw_data_ptr = t_storage.data(); - should_track_liveness = raw_data_ptr != nullptr; - } - - auto id_iter = data_ptr_to_storage_id.find(raw_data_ptr); - if (!should_track_liveness) { - if (id_iter != data_ptr_to_storage_id.end()) { - return id_iter->second; - } - ID id = storage_id_++; - data_ptr_to_storage_id.emplace(raw_data_ptr, id); - return id; - } - - auto weak_iter = data_ptr_to_weak_storage_ptr.find(raw_data_ptr); - if (weak_iter == data_ptr_to_weak_storage_ptr.end()) { - ID id = storage_id_++; - data_ptr_to_storage_id.insert_or_assign(raw_data_ptr, id); - data_ptr_to_weak_storage_ptr.emplace( - raw_data_ptr, t_storage.getWeakStorageImpl()); - return id; - } - - if (weak_iter->second.expired()) { - ID id = storage_id_++; - data_ptr_to_storage_id.insert_or_assign(raw_data_ptr, id); - data_ptr_to_weak_storage_ptr.insert_or_assign( - raw_data_ptr, t_storage.getWeakStorageImpl()); - return id; - } - - id_iter = data_ptr_to_storage_id.find(raw_data_ptr); - TORCH_INTERNAL_ASSERT(id_iter != data_ptr_to_storage_id.end()); - return id_iter->second; - } + std::map objectId{}; // Observer run state. enum class RunState { uninitialized, disabled, enabled }; @@ -227,8 +176,6 @@ struct TORCH_API ExecutionTraceObserver { // NOLINT // 1 -> root ID // 2 ... -> regular node ID std::atomic id_{2}; - - std::atomic storage_id_{1}; }; // Using a singleton manager here to allow init and delete the observer object. @@ -499,8 +446,8 @@ convertIValue( // symbolic sizes/strides implies t->storage_offset() will fail if (tensor_impl->has_storage() && !tensor_impl->has_symbolic_sizes_strides()) { - const c10::Storage& t_storage = tensor_impl->storage(); - storage_id = ob.get_tensor_storage_ID(t_storage); + auto& t_storage = tensor_impl->storage(); + storage_id = getObjectID(ob, t_storage.data()); offset = tensor_impl->storage_offset(); numel = tensor_impl->numel(); itemsize = tensor_impl->itemsize(); From b026eb96cac17524c221d9967fe6727bec5db073 Mon Sep 17 00:00:00 2001 From: Bartlomiej Stemborowski Date: Fri, 21 Nov 2025 05:27:34 +0000 Subject: [PATCH 149/230] Fix EmbeddingBag when input is 2D and include_last_offset is True (#168159) Before this change, EmbeddingBag generated incorrect offsets when the input indices were 2D and include_last_offset was set to True, resulting in one fewer bag than expected. This PR also aligns the documentation for the include_last_offset parameter between nn.EmbeddingBag and nn.functional.embedding_bag. Fixes #167974 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168159 Approved by: https://github.com/ezyang, https://github.com/malfet --- test/nn/test_embedding.py | 8 ++++++++ torch/nn/functional.py | 6 ++++-- torch/nn/modules/sparse.py | 6 ++++-- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/test/nn/test_embedding.py b/test/nn/test_embedding.py index e1411f3101e22..ae0f137e152ce 100644 --- a/test/nn/test_embedding.py +++ b/test/nn/test_embedding.py @@ -285,6 +285,14 @@ def test_embeddingbag_include_last_offset(self): self.assertEqual(ref_out, out) self.assertEqual(ref_out, out2) + def test_embeddingbag_2d_include_last_offset(self): + # Test case from https://github.com/pytorch/pytorch/issues/167974 + embedding_sum = torch.nn.EmbeddingBag(10, 3, include_last_offset=True) + input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]], dtype=torch.long) + res = embedding_sum(input) + # Check if number of bags matches + self.assertTrue(res.shape[0] == input.shape[0]) + class TestEmbeddingNNDeviceType(NNTestCase): def test_embedding_dense_grad(self, device): diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 0ee7b0f964fee..d31b99a59de21 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -2613,7 +2613,9 @@ def embedding_bag( :attr:`offsets`, if those are not None. include_last_offset (bool, optional): if ``True``, the size of offsets is equal to the number of bags + 1. - The last element is the size of the input, or the ending index position of the last bag (sequence). + The last element is the size of the input, or the ending index position + of the last bag (sequence). This matches the CSR format. Ignored when + input is 2D. Default ``False``. padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated @@ -2724,7 +2726,7 @@ def embedding_bag( offsets = torch.arange( 0, input.numel(), input.size(1), dtype=input.dtype, device=input.device ) - + include_last_offset = False input = input.reshape(-1) if per_sample_weights is not None: per_sample_weights = per_sample_weights.reshape(-1) diff --git a/torch/nn/modules/sparse.py b/torch/nn/modules/sparse.py index e3b8fafa6a274..83a8d6ef334bb 100644 --- a/torch/nn/modules/sparse.py +++ b/torch/nn/modules/sparse.py @@ -304,8 +304,10 @@ class EmbeddingBag(Module): sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See Notes for more details regarding sparse gradients. Note: this option is not supported when ``mode="max"``. - include_last_offset (bool, optional): if ``True``, :attr:`offsets` has one additional element, where the last element - is equivalent to the size of `indices`. This matches the CSR format. + include_last_offset (bool, optional): if ``True``, the size of offsets is equal to the number of bags + 1. + The last element is the size of the input, or the ending index position + of the last bag (sequence). This matches the CSR format. Ignored when + input is 2D. Default ``False``. padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated during training, i.e. it remains as a fixed "pad". For a newly constructed From 61cdf872183659c72c2466158c15e9bfafc20678 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Fri, 21 Nov 2025 05:31:33 +0000 Subject: [PATCH 150/230] dist: add list_keys to Store API (#167883) This adds a `list` Store API and implements it for all backends. This is intended to be used for debugging and will allow inspecting all keys in a store locally as well as remotely in the case of TCPStore. Test plan: ``` pytest test/distributed/test_store.py ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/167883 Approved by: https://github.com/fduwjj --- test/distributed/test_store.py | 8 ++++++ torch/_C/_distributed_c10d.pyi | 1 + torch/csrc/distributed/c10d/FileStore.cpp | 13 +++++++++ torch/csrc/distributed/c10d/FileStore.hpp | 2 ++ torch/csrc/distributed/c10d/HashStore.cpp | 10 +++++++ torch/csrc/distributed/c10d/HashStore.hpp | 2 ++ torch/csrc/distributed/c10d/PrefixStore.cpp | 14 ++++++++++ torch/csrc/distributed/c10d/PrefixStore.hpp | 2 ++ torch/csrc/distributed/c10d/Store.hpp | 5 ++++ torch/csrc/distributed/c10d/TCPStore.cpp | 24 ++++++++++++++++ torch/csrc/distributed/c10d/TCPStore.hpp | 2 ++ .../csrc/distributed/c10d/TCPStoreBackend.cpp | 10 +++++++ .../csrc/distributed/c10d/TCPStoreBackend.hpp | 1 + .../distributed/c10d/TCPStoreLibUvBackend.cpp | 28 +++++++++++++++++++ torch/csrc/distributed/c10d/init.cpp | 6 ++++ 15 files changed, 128 insertions(+) diff --git a/test/distributed/test_store.py b/test/distributed/test_store.py index 5e063d373ffb5..e1412701807b6 100644 --- a/test/distributed/test_store.py +++ b/test/distributed/test_store.py @@ -253,6 +253,14 @@ def test_clone(self): a.set("foo", "bar") self.assertEqual(b.get("foo"), b"bar") + def test_list_keys(self): + a = self._create_store() + a.set("foo", "bar") + a.set("baz", "qux") + keys = a.list_keys() + self.assertIn("foo", keys) + self.assertIn("baz", keys) + # This is the number of keys used in test_set_get. Adding this as a class # property instead of hardcoding in the test since some Store # implementations will have differing number of keys. In the base case, diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index a80efc696e17d..477b35b1811e4 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -218,6 +218,7 @@ class Store: def queue_pop(self, key: str, block: bool = True) -> bytes: ... def queue_push(self, key: str, value: Union[bytes, str]) -> None: ... def queue_len(self, key: str) -> int: ... + def list_keys(self) -> list[str]: ... class FileStore(Store): def __init__(self, path: str, numWorkers: int = ...) -> None: ... diff --git a/torch/csrc/distributed/c10d/FileStore.cpp b/torch/csrc/distributed/c10d/FileStore.cpp index 7e22aa6fd0bd5..969379e739438 100644 --- a/torch/csrc/distributed/c10d/FileStore.cpp +++ b/torch/csrc/distributed/c10d/FileStore.cpp @@ -492,4 +492,17 @@ void FileStore::wait( } } +std::vector FileStore::listKeys() { + std::unique_lock l(activeFileOpLock_); + File file(path_, O_RDONLY, timeout_); + auto lock = file.lockShared(); + pos_ = refresh(file, pos_, cache_, deletePrefix_); + std::vector keys; + keys.reserve(cache_.size()); + for (const auto& kv : cache_) { + keys.push_back(kv.first.substr(regularPrefix_.size())); + } + return keys; +} + } // namespace c10d diff --git a/torch/csrc/distributed/c10d/FileStore.hpp b/torch/csrc/distributed/c10d/FileStore.hpp index 563ac76e03bf5..11ded19d8125a 100644 --- a/torch/csrc/distributed/c10d/FileStore.hpp +++ b/torch/csrc/distributed/c10d/FileStore.hpp @@ -45,6 +45,8 @@ class TORCH_API FileStore : public Store { return path_; } + std::vector listKeys() override; + protected: int64_t addHelper(const std::string& key, int64_t i); diff --git a/torch/csrc/distributed/c10d/HashStore.cpp b/torch/csrc/distributed/c10d/HashStore.cpp index 15befd9ec34e2..9073333fb9a48 100644 --- a/torch/csrc/distributed/c10d/HashStore.cpp +++ b/torch/csrc/distributed/c10d/HashStore.cpp @@ -217,4 +217,14 @@ int64_t HashStore::queueLen(const std::string& key) { return static_cast(it->second.size()); } +std::vector HashStore::listKeys() { + std::unique_lock lock(m_); + std::vector keys; + keys.reserve(map_.size()); + for (const auto& kv : map_) { + keys.push_back(kv.first); + } + return keys; +} + } // namespace c10d diff --git a/torch/csrc/distributed/c10d/HashStore.hpp b/torch/csrc/distributed/c10d/HashStore.hpp index 4007d543a9371..f7aca03de8b22 100644 --- a/torch/csrc/distributed/c10d/HashStore.hpp +++ b/torch/csrc/distributed/c10d/HashStore.hpp @@ -59,6 +59,8 @@ class TORCH_API HashStore : public Store { int64_t queueLen(const std::string& key) override; + std::vector listKeys() override; + protected: bool checkLocked( const std::unique_lock& lock, diff --git a/torch/csrc/distributed/c10d/PrefixStore.cpp b/torch/csrc/distributed/c10d/PrefixStore.cpp index 057d198f93c2d..fa228c4467f01 100644 --- a/torch/csrc/distributed/c10d/PrefixStore.cpp +++ b/torch/csrc/distributed/c10d/PrefixStore.cpp @@ -146,4 +146,18 @@ c10::intrusive_ptr PrefixStore::getUnderlyingNonPrefixStore() { return store; } +std::vector PrefixStore::listKeys() { + auto keys = store_->listKeys(); + std::vector filteredKeys; + filteredKeys.reserve(keys.size()); + + for (auto& key : keys) { + if (key.find(prefix_) == 0) { + key = key.substr(prefix_.size() + 1); + filteredKeys.push_back(std::move(key)); + } + } + return filteredKeys; +} + } // namespace c10d diff --git a/torch/csrc/distributed/c10d/PrefixStore.hpp b/torch/csrc/distributed/c10d/PrefixStore.hpp index 627d2153bb22b..f950ff96590a3 100644 --- a/torch/csrc/distributed/c10d/PrefixStore.hpp +++ b/torch/csrc/distributed/c10d/PrefixStore.hpp @@ -64,6 +64,8 @@ class TORCH_API PrefixStore : public Store { // Recursively to fetch the store before layers of wrapping with PrefixStore. c10::intrusive_ptr getUnderlyingNonPrefixStore(); + std::vector listKeys() override; + protected: std::string prefix_; c10::intrusive_ptr store_; diff --git a/torch/csrc/distributed/c10d/Store.hpp b/torch/csrc/distributed/c10d/Store.hpp index 8260d33597d9c..9a037c65ee7c2 100644 --- a/torch/csrc/distributed/c10d/Store.hpp +++ b/torch/csrc/distributed/c10d/Store.hpp @@ -114,6 +114,11 @@ class TORCH_API Store : public torch::CustomClassHolder { C10_THROW_ERROR(NotImplementedError, "queue support is not implemented."); } + virtual std::vector listKeys() { + C10_THROW_ERROR( + NotImplementedError, "listKeys support is not implemented."); + } + protected: std::chrono::milliseconds timeout_; }; diff --git a/torch/csrc/distributed/c10d/TCPStore.cpp b/torch/csrc/distributed/c10d/TCPStore.cpp index b664c5d3bb963..9f566032b5b3c 100644 --- a/torch/csrc/distributed/c10d/TCPStore.cpp +++ b/torch/csrc/distributed/c10d/TCPStore.cpp @@ -723,6 +723,30 @@ int64_t TCPStore::queueLen(const std::string& key) { return client_->receiveValue(); } +std::vector TCPStore::listKeys() { + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__list); + + const std::lock_guard lock(activeOpLock_); + + detail::SendBuffer buffer(*client_, detail::QueryType::LIST_KEYS); + buffer.flush(); + + auto numKeys = client_->receiveValue(); + std::vector keys; + keys.reserve(numKeys); + for (auto i = 0; i < numKeys; ++i) { + auto bits = client_->receiveBits(); + std::string str(bits.begin(), bits.end()); + if (str.find(keyPrefix_) == 0) { + str = str.substr(keyPrefix_.size()); + } else { + continue; + } + keys.emplace_back(str); + } + return keys; +} + bool TCPStore::hasExtendedApi() const { return true; } diff --git a/torch/csrc/distributed/c10d/TCPStore.hpp b/torch/csrc/distributed/c10d/TCPStore.hpp index 2caab088a609a..09d7ae111c57a 100644 --- a/torch/csrc/distributed/c10d/TCPStore.hpp +++ b/torch/csrc/distributed/c10d/TCPStore.hpp @@ -121,6 +121,8 @@ class TORCH_API TCPStore : public Store { int64_t queueLen(const std::string& key) override; + std::vector listKeys() override; + // Waits for all workers to join. void waitForWorkers(); diff --git a/torch/csrc/distributed/c10d/TCPStoreBackend.cpp b/torch/csrc/distributed/c10d/TCPStoreBackend.cpp index 22455a22a4610..dd25729a6ee13 100644 --- a/torch/csrc/distributed/c10d/TCPStoreBackend.cpp +++ b/torch/csrc/distributed/c10d/TCPStoreBackend.cpp @@ -78,6 +78,7 @@ class TCPStoreMasterDaemon : public BackgroundThread { void multiGetHandler(int socket); void multiSetHandler(int socket); void cancelWaitHandler(int socket); + void listKeysHandler(int socket); void addMiscellaneousSocket(int socket); void removeMiscellaneousSocket(int socket); bool isMiscellaneousSocket(int socket); @@ -295,6 +296,8 @@ void TCPStoreMasterDaemon::query(int socket) { multiSetHandler(socket); } else if (qt == QueryType::CANCEL_WAIT) { cancelWaitHandler(socket); + } else if (qt == QueryType::LIST_KEYS) { + listKeysHandler(socket); } else { TORCH_CHECK(false, "Unexpected query type"); } @@ -482,6 +485,13 @@ void TCPStoreMasterDaemon::cancelWaitHandler(int socket) { socket, detail::WaitResponseType::WAIT_CANCELED); } +void TCPStoreMasterDaemon::listKeysHandler(int socket) { + tcputil::sendValue(socket, tcpStore_.size()); + for (const auto& kv : tcpStore_) { + tcputil::sendString(socket, kv.first); + } +} + bool TCPStoreMasterDaemon::checkKeys( const std::vector& keys) const { return std::all_of(keys.begin(), keys.end(), [this](const std::string& s) { diff --git a/torch/csrc/distributed/c10d/TCPStoreBackend.hpp b/torch/csrc/distributed/c10d/TCPStoreBackend.hpp index d5f7f0248bba5..d176ccb702838 100644 --- a/torch/csrc/distributed/c10d/TCPStoreBackend.hpp +++ b/torch/csrc/distributed/c10d/TCPStoreBackend.hpp @@ -36,6 +36,7 @@ enum class QueryType : uint8_t { QUEUE_PUSH, QUEUE_POP, QUEUE_LEN, + LIST_KEYS, }; enum class CheckResponseType : uint8_t { READY, NOT_READY }; diff --git a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp index edb640785a170..7427848b8445b 100644 --- a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp +++ b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp @@ -683,6 +683,7 @@ class LibUVStoreDaemon : public BackgroundThread { const std::string& queueName, const c10::intrusive_ptr& client); int64_t queueLen(const std::string& queueName); + std::vector listKeys(); void registerClient(const c10::intrusive_ptr& client); void unregisterClient(const c10::intrusive_ptr& client); @@ -822,6 +823,10 @@ class UvClient : public UvTcpSocket { if (!parse_queue_len_command()) return; break; + case QueryType::LIST_KEYS: + if (!parse_list_keys_command()) + return; + break; default: C10D_DEBUG( "Client sent invalid command. client:{} command:{}", @@ -1164,6 +1169,19 @@ class UvClient : public UvTcpSocket { return true; } + bool parse_list_keys_command() { + C10D_TRACE("list_keys address:{}", this->address()); + + auto keys = store->listKeys(); + StreamWriter sw(iptr()); + sw.write_value(static_cast(keys.size())); + for (const auto& key : keys) { + sw.write_string(key); + } + sw.send(); + return true; + } + public: explicit UvClient(uv_loop_t* loop, LibUVStoreDaemon* store) : UvTcpSocket(loop), store(store) {} @@ -1542,6 +1560,16 @@ int64_t LibUVStoreDaemon::queueLen(const std::string& key) { } return static_cast(it->second.size()); } + +std::vector LibUVStoreDaemon::listKeys() { + std::vector keys; + keys.reserve(tcpStore_.size()); + for (const auto& kv : tcpStore_) { + keys.push_back(kv.first); + } + return keys; +} + #endif std::unique_ptr create_libuv_tcpstore_backend( diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 6f38cd9cd2c6f..255e793eaa4df 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1657,6 +1657,12 @@ See queue_push for more details. Arguments: key (str): The key of the queue to get the length. +)") + .def( + "list_keys", + &::c10d::Store::listKeys, + R"( +Returns a list of all keys in the store. )") .def( "has_extended_api", From 6038c592be04c2a7b61623819c0efe3c76e164c5 Mon Sep 17 00:00:00 2001 From: ruisizhang123 Date: Fri, 21 Nov 2025 05:36:26 +0000 Subject: [PATCH 151/230] [simplefsdp] fix simplefsdp llama3 run (#168311) Fixes for https://github.com/pytorch/torchtitan/issues/2066 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168311 Approved by: https://github.com/eellison --- torch/_inductor/fx_passes/bucketing.py | 3 ++- .../_inductor/fx_passes/overlap_preserving_bucketer.py | 10 +++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index 00737a3b6e3b7..aba2c5182264a 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -66,7 +66,8 @@ def _schedulable_wait_node(node: torch.fx.Node) -> bool: if not is_wait_tensor(node): return False assert isinstance(node.args[0], torch.fx.Node) - assert isinstance(node.args[0].target.name(), str) + if not isinstance(node.args[0].target, Callable): + return False is_callable: bool = node.args[0].op == "call_function" coll: NCCL_COLL = get_collective_type_from_kernel_name(node.args[0].target.name()) is_collective: bool = coll != NCCL_COLL.UNSUPPORTED diff --git a/torch/_inductor/fx_passes/overlap_preserving_bucketer.py b/torch/_inductor/fx_passes/overlap_preserving_bucketer.py index ed37c0902c325..eb239a3a219a6 100644 --- a/torch/_inductor/fx_passes/overlap_preserving_bucketer.py +++ b/torch/_inductor/fx_passes/overlap_preserving_bucketer.py @@ -9,12 +9,12 @@ from torch._dynamo.utils import counters from torch._inductor.augmented_graph_helper import AugmentedGraphHelper from torch._inductor.fx_passes.bucketing import ( + _schedulable_wait_node, bucket_key, BucketMode, has_mergeable_all_gather_convert_dtype, is_all_gather_into_tensor as is_all_gather, is_reduce_scatter_tensor as is_reduce_scatter, - is_wait_tensor, ) from torch._inductor.fx_passes.overlap_scheduling import ( CollBucket, @@ -53,12 +53,12 @@ def __call__(self, reason: str, *args: Any) -> None: def is_collective_or_wait(n: fx.Node) -> bool: """Check if node is a collective start or wait.""" - if is_wait_tensor(n): + if _schedulable_wait_node(n): return True # Collective starts have exactly one use: the wait_tensor if len(n.users) == 1: user = next(iter(n.users.keys())) - if is_wait_tensor(user): + if _schedulable_wait_node(user): return True return False @@ -214,7 +214,7 @@ def build_timeline(self, pg: str) -> Optional[PGEvent]: if node in self.collective_info and get_group_name(node) == pg: node_type = "starts" hiding_nodes |= self.collective_info[node].hiding_nodes - elif is_wait_tensor(node): + elif _schedulable_wait_node(node): wait_input = node.args[0] if isinstance(wait_input, fx.Node) and get_group_name(wait_input) == pg: node_type = "waits" @@ -844,7 +844,7 @@ def _apply_bucket(self, bucket_info: CollBucket) -> None: ) # Get new nodes - new_waits = [n for n in new_nodes if is_wait_tensor(n)] + new_waits = [n for n in new_nodes if _schedulable_wait_node(n)] assert len(new_waits) == 1 new_wait = new_waits[0] From 28656727473a6361ee0ebfa5f412746807343651 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 21 Nov 2025 06:14:41 +0000 Subject: [PATCH 152/230] Revert "Revert #154859 (#168297)" This reverts commit 6707dc8e444de405401f2e36e2868a19bb7ba43a. Reverted https://github.com/pytorch/pytorch/pull/168297 on behalf of https://github.com/yangw-dev due to this seems breaks the trunk ##[error]Process completed with exit code 2. ([comment](https://github.com/pytorch/pytorch/pull/168297#issuecomment-3561558790)) --- test/profiler/test_execution_trace.py | 7 --- .../standalone/execution_trace_observer.cpp | 59 ++++++++++++++++++- 2 files changed, 56 insertions(+), 10 deletions(-) diff --git a/test/profiler/test_execution_trace.py b/test/profiler/test_execution_trace.py index 26c0ab42905de..dbd5d89ad6a61 100644 --- a/test/profiler/test_execution_trace.py +++ b/test/profiler/test_execution_trace.py @@ -2,7 +2,6 @@ import json import os -import sys import tempfile import unittest from typing import Any @@ -365,9 +364,6 @@ def test_execution_trace_env_disabled(self, device): self.assertTrue(p.execution_trace_observer is None) @unittest.skipIf(IS_WINDOWS, "torch.compile does not support WINDOWS") - @unittest.skipIf( - sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" - ) @unittest.skipIf( (not has_triton()) or (not TEST_CUDA and not TEST_XPU), "need triton and device(CUDA or XPU) availability to run", @@ -423,9 +419,6 @@ def fn(a, b, c): assert found_call_compiled_fx_graph @unittest.skipIf(IS_WINDOWS, "torch.compile does not support WINDOWS") - @unittest.skipIf( - sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" - ) @unittest.skipIf( (not has_triton()) or (not TEST_CUDA and not TEST_XPU), "need triton and device(CUDA or XPU) availability to run", diff --git a/torch/csrc/profiler/standalone/execution_trace_observer.cpp b/torch/csrc/profiler/standalone/execution_trace_observer.cpp index b46e1d19bcd0e..29b2b94af4472 100644 --- a/torch/csrc/profiler/standalone/execution_trace_observer.cpp +++ b/torch/csrc/profiler/standalone/execution_trace_observer.cpp @@ -112,8 +112,59 @@ struct TORCH_API ExecutionTraceObserver { // NOLINT std::map> opStack; // Uses the underlying TensorImpl object pointer as the key and map to its // unique id. + std::map objectId; + + using weak_storage_ptr = c10::weak_intrusive_ptr; + std::unordered_map data_ptr_to_storage_id; + std::unordered_map + data_ptr_to_weak_storage_ptr; + + ID get_tensor_storage_ID(const c10::Storage& t_storage) { + const std::lock_guard lock(gMutex); + + const void* raw_data_ptr = nullptr; + bool should_track_liveness = false; + // FakeTensor/FunctionalTensor may clear the Storage handle entirely or use + // a nullptr data pointer. Treat both cases as a shared cache key but avoid + // touching the weak-ref table so they can reuse the same ID without + // tripping the liveness check. + if (t_storage.unsafeGetStorageImpl()) { + raw_data_ptr = t_storage.data(); + should_track_liveness = raw_data_ptr != nullptr; + } + + auto id_iter = data_ptr_to_storage_id.find(raw_data_ptr); + if (!should_track_liveness) { + if (id_iter != data_ptr_to_storage_id.end()) { + return id_iter->second; + } + ID id = storage_id_++; + data_ptr_to_storage_id.emplace(raw_data_ptr, id); + return id; + } + + auto weak_iter = data_ptr_to_weak_storage_ptr.find(raw_data_ptr); + if (weak_iter == data_ptr_to_weak_storage_ptr.end()) { + ID id = storage_id_++; + data_ptr_to_storage_id.insert_or_assign(raw_data_ptr, id); + data_ptr_to_weak_storage_ptr.emplace( + raw_data_ptr, t_storage.getWeakStorageImpl()); + return id; + } + + if (weak_iter->second.expired()) { + ID id = storage_id_++; + data_ptr_to_storage_id.insert_or_assign(raw_data_ptr, id); + data_ptr_to_weak_storage_ptr.insert_or_assign( + raw_data_ptr, t_storage.getWeakStorageImpl()); + return id; + } + + id_iter = data_ptr_to_storage_id.find(raw_data_ptr); + TORCH_INTERNAL_ASSERT(id_iter != data_ptr_to_storage_id.end()); + return id_iter->second; + } - std::map objectId{}; // Observer run state. enum class RunState { uninitialized, disabled, enabled }; @@ -176,6 +227,8 @@ struct TORCH_API ExecutionTraceObserver { // NOLINT // 1 -> root ID // 2 ... -> regular node ID std::atomic id_{2}; + + std::atomic storage_id_{1}; }; // Using a singleton manager here to allow init and delete the observer object. @@ -446,8 +499,8 @@ convertIValue( // symbolic sizes/strides implies t->storage_offset() will fail if (tensor_impl->has_storage() && !tensor_impl->has_symbolic_sizes_strides()) { - auto& t_storage = tensor_impl->storage(); - storage_id = getObjectID(ob, t_storage.data()); + const c10::Storage& t_storage = tensor_impl->storage(); + storage_id = ob.get_tensor_storage_ID(t_storage); offset = tensor_impl->storage_offset(); numel = tensor_impl->numel(); itemsize = tensor_impl->itemsize(); From 4ee6b3d60c85d847212901248fc7d99ee81a5899 Mon Sep 17 00:00:00 2001 From: kundaMwiza Date: Fri, 21 Nov 2025 07:08:02 +0000 Subject: [PATCH 153/230] [inductor] Use custom triton kernel subclass when available (#167456) This refactor replaces direct uses of TritonKernel in cases where a subclass type is available since out of tree / custom backends can: - have their own configs that they would like to place in `inductor_meta` via a `TritonKernel` subclass for the autotuner to handle - have their own triton heuristics for the different types of operations (pointwise, reduction e.t.c). These heuristics can currently only be reached by patching. This change allows custom backends to inject their own imports directly via a subclass Example out of tree backends with their own heuristic modules: Ascend NPU: https://github.com/Ascend/pytorch/blob/045a034dbcec287a5997aa13fd129a1cd6b1e215/torch_npu/_inductor/npu_triton_heuristics.py#L4 Intel XPU: https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/xpu/triton_ops/autotune.py It also adds a `triton_meta_common` method that is analogous to `inductor_meta_common` that is overridable, so that compile options can be directly provided. Test plan: Added unit tests to test_triton_extension_backend.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/167456 Approved by: https://github.com/jansel --- .../triton/extension_triton_heuristics.py | 41 +++++ .../inductor/test_triton_extension_backend.py | 170 +++++++++++++++--- torch/_inductor/codegen/simd.py | 6 + torch/_inductor/codegen/triton.py | 62 ++++--- .../_inductor/codegen/triton_combo_kernel.py | 16 +- torch/_inductor/codegen/wrapper.py | 18 +- torch/_inductor/runtime/triton_heuristics.py | 6 +- torch/_inductor/select_algorithm.py | 3 +- 8 files changed, 260 insertions(+), 62 deletions(-) create mode 100644 test/inductor/extension_backends/triton/extension_triton_heuristics.py diff --git a/test/inductor/extension_backends/triton/extension_triton_heuristics.py b/test/inductor/extension_backends/triton/extension_triton_heuristics.py new file mode 100644 index 0000000000000..bfe558ae1708a --- /dev/null +++ b/test/inductor/extension_backends/triton/extension_triton_heuristics.py @@ -0,0 +1,41 @@ +from typing import Any + +from torch._inductor.runtime import triton_heuristics +from torch._inductor.runtime.triton_heuristics import user_autotune # noqa: F401 + + +EXTENSION_TRITON_META_FIELD = "extension_custom_field" + + +class ExtensionCachingAutotuner(triton_heuristics.CachingAutotuner): + def _create_compile_meta( + self, + cfg: triton_heuristics.Config, + ) -> dict[str, Any]: + assert EXTENSION_TRITON_META_FIELD in self.triton_meta + compile_meta = super()._create_compile_meta(cfg) + assert EXTENSION_TRITON_META_FIELD in compile_meta + return compile_meta + + +def pointwise( + size_hints, + triton_meta, + tile_hint=None, + filename=None, + min_elem_per_thread=0, + inductor_meta=None, +): + """ + Construct @triton.heuristics() based on size_hints. + """ + configs = [triton_heuristics.Config({"XBLOCK": 32})] + return triton_heuristics.cached_autotune( + size_hints, + configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=triton_heuristics.HeuristicType.POINTWISE, + filename=filename, + caching_autotuner_cls=ExtensionCachingAutotuner, + ) diff --git a/test/inductor/test_triton_extension_backend.py b/test/inductor/test_triton_extension_backend.py index 37b32404508bb..ae9afeec0637e 100644 --- a/test/inductor/test_triton_extension_backend.py +++ b/test/inductor/test_triton_extension_backend.py @@ -1,12 +1,15 @@ # Owner(s): ["module: inductor"] +import functools import random import string -import sys import unittest +from pathlib import Path +from typing import Any, Optional import torch import torch._dynamo import torch.utils.cpp_extension +from torch._inductor import config try: @@ -18,6 +21,9 @@ ExtensionScheduling, ExtensionWrapperCodegen, ) + from extension_backends.triton.extension_triton_heuristics import ( + EXTENSION_TRITON_META_FIELD, + ) except ImportError: from .extension_backends.triton.device_interface import DeviceInterface from .extension_backends.triton.extension_codegen_backend import ( @@ -25,18 +31,27 @@ ExtensionScheduling, ExtensionWrapperCodegen, ) + from .extension_backends.triton.extension_triton_heuristics import ( + EXTENSION_TRITON_META_FIELD, + ) +import torch._inductor.lowering as inductor_lowering from torch._C import FileCheck from torch._dynamo import device_interface -from torch._inductor import metrics +from torch._inductor import codegen, ir, metrics +from torch._inductor.codegen import common from torch._inductor.codegen.common import ( get_scheduling_for_device, get_wrapper_codegen_for_device, + IndentedBuffer, register_backend_for_device, register_device_op_overrides, ) -from torch._inductor.utils import get_triton_code +from torch._inductor.codegen.wrapper import PythonWrapperCodegen +from torch._inductor.utils import get_triton_code, run_and_get_triton_code from torch.testing._internal.common_utils import IS_FBCODE, IS_MACOS +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU_AND_TRITON +from torch.testing._internal.triton_utils import requires_cuda_and_triton try: @@ -44,18 +59,9 @@ except ImportError: from test_extension_backend import BaseExtensionBackendTests -try: - try: - from . import test_torchinductor - except ImportError: - import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library -except unittest.SkipTest: - if __name__ == "__main__": - sys.exit(0) - raise - - -TestCase = test_torchinductor.TestCase +if HAS_GPU_AND_TRITON: + import triton + import triton.language as tl def mock_triton_hash_with_backend(*args, **kwargs): @@ -65,14 +71,33 @@ def mock_triton_hash_with_backend(*args, **kwargs): @unittest.skipIf(IS_FBCODE, "cpp_extension doesn't work in fbcode right now") -@test_torchinductor.skip_if_cpp_wrapper( - "Not possible to fix until CppWrapperCpu supports triton for CPU" -) class TritonExtensionBackendTests(BaseExtensionBackendTests): """ Test creating a backend for inductor with Triton scheduling. """ + @classmethod + def setUpClass(cls): + super().setUpClass() + if config.cpp_wrapper: + raise unittest.SkipTest( + "Not possible to fix until CppWrapperCpu supports triton for CPU" + ) + + # Store the default backends and reset later + common.init_backend_registration() + + default_backend_patch = unittest.mock.patch.dict(inductor_lowering.lowerings) + default_backend_patch.start() + cls._default_backend_patch = default_backend_patch + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + + # Restore the default backend. + cls._default_backend_patch.stop() + def test_open_device_registration(self): torch._register_device_module("privateuseone", self.module) register_backend_for_device( @@ -112,10 +137,115 @@ def foo(x): "tl_math.sin" ).check("device_str='privateuseone'").run(code) + def _register_custom_backend_with_heuristics(self, device): + class ExtensionTritonKernel(codegen.triton.TritonKernel): + @classmethod + @functools.lru_cache(None) + def gen_common_triton_imports(cls) -> str: + default_imports = super().gen_common_triton_imports() + custom_imports = IndentedBuffer() + custom_imports.splice(default_imports) + path_to_ext_heuristics = ( + Path(__file__).parent / "extension_backends" / "triton" + ) + + custom_imports.splice(f""" + import sys + sys.path.append("{path_to_ext_heuristics}") + import extension_triton_heuristics as triton_heuristics + """) + return custom_imports + + @classmethod + def triton_meta_common(cls) -> dict[str, Any]: + triton_meta = super().triton_meta_common() + triton_meta[EXTENSION_TRITON_META_FIELD] = True + return triton_meta + + class ExtensionTritonScheduling(codegen.triton.TritonScheduling): + kernel_type = ExtensionTritonKernel + + class ExtensionPythonWrapperCodegen(PythonWrapperCodegen): + @classmethod + def _get_triton_info_kernel_cls(cls) -> type[codegen.triton.TritonKernel]: + return ExtensionTritonKernel + + @staticmethod + def create( + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], + partition_signatures: Optional[ir.GraphPartitionSignature] = None, + ): + if is_subgraph: + assert subgraph_name is not None + assert parent_wrapper is not None + return PythonWrapperCodegen.create( + subgraph_name, parent_wrapper, partition_signatures + ) + return ExtensionPythonWrapperCodegen() + + register_backend_for_device( + device, ExtensionTritonScheduling, ExtensionPythonWrapperCodegen + ) + + @requires_cuda_and_triton + def test_codegen_with_custom_heuristics_module(self): + self._register_custom_backend_with_heuristics(GPU_TYPE) + + def add(x, y): + return x + y + + x = torch.zeros((32,), device=GPU_TYPE) + y = x + compiled_add = torch.compile(add) + + code = run_and_get_triton_code(compiled_add, x, y) + FileCheck().check("import extension_triton_heuristics").check( + f"{EXTENSION_TRITON_META_FIELD}" + ).check("@triton.jit").run(code) + + @requires_cuda_and_triton + def test_codegen_with_custom_heuristics_module_udtk(self): + self._register_custom_backend_with_heuristics(GPU_TYPE) + + @triton.jit + def add_kernel( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x + y + tl.store(out_ptr + offsets, output, mask=mask) + + def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + output = torch.empty_like(x) + n_elements = output.numel() + + def grid(meta): + return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16) + return output + + args = [torch.randn(32, device=GPU_TYPE) for _ in range(2)] + code = run_and_get_triton_code(torch.compile(add), *args) + + FileCheck().check("import extension_triton_heuristics").check( + "@triton.jit" + ).run(code) + if __name__ == "__main__": from torch._inductor.test_case import run_tests - from torch.testing._internal.inductor_utils import HAS_CPU - if HAS_CPU and not IS_MACOS: + if (HAS_CPU or HAS_GPU_AND_TRITON) and not IS_MACOS: run_tests() diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 1706f53cb2927..6bfe27cdc6f99 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -2274,8 +2274,12 @@ def generate_combo_kernel_code( mixed_sizes: bool, only_gen_src_code: bool = False, ) -> list[tuple[str, Any, Any]]: + from .triton import TritonKernel from .triton_combo_kernel import ComboKernel + # This is currently the only type supported by this method + assert issubclass(self.kernel_type, TritonKernel) + fused_node_lists = [node.get_nodes() for node in subkernel_nodes] subkernel_map, node_schedule_map = {}, {} for pn, nodes in zip(subkernel_nodes, fused_node_lists): @@ -2287,6 +2291,7 @@ def generate_combo_kernel_code( tiling, features=SIMDKernelFeatures(node_schedule, numel, rnumel), optimize_mask=not mixed_sizes, + triton_kernel_cls=self.kernel_type, ) partitions = ComboKernel.horizontal_partition( @@ -2306,6 +2311,7 @@ def generate_combo_kernel_code( if len(node_group) == 0: continue kernel = ComboKernel( + triton_kernel_cls=self.kernel_type, enable_autotune=enable_autotune, mixed_sizes=mixed_sizes, ) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 9b718f0c780c1..f590fe57de609 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -178,28 +178,6 @@ class defined. return "" -@lru_cache(None) -def gen_common_triton_imports() -> str: - imports = IndentedBuffer() - imports.splice( - """ - import triton - import triton.language as tl - """ - ) - if attr_desc := gen_attr_descriptor_import(): - imports.writeline(attr_desc) - - imports.splice( - """ - from torch._inductor.runtime import triton_helpers, triton_heuristics - from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math - from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties - """ - ) - return imports.getvalue() - - class TritonSymbols: """ Stores sympy.Symbol instances and constants associated with triton codegen. @@ -4871,8 +4849,37 @@ def _get_heuristic(self): return "reduction" return "pointwise" - @staticmethod - def inductor_meta_common(): + @classmethod + @lru_cache(None) + def gen_common_triton_imports(cls) -> str: + imports = IndentedBuffer() + imports.splice( + """ + import triton + import triton.language as tl + """ + ) + if attr_desc := gen_attr_descriptor_import(): + imports.writeline(attr_desc) + + imports.splice( + """ + from torch._inductor.runtime import triton_helpers, triton_heuristics + from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math + from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + """ + ) + return imports.getvalue() + + @classmethod + def triton_meta_common(cls): + triton_meta = {"enable_fp_fusion": not config.emulate_precision_casts} + if enable_pdl_codegen(): + triton_meta["launch_pdl"] = True + return triton_meta + + @classmethod + def inductor_meta_common(cls): inductor_meta = { "backend_hash": torch.utils._triton.triton_hash_with_backend(), "assert_indirect_indexing": config.assert_indirect_indexing, @@ -4950,7 +4957,7 @@ def codegen_kernel(self, name=None) -> str: size_hints[prefix] = size_hint if name is None: - code.splice(gen_common_triton_imports()) + code.splice(self.gen_common_triton_imports()) device_type = V.graph.get_current_device_or_throw().type if device_type == "cpu": code.splice("triton_helpers.set_driver_to_cpu()") @@ -5053,6 +5060,7 @@ def add_constexpr_arg(arg_name): torch._inductor.config.triton.native_matmul and ("tl.dot" in str(self.body) or "tl.dot" in str(self.compute)) ), + **self.triton_meta_common(), } # Skip memory optimization for forward of the training loop where we expect @@ -5159,9 +5167,6 @@ def add_constexpr_arg(arg_name): triton_meta["configs"] = [config_of(signature)] - if enable_pdl_codegen(): - triton_meta["launch_pdl"] = True - # Triton compiler includes equal_to_1 args into constants even # when they are not constexpr. otherwise there may be a segfault # during launching the Inductor-compiled Triton kernel. @@ -5169,7 +5174,6 @@ def add_constexpr_arg(arg_name): # https://github.com/triton-lang/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384 for arg_num in equal_1_arg_indices(signature): # type: ignore[index] triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index,union-attr] - triton_meta["enable_fp_fusion"] = not config.emulate_precision_casts self.triton_meta = triton_meta diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index 41b12d05cd32e..010de72f1606a 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -35,7 +35,7 @@ ) from .simd import prefix_is_reduction, SIMDScheduling from .simd_kernel_features import SIMDKernelFeatures -from .triton import gen_common_triton_imports, TritonKernel +from .triton import TritonKernel from .triton_utils import config_of, signature_to_meta @@ -355,9 +355,13 @@ def codegen_pid_range( code.splice(f"pid_offset = pid // {num_kernels}") def __init__( - self, enable_autotune: bool = False, mixed_sizes: bool = False + self, + triton_kernel_cls: type[TritonKernel], + enable_autotune: bool = False, + mixed_sizes: bool = False, ) -> None: super().__init__() + self.triton_kernel_cls = triton_kernel_cls self.sub_kernels: list[TritonKernel] = [] self.iter_vars_count = itertools.count() self.grids: list[list[int]] = [] @@ -391,12 +395,13 @@ def create_triton_kernel( tiling: dict[str, sympy.Expr], features: SIMDKernelFeatures, optimize_mask: bool, + triton_kernel_cls: type[TritonKernel], ) -> TritonKernel: """ Only allow optimize_mask=True when 1) sequential dispatch is used, 2) numels except x dimension are the same for each sub kernel. """ - return TritonKernel( + return triton_kernel_cls( tiling, features=features, pid_cache={"tl.program_id(0)": "pid_offset"}, @@ -615,12 +620,13 @@ def jit_line( mutated_args = self.get_mutated_args_sub_kernels() dispatch = self.dispatch_class assert dispatch is not None + inductor_meta = { "grid_type": dispatch.grid_expr.__name__, "combo_grid_meta": self.combo_grid_meta(), "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), "mutated_arg_names": mutated_args, - **TritonKernel.inductor_meta_common(), + **self.triton_kernel_cls.inductor_meta_common(), } sub_kernel = selected_kernel @@ -768,7 +774,7 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: ) code = IndentedBuffer() - code.splice(gen_common_triton_imports()) + code.splice(self.triton_kernel_cls.gen_common_triton_imports()) if config.benchmark_combo_kernel: code.splice(self.imports_for_benchmark_kernel()) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index c6c56c86b2c24..0498fca739bfc 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1025,7 +1025,7 @@ class PythonWrapperCodegen(CodeGen): Generate outer wrapper in Python that calls the kernels. """ - supports_caching = True # Whether the output code is cacheable. + supports_caching: bool = True # Whether the output code is cacheable. def __init__(self): super().__init__() @@ -2280,6 +2280,16 @@ def _define_kernel_helper( def define_subgraph_launcher_fn(self, name: str, subgraph_code): self.subgraph_definitions.splice(subgraph_code.value) + @classmethod + def _get_triton_info_kernel_cls(cls): + # Other inductor triton backends may subclass from + # the `TritonKernel` class. An override of this method + # allows them to set which subclass to use to get information + # such as common triton imports or inductor metadata + from .triton import TritonKernel + + return TritonKernel + def define_user_defined_triton_kernel( self, kernel, @@ -2301,7 +2311,6 @@ def define_user_defined_triton_kernel( TensorArg, TMADescriptorArg, ) - from .triton import gen_common_triton_imports, TritonKernel original_name = kernel.__name__ signature: list[KernelArgType] = [] @@ -2514,9 +2523,10 @@ def rename_sizes_for_launcher(expr: Union[int, sympy.Expr]) -> sympy.Expr: compile_wrapper.writeline(f"async_compile.triton({original_name!r}, '''") inductor_meta["kernel_name"] = name - inductor_meta.update(TritonKernel.inductor_meta_common()) + triton_info_kernel_cls = self._get_triton_info_kernel_cls() + inductor_meta.update(triton_info_kernel_cls.inductor_meta_common()) - compile_wrapper.splice(gen_common_triton_imports()) + compile_wrapper.splice(triton_info_kernel_cls.gen_common_triton_imports()) compile_wrapper.splice( f""" @triton_heuristics.user_autotune( diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index d59d9bbf3fee4..8a84d2432ceca 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2132,6 +2132,8 @@ def cached_autotune( filename=None, inductor_meta=None, custom_kernel=False, + caching_autotuner_cls: type[CachingAutotuner] = CachingAutotuner, + debug_autotuner_cls: type[DebugAutotuner] = DebugAutotuner, ): """ A copy of triton.autotune that calls our subclass. Our subclass @@ -2168,7 +2170,7 @@ def decorator(fn): tconfig.kwargs.pop("XBLOCK") if inductor_meta.get("profile_bandwidth"): - return DebugAutotuner( + return debug_autotuner_cls( fn, triton_meta=triton_meta, inductor_meta=inductor_meta, @@ -2187,7 +2189,7 @@ def decorator(fn): filename=filename, with_bandwidth_info=True, ) - return CachingAutotuner( + return caching_autotuner_cls( fn, triton_meta=triton_meta, inductor_meta=inductor_meta, diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 493ca1179fad8..28cdfbf0cc7ea 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -61,7 +61,6 @@ from .codegen.simd_kernel_features import SIMDKernelFeatures from .codegen.subgraph import SubgraphChoiceCaller from .codegen.triton import ( - gen_common_triton_imports, texpr, TMACompatibilityChecker, TritonKernel, @@ -742,7 +741,7 @@ def hook(): # python_argdefs() cannot be run until after the rest of the template lazily adds more args arg_defs, *_ = self.args.python_argdefs() code = IndentedBuffer() - code.splice(gen_common_triton_imports()) + code.splice(self.gen_common_triton_imports()) code.splice(self.jit_lines()) code.writeline( f"def {self.kernel_name}({', '.join(x.full_name() for x in arg_defs)}):" From 8b0314d1a7c4ea438d2e46db81ec2027b4287adf Mon Sep 17 00:00:00 2001 From: Frank Lin Date: Fri, 21 Nov 2025 07:16:02 +0000 Subject: [PATCH 154/230] Fix edge-data handling in cudaGraphNodeGetDependencies for CUDA 13 in graph_capture_record_stream_reuse (#168305) CUDA 13 introduced stricter behavior for querying graph edges with edge data. According to the CUDA documentation for [cudaGraphNodeGetDependencies](https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g94ee7ba53ade560483e9c5d06e8ef50d) > If an edge has non-zero (non-default) edge data and edgeData is NULL, this API returns cudaErrorLossyQuery. If edgeData is non-NULL, then pDependencies must also be non-NULL. When a graph contains edge data, we must provide a non-NULL edgeData buffer during dependency queries. Otherwise CUDA 13 will raise a cudaErrorLossyQuery. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168305 Approved by: https://github.com/eqy, https://github.com/ezyang --- c10/cuda/CUDACachingAllocator.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 9e7823a394302..1d70edde5a4ca 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -1765,7 +1765,12 @@ class DeviceCachingAllocator { auto node_get_dependencies = [](cudaGraphNode_t n, cudaGraphNode_t* deps, size_t* count) -> void { #if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000) - C10_CUDA_CHECK(cudaGraphNodeGetDependencies(n, deps, nullptr, count)); + if (deps == nullptr) { + C10_CUDA_CHECK(cudaGraphNodeGetDependencies(n, deps, nullptr, count)); + } else { + cudaGraphEdgeData edgeData; + C10_CUDA_CHECK(cudaGraphNodeGetDependencies(n, deps, &edgeData, count)); + } #else C10_CUDA_CHECK(cudaGraphNodeGetDependencies(n, deps, count)); #endif From 265a8bc90b7ce30e128d4920b4f4882814835fde Mon Sep 17 00:00:00 2001 From: arkadip-maitra Date: Fri, 21 Nov 2025 07:52:39 +0000 Subject: [PATCH 155/230] adding kwarg inputs handling in register sharding (#168249) Fixes #167977 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168249 Approved by: https://github.com/ezyang --- .../experimental/test_register_sharding.py | 28 +++++++++++++++++++ torch/distributed/tensor/_op_schema.py | 10 +++++++ torch/distributed/tensor/_ops/utils.py | 18 ++++++++---- 3 files changed, 51 insertions(+), 5 deletions(-) diff --git a/test/distributed/tensor/experimental/test_register_sharding.py b/test/distributed/tensor/experimental/test_register_sharding.py index 1cfd7af243b58..0f2f208689608 100644 --- a/test/distributed/tensor/experimental/test_register_sharding.py +++ b/test/distributed/tensor/experimental/test_register_sharding.py @@ -116,6 +116,34 @@ def custom_argmax_sharding(x, dim, keepdim): self.assertTrue(dist_y.placements[0].is_shard(dim=0)) self.assertEqual(dist_y.full_tensor(), local_y) + @with_comms + def test_register_sharding_for_tensor_kwargs(self): + mesh = self.build_device_mesh() + x = torch.randn(4, 4, device=self.device_type) + x_dt = distribute_tensor(x, mesh, [Replicate()]) + + @register_sharding(aten.min.dim_min) + def min_dim_strategy(x, dim, keepdim, min, min_indices): + all_replicate = ( + [Replicate(), Replicate()], + [Replicate(), None, None, Replicate(), Replicate()], + ) + return [all_replicate] + + value = torch.randn(4, 1, device=self.device_type) + indices = torch.randn(4, 1, device=self.device_type).long() + value_dt = distribute_tensor(value, mesh, [Replicate()]) + indices_dt = distribute_tensor(indices, mesh, [Replicate()]) + + result = torch.min(x_dt, dim=1, keepdim=True, out=(value_dt, indices_dt)) + + self.assertIsInstance(result[0], DTensor) + self.assertIsInstance(result[1], DTensor) + + expected_values, expected_indices = torch.min(x, dim=1, keepdim=True) + self.assertEqual(result[0].full_tensor(), expected_values) + self.assertEqual(result[1].full_tensor(), expected_indices) + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/tensor/_op_schema.py b/torch/distributed/tensor/_op_schema.py index 95e9509cdbcd6..283eaf4a06db8 100644 --- a/torch/distributed/tensor/_op_schema.py +++ b/torch/distributed/tensor/_op_schema.py @@ -361,6 +361,16 @@ def args_strategy(self) -> tuple[OpStrategy, ...]: ) return tuple(item for item in args if isinstance(item, OpStrategy)) + @property + def kwargs_strategy(self) -> tuple[OpStrategy, ...]: + # returns OpStrategy items from kwargs_schema. + kwargs_vals = ( + tree_leaves(self.kwargs_schema) + if self.schema_info is not None and self.schema_info.needs_pytree + else self.kwargs_schema.values() + ) + return tuple(item for item in kwargs_vals if isinstance(item, OpStrategy)) + def __repr__(self) -> str: args_schema = ", ".join([str(arg_schema) for arg_schema in self.args_schema]) return ( diff --git a/torch/distributed/tensor/_ops/utils.py b/torch/distributed/tensor/_ops/utils.py index a19ce091e3748..a9f42b53fca6e 100644 --- a/torch/distributed/tensor/_ops/utils.py +++ b/torch/distributed/tensor/_ops/utils.py @@ -104,9 +104,10 @@ def replicate_op_strategy(op_schema: OpSchema) -> StrategyType: """ Fallback strategy all use Replication() """ - inputs_strategy = op_schema.args_strategy - # TODO(zpcore): handle kwarg_inputs_strategy - # kwarg_inputs_strategy = op_schema.kwargs_schema + args_strategy = op_schema.args_strategy + kwargs_strategy = op_schema.kwargs_strategy + inputs_strategy = args_strategy + kwargs_strategy + output_type = [str(ret.type) for ret in op_schema.op._schema.returns] output_len = output_type.count("Tensor") # TODO(zpcore): Confirm if view op can be handle properly or not. Prevent @@ -355,8 +356,15 @@ def expand_to_full_mesh_op_strategy( s for s in spec_list[input_index:] if isinstance(s, DTensorSpec) ] - input_args_strategy = op_schema.args_strategy - assert len(input_specs) == len(input_args_strategy) + args_strategy = op_schema.args_strategy + kwargs_strategy = op_schema.kwargs_strategy + input_args_strategy = args_strategy + kwargs_strategy + + if len(input_specs) != len(input_args_strategy): + raise AssertionError( + f"input_specs({len(input_specs)}) != strategies({len(input_args_strategy)}: " + f"{len(args_strategy)} args + {len(kwargs_strategy)} kwargs)" + ) self_spec = input_args_strategy[0].strategies[0].output_spec if inplace_op and self_spec.placements != input_specs[0].placements: From 3b19eca8f3639d7f137c986c76405d20b13f9c38 Mon Sep 17 00:00:00 2001 From: Grayson Derossi Date: Fri, 21 Nov 2025 08:15:40 +0000 Subject: [PATCH 156/230] Fix cublasLtMatmul failure (#167873) This PR fixes a cublasLtMatmul failure seen in `python test/test_ops.py TestCommonCUDA.test_out_warning_torch__scaled_mm_cuda` caused by the use of `result_ptr` for matrix C. Using `nullptr` fixes this issue when `beta` is on host, which currently seems to be the only possible case, as the optional `alpha` parameter is not exposed to the user. Using `result_ptr` for matrix C should be valid if Cdesc == DDesc or Cdesc is `nullptr`, but both of these cases cause `cublasLtMatmulAlgoGetHeuristic` to fail with `CUBLAS_STATUS_NOT_SUPPORTED`. The use of `nullptr` in `cublasLtMatmul` is a workaround until this failure can be resolved. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167873 Approved by: https://github.com/slayton58, https://github.com/eqy --- aten/src/ATen/cuda/CUDABlas.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 9a55b058001da..bc7607f232011 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -1997,6 +1997,10 @@ void scaled_gemm( // Note: alpha_val may change later depending on user-passed argument float alpha_val = 1.0; float beta_val = 0.0; +#ifndef USE_ROCM + // Note: unused, but cublasLtMatmul requires a C pointer that is not result_ptr or nullptr + const void* dummy_C_ptr = mat1_ptr; +#endif // ifndef USE_ROCM CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, _cublasOpFromChar(transa)); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb)); @@ -2180,8 +2184,11 @@ void scaled_gemm( mat2_ptr, Bdesc.descriptor(), beta_ptr, - // NOTE: always use result_ptr here, because cuBLASLt w/device beta=0 can't handle nullptr either +#ifdef USE_ROCM result_ptr, // unused, since beta_val is 0, but hipblaslt can't handle nullptr +#else + dummy_C_ptr, // also unused, but cuBLAS can't use nullptr or result_ptr +#endif // ifdef USE_ROCM Cdesc.descriptor(), result_ptr, Ddesc.descriptor(), From 29eca3015ecaf86ca4ad5d1f21ad958dd2a55d0f Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Fri, 21 Nov 2025 09:08:45 +0000 Subject: [PATCH 157/230] Remove useless super() delegation (#168235) This PR removes useless super() delegations detected by pylint. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168235 Approved by: https://github.com/albanD, https://github.com/zou3519 --- torch/_dynamo/exc.py | 15 +-- torch/_dynamo/variables/base.py | 3 - torch/_dynamo/variables/dicts.py | 14 --- torch/_higher_order_ops/_invoke_quant.py | 3 - .../learnedheuristic_interface.py | 6 - torch/_inductor/codegen/cpp.py | 3 - torch/_inductor/codegen/cuda/gemm_template.py | 13 -- .../ao/nn/intrinsic/qat/modules/conv_fused.py | 117 ------------------ .../data_sparsifier/benchmarks/dlrm_utils.py | 3 - .../quantizer/embedding_quantizer.py | 3 - .../quantizer/xpu_inductor_quantizer.py | 3 - torch/backends/__init__.py | 3 - torch/backends/cudnn/__init__.py | 3 - torch/backends/miopen/__init__.py | 3 - torch/backends/mkldnn/__init__.py | 3 - torch/backends/opt_einsum/__init__.py | 3 - .../_checkpoint/checkpoint_wrapper.py | 3 - torch/jit/_monkeytype_config.py | 3 - torch/jit/_recursive.py | 3 +- torch/testing/_internal/common_utils.py | 2 - 20 files changed, 5 insertions(+), 204 deletions(-) diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index f11c78bdaa49e..5b0e8a402dd96 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -198,24 +198,20 @@ class RecompileError(TorchDynamoException): class ArgsMismatchError(Unsupported): - def __init__(self, msg: str) -> None: - super().__init__(msg) + pass class AttributeMutationError(Unsupported): - def __init__(self, msg: str) -> None: - super().__init__(msg) + pass class InfiniteGeneratorError(Unsupported): # Raised when the number of yielded values is greater than MAX_ITERATOR_LIMIT - def __init__(self, msg: str) -> None: - super().__init__(msg) + pass class SideEffectsError(Unsupported): - def __init__(self, msg: str) -> None: - super().__init__(msg) + pass class CondOpArgsMismatchError(ArgsMismatchError): @@ -223,9 +219,6 @@ class CondOpArgsMismatchError(ArgsMismatchError): Internal error from cond() due to arguments mismatch. """ - def __init__(self, msg: str) -> None: - super().__init__(msg) - class UserErrorType(Enum): DYNAMIC_CONTROL_FLOW = auto() diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index 4e248320e60b6..2d11a27bafac0 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -151,9 +151,6 @@ class AttributeMutation(MutationType): allows mutation on the value's attributes. """ - def __init__(self, typ: SourceType) -> None: - super().__init__(typ) - class AttributeMutationExisting(AttributeMutation): """ diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 24cd5007da37d..636875d85e54a 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -1296,13 +1296,6 @@ def install_dict_contains_guard( class FrozensetVariable(SetVariable): - def __init__( - self, - items: list[VariableTracker], - **kwargs: Any, - ) -> None: - super().__init__(items, **kwargs) - def debug_repr(self) -> str: if not self.items: return "frozenset()" @@ -1360,13 +1353,6 @@ def call_method( class DictKeySetVariable(SetVariable): - def __init__( - self, - items: list[VariableTracker], - **kwargs: Any, - ) -> None: - super().__init__(items, **kwargs) - def debug_repr(self) -> str: if not self.items: return "dict_keys([])" diff --git a/torch/_higher_order_ops/_invoke_quant.py b/torch/_higher_order_ops/_invoke_quant.py index 1fc1e1114a036..b7a9fb94b93e2 100644 --- a/torch/_higher_order_ops/_invoke_quant.py +++ b/torch/_higher_order_ops/_invoke_quant.py @@ -26,9 +26,6 @@ class InvokeQuantUnpacked(BaseHOP): def __init__(self) -> None: super().__init__("invoke_quant") - def __call__(self, subgraph, *operands, scheme=None): - return super().__call__(subgraph, *operands, scheme=scheme) - invoke_quant = InvokeQuantUnpacked() diff --git a/torch/_inductor/autoheuristic/learnedheuristic_interface.py b/torch/_inductor/autoheuristic/learnedheuristic_interface.py index cb2568d8a6801..84a941b076c31 100644 --- a/torch/_inductor/autoheuristic/learnedheuristic_interface.py +++ b/torch/_inductor/autoheuristic/learnedheuristic_interface.py @@ -39,9 +39,6 @@ def get_decisions_ranked(self, context: AHContext) -> Optional[list[str]]: class LearnedHeuristicRegression(LearnedHeuristic): - def __init__(self) -> None: - super().__init__() - def get_feedback(self, context: AHContext, choice: Choice) -> float: return 1.0 @@ -64,9 +61,6 @@ def get_decision( class LearnedHeuristicDecision(LearnedHeuristic): - def __init__(self) -> None: - super().__init__() - def get_choice(self, idx: int) -> Optional[str]: return None diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 88f203421cc1c..18b209de94cb3 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -3786,9 +3786,6 @@ class TilingSelect: In the future, we can implement advanced heuristic in a subclass. """ - def __init__(self): - super().__init__() - def select_tiling( self, fn_list, diff --git a/torch/_inductor/codegen/cuda/gemm_template.py b/torch/_inductor/codegen/cuda/gemm_template.py index 22d0981febecd..c4b7188bd9e62 100644 --- a/torch/_inductor/codegen/cuda/gemm_template.py +++ b/torch/_inductor/codegen/cuda/gemm_template.py @@ -1330,19 +1330,6 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate): including those which allow flexible fusions with epilogues. """ - def __init__( - self, - input_nodes: list[Buffer], - layout: Layout, - alpha: float, - beta: float, - input_reorder: Optional[list[int]] = None, - use_fast_accum: Optional[bool] = None, - ): - super().__init__( - input_nodes, layout, alpha, beta, input_reorder, use_fast_accum - ) - @staticmethod def add_cutlass_gemm_choices( choices: list[ChoiceCaller], diff --git a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py index 0054e996e33ce..1e49a274e129c 100644 --- a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py +++ b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py @@ -112,9 +112,6 @@ def reset_bn_parameters(self): bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) - def reset_parameters(self): - super().reset_parameters() - def update_bn_stats(self): self.freeze_bn = False self.bn.training = True @@ -534,44 +531,6 @@ class ConvBnReLU1d(ConvBn1d): # module class after fusing bn into conv _FUSED_FLOAT_MODULE: ClassVar[type[nn.Module] | None] = nni.ConvReLU1d - def __init__( - self, - # Conv1d args - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=None, - padding_mode="zeros", - # BatchNorm1d args - # num_features: out_channels - eps=1e-05, - momentum=0.1, - # affine: True - # track_running_stats: True - # Args for this module - freeze_bn=False, - qconfig=None, - ): - super().__init__( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - groups, - bias, - padding_mode, - eps, - momentum, - freeze_bn, - qconfig, - ) - def forward(self, input): return F.relu(self._forward(input)) @@ -735,44 +694,6 @@ class ConvBnReLU2d(ConvBn2d): # module class after fusing bn into conv _FUSED_FLOAT_MODULE: ClassVar[type[nni.ConvReLU2d] | None] = nni.ConvReLU2d - def __init__( - self, - # Conv2d args - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=None, - padding_mode="zeros", - # BatchNorm2d args - # num_features: out_channels - eps=1e-05, - momentum=0.1, - # affine: True - # track_running_stats: True - # Args for this module - freeze_bn=False, - qconfig=None, - ): - super().__init__( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - groups, - bias, - padding_mode, - eps, - momentum, - freeze_bn, - qconfig, - ) - def forward(self, input): return F.relu(self._forward(input)) @@ -935,44 +856,6 @@ class ConvBnReLU3d(ConvBn3d): # module class after fusing bn into conv _FUSED_FLOAT_MODULE: ClassVar[type[nni.ConvReLU3d] | None] = nni.ConvReLU3d - def __init__( - self, - # Conv3d args - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=None, - padding_mode="zeros", - # BatchNorm3d args - # num_features: out_channels - eps=1e-05, - momentum=0.1, - # affine: True - # track_running_stats: True - # Args for this module - freeze_bn=False, - qconfig=None, - ): - super().__init__( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - groups, - bias, - padding_mode, - eps, - momentum, - freeze_bn, - qconfig, - ) - def forward(self, input): return F.relu(ConvBn3d._forward(self, input)) diff --git a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py index 3c146c55947a0..e2b31e0e563bf 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py @@ -19,9 +19,6 @@ class SparseDLRM(DLRM_Net): layer of the top layer. """ - def __init__(self, **args): - super().__init__(**args) - def forward(self, dense_x, lS_o, lS_i): # pyrefly: ignore [missing-attribute] x = self.apply_mlp(dense_x, self.bot_l) # dense features diff --git a/torch/ao/quantization/quantizer/embedding_quantizer.py b/torch/ao/quantization/quantizer/embedding_quantizer.py index b0f1b823b7fdb..3b8ef1030bfdc 100644 --- a/torch/ao/quantization/quantizer/embedding_quantizer.py +++ b/torch/ao/quantization/quantizer/embedding_quantizer.py @@ -41,9 +41,6 @@ def get_embedding_operators_config() -> OperatorConfig: class EmbeddingQuantizer(Quantizer): - def __init__(self) -> None: - super().__init__() - @classmethod def get_supported_quantization_configs(cls) -> list[QuantizationConfig]: op_configs: set[QuantizationConfig] = { diff --git a/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py b/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py index d19968c2787f4..1c0fc48fd54fa 100644 --- a/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py @@ -75,9 +75,6 @@ class XPUInductorQuantizer(X86InductorQuantizer): of the optimized kernels in oneDNN library. """ - def __init__(self) -> None: - super().__init__() - """ Following annotate_xx overrides the impls in base class, as no XPU implementation for these operators currently. We would diff --git a/torch/backends/__init__.py b/torch/backends/__init__.py index c02a8c36fd08b..f54a3fd6820c7 100644 --- a/torch/backends/__init__.py +++ b/torch/backends/__init__.py @@ -113,9 +113,6 @@ def inner(precision): class GenericModule(PropModule): - def __init__(self, m, name): - super().__init__(m, name) - fp32_precision = ContextProp( _get_fp32_precision_getter("generic", "all"), _set_fp32_precision_setter("generic", "all"), diff --git a/torch/backends/cudnn/__init__.py b/torch/backends/cudnn/__init__.py index 697783c01cb64..267594531db3d 100644 --- a/torch/backends/cudnn/__init__.py +++ b/torch/backends/cudnn/__init__.py @@ -198,9 +198,6 @@ def flags( class CudnnModule(PropModule): - def __init__(self, m, name): - super().__init__(m, name) - enabled = ContextProp(torch._C._get_cudnn_enabled, torch._C._set_cudnn_enabled) deterministic = ContextProp( torch._C._get_cudnn_deterministic, torch._C._set_cudnn_deterministic diff --git a/torch/backends/miopen/__init__.py b/torch/backends/miopen/__init__.py index 93453cc11592d..1b270b658e31a 100644 --- a/torch/backends/miopen/__init__.py +++ b/torch/backends/miopen/__init__.py @@ -37,9 +37,6 @@ def flags( class MiopenModule(PropModule): - def __init__(self, m, name): - super().__init__(m, name) - immediate = ContextProp( torch._C._get_miopen_immediate, torch._C._set_miopen_immediate ) diff --git a/torch/backends/mkldnn/__init__.py b/torch/backends/mkldnn/__init__.py index 2d1ce8f3bb997..58e6b2c595e98 100644 --- a/torch/backends/mkldnn/__init__.py +++ b/torch/backends/mkldnn/__init__.py @@ -110,9 +110,6 @@ def flags(enabled=False, deterministic=False, allow_tf32=True, fp32_precision="n class MkldnnModule(PropModule): - def __init__(self, m, name): - super().__init__(m, name) - def is_available(self): return is_available() diff --git a/torch/backends/opt_einsum/__init__.py b/torch/backends/opt_einsum/__init__.py index 797d847e31e5c..264be78aa9a1c 100644 --- a/torch/backends/opt_einsum/__init__.py +++ b/torch/backends/opt_einsum/__init__.py @@ -101,9 +101,6 @@ def flags(enabled=None, strategy=None): class OptEinsumModule(PropModule): - def __init__(self, m, name): - super().__init__(m, name) - global enabled enabled = ContextProp(_get_enabled, _set_enabled) global strategy diff --git a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py index 3ce067f6cddc0..eae76e8cc72af 100644 --- a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py +++ b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py @@ -103,9 +103,6 @@ def _pre_load_state_dict_hook( class OffloadWrapper(ActivationWrapper): - def __init__(self, mod): - super().__init__(mod) - def forward(self, *args, **kwargs): with save_on_cpu(pin_memory=True): return self._checkpoint_wrapped_module(*args, **kwargs) diff --git a/torch/jit/_monkeytype_config.py b/torch/jit/_monkeytype_config.py index 0f348590ea397..e5ddc1e443a29 100644 --- a/torch/jit/_monkeytype_config.py +++ b/torch/jit/_monkeytype_config.py @@ -85,9 +85,6 @@ def get_qualified_name(func): class JitTypeTraceStoreLogger(CallTraceStoreLogger): """A JitTypeCallTraceLogger that stores logged traces in a CallTraceStore.""" - def __init__(self, store: CallTraceStore) -> None: - super().__init__(store) - def log(self, trace: CallTrace) -> None: # pyrefly: ignore [missing-attribute] self.traces.append(trace) diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index 75355cbd4b8e0..ec4bbd125119d 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -152,8 +152,7 @@ def _get_valid_constant(attr, v, owner_type): class SourceContext(torch._C._jit_tree_views.SourceRangeFactory): - def __init__(self, source, filename, file_lineno, leading_whitespace_len) -> None: - super().__init__(source, filename, file_lineno, leading_whitespace_len) + pass def get_annotations(obj): diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index d5afc413daed8..815cc8859080f 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1370,8 +1370,6 @@ class XMLTestResultVerbose(_XMLTestResult): This works with unittest_xml_reporting<=3.2.0,>=2.0.0 (3.2.0 is latest at the moment) """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) def addSkip(self, test, reason): super().addSkip(test, reason) From 7c57ee33dad4e4948529c08619248a0cbab864b3 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Fri, 21 Nov 2025 09:15:15 +0000 Subject: [PATCH 158/230] Add Pylint checks to linterrunner (#167421) This PR adds pylint to CI and enables `W0143` which checks comparisons of callables and is not implemented in `ruff`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167421 Approved by: https://github.com/albanD --- .lintrunner.toml | 47 +++++ .spin/cmds.py | 1 + functorch/benchmarks/pointwise_scorecard.py | 4 +- pylintrc | 5 + .../_composable/test_replicate_training.py | 4 +- .../fsdp/test_checkpoint_wrapper.py | 2 +- test/distributed/test_c10d_common.py | 6 +- test/dynamo/test_graph_deduplication.py | 2 +- test/export/test_export.py | 2 +- test/functorch/discover_coverage.py | 2 +- test/functorch/test_eager_transforms.py | 16 +- test/functorch/test_vmap.py | 4 +- test/fx/test_fx_split.py | 2 +- test/fx/test_subgraph_rewriter.py | 4 +- test/quantization/fx/test_numeric_suite_fx.py | 12 +- test/test_fx.py | 6 +- test/test_fx_experimental.py | 8 +- test/test_tensorexpr.py | 2 +- test/test_torch.py | 2 +- .../torch_np/numpy_tests/core/test_numeric.py | 2 +- tools/linter/adapters/pylint_linter.py | 192 ++++++++++++++++++ 21 files changed, 285 insertions(+), 40 deletions(-) create mode 100644 pylintrc create mode 100644 tools/linter/adapters/pylint_linter.py diff --git a/.lintrunner.toml b/.lintrunner.toml index 7a6e241f90c8d..0f46b398ca501 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1751,3 +1751,50 @@ command = [ "python3", "tools/linter/adapters/gb_registry_linter.py", ] + +[[linter]] +code = 'PYLINT' +include_patterns = ['**/*.py'] +exclude_patterns = [ + '.git/**', + 'build_test_custom_build/**', + 'build/**', + 'caffe2/**', + 'docs/caffe2/**', + 'docs/cpp/src/**', + 'docs/src/**', + 'fb/**', + '**/fb/**', + 'functorch/docs/**', + 'functorch/examples/**', + 'functorch/docs/source/tutorials/**', + 'torch/_inductor/fx_passes/serialized_patterns/**', + 'torch/_inductor/autoheuristic/artifacts/**', + 'scripts/**', + 'test/generated_type_hints_smoketest.py', + 'test/test_torchfuzz_repros.py', + # CPython tests + 'test/dynamo/cpython/**', + # Tests from the NumPy test suite + 'test/torch_np/numpy_test/**/*.py', + 'third_party/**', + 'torch/include/**', + 'torch/lib/**', + 'venv/**', + '**/*.pyi', + "tools/experimental/torchfuzz/**", + 'tools/test/test_selective_build.py', +] +command = [ + 'python3', + 'tools/linter/adapters/pylint_linter.py', + '--config=pylintrc', + '--', + '@{{PATHSFILE}}' +] +init_command = [ + 'python3', + 'tools/linter/adapters/pip_init.py', + '--dry-run={{DRYRUN}}', + 'pylint==4.0.2', +] diff --git a/.spin/cmds.py b/.spin/cmds.py index 9ed9f4a796b45..374d7699ef7f8 100644 --- a/.spin/cmds.py +++ b/.spin/cmds.py @@ -191,6 +191,7 @@ def regenerate_clangtidy_files(): "FLAKE8", "GB_REGISTRY", "PYFMT", + "PYLINT", "PYREFLY", "TEST_DEVICE_BIAS", "TEST_HAS_MAIN", diff --git a/functorch/benchmarks/pointwise_scorecard.py b/functorch/benchmarks/pointwise_scorecard.py index 5f46c0a74fc5d..9b910735f7d1e 100644 --- a/functorch/benchmarks/pointwise_scorecard.py +++ b/functorch/benchmarks/pointwise_scorecard.py @@ -233,7 +233,7 @@ def micros(s): args = shape()[:nargs] try: - if shape == medium_transpose: + if shape is medium_transpose: raise RuntimeError("pointwise_operator hangs on medium_transpose") pw_op = pointwise_operator(operator) torch.testing.assert_close(operator(*args), pw_op(*args)) @@ -264,7 +264,7 @@ def micros(s): ) ) try: - if shape == medium_transpose: + if shape is medium_transpose: raise RuntimeError("pointwise_operator hangs on medium_transpose") if (operator, shape) in nope: raise RuntimeError("pointwise_operator fails on medium_transpose") diff --git a/pylintrc b/pylintrc new file mode 100644 index 0000000000000..3cb94acaa33f4 --- /dev/null +++ b/pylintrc @@ -0,0 +1,5 @@ +[MESSAGES CONTROL] + +# Disable the message, report, category or checker with the given id(s). +disable=all +enable=W0143 diff --git a/test/distributed/_composable/test_replicate_training.py b/test/distributed/_composable/test_replicate_training.py index 076a5e3760ff5..3dc908a8b1afe 100644 --- a/test/distributed/_composable/test_replicate_training.py +++ b/test/distributed/_composable/test_replicate_training.py @@ -678,7 +678,9 @@ def _test_train_parity_with_activation_checkpointing( test_device_type: str, ): assert checkpoint_impl in ("composable", "utils", "wrapper") - testing_compile = replicate != torch.distributed._composable.replicate_with_fsdp + testing_compile = ( + replicate is not torch.distributed._composable.replicate_with_fsdp + ) if testing_compile and checkpoint_impl == "composable": return torch.manual_seed(42) diff --git a/test/distributed/fsdp/test_checkpoint_wrapper.py b/test/distributed/fsdp/test_checkpoint_wrapper.py index 0acb530f441fc..8cc5698cd19aa 100644 --- a/test/distributed/fsdp/test_checkpoint_wrapper.py +++ b/test/distributed/fsdp/test_checkpoint_wrapper.py @@ -73,7 +73,7 @@ def forward(self, a, b, c=None, d=None, **kwargs): ]: with self.subTest(wrapper=wrapper): model = wrapper(MyModel()) - if wrapper == offload_wrapper: + if wrapper is offload_wrapper: self.assertTrue(isinstance(model, OffloadWrapper)) else: self.assertTrue(isinstance(model, CheckpointWrapper)) diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 2a1cb2b5580cb..0d11725829d26 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -2073,16 +2073,16 @@ def _call_collective_with_varying_tensors(self, backend, collective, *args): # ensure supported devices (cpu, cuda) succeeds during dispatch call tensor = torch.zeros(2, 2, device=torch.device(device)) # multi tensor collectives - if collective == dist.barrier: + if collective is dist.barrier: collective() elif collective in (dist.all_gather, dist.gather): collective([tensor], tensor, *args) - elif collective == dist.scatter: + elif collective is dist.scatter: collective(tensor, [tensor], *args) elif collective in (dist.reduce_scatter, dist.all_to_all): # gloo does not support reduce_scatter or all_to_all if backend != "gloo": - if collective == dist.reduce_scatter: + if collective is dist.reduce_scatter: collective(tensor, [tensor], *args) else: collective([tensor], [tensor], *args) diff --git a/test/dynamo/test_graph_deduplication.py b/test/dynamo/test_graph_deduplication.py index fc9284a3c9542..d0e712ffaa6cf 100644 --- a/test/dynamo/test_graph_deduplication.py +++ b/test/dynamo/test_graph_deduplication.py @@ -1154,7 +1154,7 @@ def install_subgraph(self, name, subgraph): splits = [ n for n in graph.nodes - if n.op == "call_function" and n.target == torch.split + if n.op == "call_function" and n.target is torch.split ] for split in splits: tracker.node_to_duplicates.pop(split) diff --git a/test/export/test_export.py b/test/export/test_export.py index 8545c210e1b8d..b3bb0b48b569d 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -16483,7 +16483,7 @@ def forward(self, x): # Expect builtin round in the export graph round_nodes = [ - n for n in ep.graph.nodes if n.op == "call_function" and n.target == round + n for n in ep.graph.nodes if n.op == "call_function" and n.target is round ] self.assertEqual(len(round_nodes), 1) diff --git a/test/functorch/discover_coverage.py b/test/functorch/discover_coverage.py index 2ffdfec1e8633..2ac21e56c5c9c 100644 --- a/test/functorch/discover_coverage.py +++ b/test/functorch/discover_coverage.py @@ -356,7 +356,7 @@ def is_decorateinfo_skip_or_xfail(decorateinfo): actual_decorator = decorateinfo.decorators[0] if isinstance(actual_decorator, toleranceOverride): return False - if actual_decorator == unittest.expectedFailure: + if actual_decorator is unittest.expectedFailure: return True # Assume the rest are skips return True diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index 0a5d03f9dd1f0..37bb013e5df82 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -978,23 +978,21 @@ def foo(t): fn = foo bdim = 0 for op in reversed(op_list): - if op == vmap: + if op is vmap: fn = op(fn, in_dims=bdim) bdim += 1 else: fn = op(fn) expected = f"{repr(x)}" - level = 0 - for op in op_list: - level += 1 # noqa: SIM113 - if op == grad: - expected = f"GradTrackingTensor(lvl={level}, value={expected})" - elif op == vmap: - bdim -= 1 + for level, op in enumerate(op_list): + if op is grad: expected = ( - f"BatchedTensor(lvl={level}, bdim={bdim}, value={expected})" + f"GradTrackingTensor(lvl={level + 1}, value={expected})" ) + elif op is vmap: + bdim -= 1 + expected = f"BatchedTensor(lvl={level + 1}, bdim={bdim}, value={expected})" fn(x) buf = buf.replace("\n", "").replace(" ", "") diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 0f893201733d3..ac58b81350cf4 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -6139,9 +6139,9 @@ def f(x): else: input = torch.randn(5) - if transform == vjp: + if transform is vjp: transform = functools.partial(transform, f) - elif transform == jvp: + elif transform is jvp: input = (input,) transform = functools.partial(transform, f, input) else: diff --git a/test/fx/test_fx_split.py b/test/fx/test_fx_split.py index 8d2b120e534ae..ae6880ab70e27 100644 --- a/test/fx/test_fx_split.py +++ b/test/fx/test_fx_split.py @@ -84,7 +84,7 @@ def forward(self, y): # Create custom operator support to mark wrapped_add as supported class CustomOpSupport(op_support.OperatorSupportBase): def is_node_supported(self, submodules, node) -> bool: - return node.target == wrapped_add + return node.target is wrapped_add # Create a simple splitter to test the edge case class TestSplitter(splitter_base._SplitterBase): diff --git a/test/fx/test_subgraph_rewriter.py b/test/fx/test_subgraph_rewriter.py index 0ee60f978127d..e887f90dba227 100644 --- a/test/fx/test_subgraph_rewriter.py +++ b/test/fx/test_subgraph_rewriter.py @@ -782,7 +782,7 @@ def replacement(a, b, bias): found_repalcement_node = False for node in traced.graph.nodes: - if node.target == wrapped_gemm_bias_mul: + if node.target is wrapped_gemm_bias_mul: found_repalcement_node = True break @@ -847,7 +847,7 @@ def gemm_bias_mul_replacement_with_c(a, b, bias, c): repalcement_node_found = 0 for node in traced.graph.nodes: - if node.target == wrapped_gemm_bias_mul_with_c: + if node.target is wrapped_gemm_bias_mul_with_c: repalcement_node_found += 1 self.assertEqual(repalcement_node_found, 2) diff --git a/test/quantization/fx/test_numeric_suite_fx.py b/test/quantization/fx/test_numeric_suite_fx.py index 75e4ebffbdf42..ed3eabd702690 100644 --- a/test/quantization/fx/test_numeric_suite_fx.py +++ b/test/quantization/fx/test_numeric_suite_fx.py @@ -866,7 +866,7 @@ def _test_match_activations( ): if qconfig_dict is None: qconfig_dict = torch.ao.quantization.get_default_qconfig_mapping() - if prepare_fn == prepare_fx: + if prepare_fn is prepare_fx: m.eval() else: m.train() @@ -929,7 +929,7 @@ def _test_match_shadow_activations( ): if qconfig_dict is None: qconfig_dict = torch.ao.quantization.get_default_qconfig_mapping() - if prepare_fn == prepare_fx: + if prepare_fn is prepare_fx: m.eval() else: m.train() @@ -1082,7 +1082,7 @@ def _test_match_activations_mod_impl(self, prepare_fn=prepare_fx): nn.Conv2d(1, 1, 1), ).eval() qconfig_dict = None - if prepare_fn == prepare_qat_fx: + if prepare_fn is prepare_qat_fx: qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} expected_occurrence = { ns.call_module(OutputLogger): 2, @@ -1103,7 +1103,7 @@ def test_match_activations_mod_qat(self): def _test_match_activations_fun_impl(self, prepare_fn=prepare_fx): m = LinearReluLinearFunctional().eval() qconfig_dict = None - if prepare_fn == prepare_qat_fx: + if prepare_fn is prepare_qat_fx: qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} expected_occurrence = { ns.call_module(OutputLogger): 2, @@ -1165,7 +1165,7 @@ def _test_add_shadow_loggers_mod_impl(self, prepare_fn=prepare_fx): nn.Conv2d(1, 1, 1), ).eval() qconfig_dict = None - if prepare_fn == prepare_qat_fx: + if prepare_fn is prepare_qat_fx: qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} res = self._test_match_shadow_activations( m, (torch.randn(1, 1, 4, 4),), results_len=2, @@ -1182,7 +1182,7 @@ def test_add_shadow_loggers_mod_qat(self): def _test_add_shadow_loggers_fun_impl(self, prepare_fn=prepare_fx): m = LinearReluLinearFunctional() qconfig_dict = None - if prepare_fn == prepare_qat_fx: + if prepare_fn is prepare_qat_fx: qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} res = self._test_match_shadow_activations( m, (torch.randn(4, 4),), results_len=2, prepare_fn=prepare_fn, diff --git a/test/test_fx.py b/test/test_fx.py index 71299ddb2400d..7fdd6552edc7b 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -3001,7 +3001,7 @@ def forward(self, inp): for node in traced.graph.nodes: if node.op == "placeholder": ph = node - elif node.op == "call_function" and node.target == wrapped_named_tup: + elif node.op == "call_function" and node.target is wrapped_named_tup: node.update_arg(0, Pair(ph, 1.2)) node.update_kwarg("p2", Pair(3.4, ph)) call_func = node @@ -3164,7 +3164,7 @@ def forward(self, x, y): mod_false = symbolic_trace(mod, concrete_args={"y": False}) self.assertEqual(mod_true(3, True), 6) print(mod_true.code) - assert any(i.target == torch._assert for i in mod_true.graph.nodes) + assert any(i.target is torch._assert for i in mod_true.graph.nodes) with self.assertRaises(AssertionError): mod_true(3, False) self.assertEqual(mod_false(3, False), 3) @@ -4783,7 +4783,7 @@ def forward(self, x): self.assertEqual(len(gm.graph.nodes), 3) found = False for node in gm.graph.nodes: - if node.op == "call_function" and node.target == side_effect_func: + if node.op == "call_function" and node.target is side_effect_func: found = True self.assertTrue(found) diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 6fe3fe2355a1e..6ed8d9f2fac51 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -601,7 +601,7 @@ def forward(self, a, b): # Check the IR to make sure there's a call_function node with target == "Assert" self.assertTrue( any( - node.op == "call_function" and node.target == torch._assert + node.op == "call_function" and node.target is torch._assert for node in traced.graph.nodes ) ) @@ -660,7 +660,7 @@ def forward(self, a, b): # Check the IR to make sure there's a call_function node with target == "Assert" self.assertTrue( any( - node.op == "call_function" and node.target == torch._assert + node.op == "call_function" and node.target is torch._assert for node in traced.graph.nodes ) ) @@ -688,7 +688,7 @@ def forward(self, a, b): # Check the IR to make sure there's a call_function node with target == "Assert" self.assertTrue( any( - node.op == "call_function" and node.target == torch._assert + node.op == "call_function" and node.target is torch._assert for node in traced.graph.nodes ) ) @@ -720,7 +720,7 @@ def forward(self, a, b): # Check the IR to make sure there's a call_function node with target == "Assert" self.assertTrue( any( - node.op == "call_function" and node.target == torch._assert + node.op == "call_function" and node.target is torch._assert for node in traced.graph.nodes ) ) diff --git a/test/test_tensorexpr.py b/test/test_tensorexpr.py index cf2c836486c80..c8ad8276a116b 100644 --- a/test/test_tensorexpr.py +++ b/test/test_tensorexpr.py @@ -891,7 +891,7 @@ def test_threshold(x, y): torch.manual_seed(0) for torch_fn, dev, data_type in fn_dev_dtype: - if torch_fn == test_lgamma and dev == "cuda": + if torch_fn is test_lgamma and dev == "cuda": # lgamma_cuda does not support BF16 continue rand_a = torch.rand(1024, dtype=data_type, device=dev) diff --git a/test/test_torch.py b/test/test_torch.py index 01c6fb39a5a2a..9b9cc2cfc58f9 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -2987,7 +2987,7 @@ def filter_shape(shape, dim): t_np = t.cpu().numpy() actual = torch.gradient(t, spacing=spacing, dim=dims, edge_order=edge_order) - if space_fn == create_coordinate_tensors and spacing[0].device != 'cpu': + if space_fn is create_coordinate_tensors and spacing[0].device != 'cpu': spacing = [space.cpu().detach().numpy() for space in spacing] expected = np.gradient(t_np, *self._wrap_to_list(spacing), axis=dims, edge_order=edge_order) actual, expected = self._inf_nan_preprocess(list(actual), self._wrap_to_list(expected)) diff --git a/test/torch_np/numpy_tests/core/test_numeric.py b/test/torch_np/numpy_tests/core/test_numeric.py index c6b2d14aef6dc..209119ee8f012 100644 --- a/test/torch_np/numpy_tests/core/test_numeric.py +++ b/test/torch_np/numpy_tests/core/test_numeric.py @@ -2382,7 +2382,7 @@ def test_dtype_str_bytes(self, likefunc, dtype): # Regression test for gh-19860 a = np.arange(16).reshape(2, 8) b = a[:, ::2] # Ensure b is not contiguous. - kwargs = {"fill_value": ""} if likefunc == np.full_like else {} + kwargs = {"fill_value": ""} if likefunc is np.full_like else {} result = likefunc(b, dtype=dtype, **kwargs) if dtype is str: assert result.strides == (16, 4) diff --git a/tools/linter/adapters/pylint_linter.py b/tools/linter/adapters/pylint_linter.py new file mode 100644 index 0000000000000..c4051272b1df6 --- /dev/null +++ b/tools/linter/adapters/pylint_linter.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +import argparse +import json +import logging +import os +import subprocess +import sys +import time +from enum import Enum +from pathlib import Path +from typing import NamedTuple + + +class LintSeverity(str, Enum): + ERROR = "error" + WARNING = "warning" + DISABLED = "disabled" + + +class LintMessage(NamedTuple): + path: str | None + line: int | None + char: int | None + code: str + severity: LintSeverity + name: str + original: str | None + replacement: str | None + description: str | None + + +def run_command( + args: list[str], +) -> subprocess.CompletedProcess[bytes]: + logging.debug("$ %s", " ".join(args)) + start_time = time.monotonic() + try: + return subprocess.run( + args, + capture_output=True, + check=False, + ) + finally: + end_time = time.monotonic() + logging.debug("took %dms", (end_time - start_time) * 1000) + + +def check_pylint_installed(code: str) -> list[LintMessage]: + cmd = [sys.executable, "-mpylint", "--version"] + try: + subprocess.run(cmd, check=True, capture_output=True) + return [] + except subprocess.CalledProcessError as e: + msg = e.stderr.decode(errors="replace") + return [ + LintMessage( + path=None, + line=None, + char=None, + code=code, + severity=LintSeverity.ERROR, + name="command-failed", + original=None, + replacement=None, + description=f"Could not run '{' '.join(cmd)}': {msg}", + ) + ] + + +def in_github_actions() -> bool: + return bool(os.getenv("GITHUB_ACTIONS")) + + +def check_files( + filenames: list[str], + config: str, + code: str, +) -> list[LintMessage]: + try: + proc = run_command( + ["pylint", f"--rcfile={config}", "-f", "json"] + filenames, + ) + except OSError as err: + return [ + LintMessage( + path=None, + line=None, + char=None, + code=code, + severity=LintSeverity.ERROR, + name="command-failed", + original=None, + replacement=None, + description=(f"Failed due to {err.__class__.__name__}:\n{err}"), + ) + ] + if proc.returncode == 32: + stderr = str(proc.stderr, "utf-8").strip() + return [ + LintMessage( + path=None, + line=None, + char=None, + code=code, + severity=LintSeverity.ERROR, + name="command-failed", + original=None, + replacement=None, + description=stderr, + ) + ] + stdout = str(proc.stdout, "utf-8").strip() + errors = json.loads(stdout) + + return [ + LintMessage( + path=error["path"], + name=error["message-id"], + description=error["message"], + line=int(error["line"]), + char=int(error["column"]), + code=code, + severity=LintSeverity.ERROR, + original=None, + replacement=None, + ) + for error in errors + ] + + +def main() -> None: + parser = argparse.ArgumentParser( + description="pylint wrapper linter.", + fromfile_prefix_chars="@", + ) + parser.add_argument( + "--config", + required=True, + help="path to a pylintrc config file", + ) + parser.add_argument( + "--code", + default="PYLINT", + help="the code this lint should report as", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="verbose logging", + ) + parser.add_argument( + "filenames", + nargs="+", + help="paths to lint", + ) + args = parser.parse_args() + + logging.basicConfig( + format="<%(threadName)s:%(levelname)s> %(message)s", + level=logging.NOTSET + if args.verbose + else logging.DEBUG + if len(args.filenames) < 1000 + else logging.INFO, + stream=sys.stderr, + ) + + filenames: set[str] = set() + + # If a stub file exists, have pylint check it instead of the original file, in + # accordance with PEP-484 (see https://www.python.org/dev/peps/pep-0484/#stub-files) + for filename in args.filenames: + if filename.endswith(".pyi"): + filenames.add(filename) + continue + + stub_filename = filename.replace(".py", ".pyi") + if Path(stub_filename).exists(): + filenames.add(stub_filename) + else: + filenames.add(filename) + + lint_messages = check_pylint_installed(args.code) + check_files( + list(filenames), args.config, args.code + ) + for lint_message in lint_messages: + print(json.dumps(lint_message._asdict()), flush=True) + + +if __name__ == "__main__": + main() From cf6d08983290c0d76744b9d58df2f823f04ea39e Mon Sep 17 00:00:00 2001 From: Arsh Zahed Date: Fri, 21 Nov 2025 10:10:29 +0000 Subject: [PATCH 159/230] [3.14] Add check for __module__ to _SysImporter.whichmodule (#168189) To be merged after https://github.com/pytorch/pytorch/pull/168152 (sorry, I should've made this a stack before) This fixes an issue with Python 3.14 where pickle.whichmodule fails to import package names like . To avoid this, I added a check for `__module__` first. **Test Plan:** Locally ran with Python 3.14 ``` python3 test/test_package.py TestPackageFX.test_package_fx_package ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/168189 Approved by: https://github.com/williamwen42, https://github.com/fxdawnn --- torch/package/importer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torch/package/importer.py b/torch/package/importer.py index fc0e735890634..83a896c69a629 100644 --- a/torch/package/importer.py +++ b/torch/package/importer.py @@ -181,6 +181,12 @@ def import_module(self, module_name: str): return importlib.import_module(module_name) def whichmodule(self, obj: Any, name: str) -> str: + # In Python 3.14+, pickle.whichmodule tries to import the module, + # which fails for mangled package names like ''. + # Check __module__ first before calling pickle.whichmodule. + module_name = getattr(obj, "__module__", None) + if module_name is not None: + return module_name return _pickle_whichmodule(obj, name) From c23a90041e451b7347d1e587b45188927ee66b89 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 21 Nov 2025 14:10:36 +0000 Subject: [PATCH 160/230] Revert "[Full Inductor][Pytorch] Prevent decomposition and enable fallback of aten.native_layer_norm for MTIA (#168290)" This reverts commit a7f3b10866098c452d89cd7a30bc4ce5713b8319. Reverted https://github.com/pytorch/pytorch/pull/168290 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/168290#issuecomment-3563174857)) --- torch/_inductor/decomposition.py | 16 ---------------- torch/_inductor/lowering.py | 5 ----- 2 files changed, 21 deletions(-) diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index db9c8f5f0333c..3cedad185c3f2 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -35,7 +35,6 @@ ELEMENTWISE_TYPE_PROMOTION_KIND, type_to_dtype, ) -from torch._refs import native_layer_norm as decomp_native_layer_norm from torch.fx.experimental.symbolic_shapes import guard_or_false, statically_known_true from . import config, inductor_prims @@ -119,7 +118,6 @@ aten.clamp_max, aten.clamp_min, aten.embedding_dense_backward, # we fall back on xpu - aten.native_layer_norm, # we fall back on mtia aten.index_add, # we conditionally call this decomp aten.glu, # inductor lowers this directly aten.select_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass @@ -161,20 +159,6 @@ def _embedding_dense_backward( ) -@register_decomposition(aten.native_layer_norm) -def _native_layer_norm( - input: torch.Tensor, - normalized_shape: utils.ShapeType, - weight: Optional[torch.Tensor], - bias: Optional[torch.Tensor], - eps: float, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if input.is_mtia: - return NotImplemented - # We can write a util function to update decomp table if we have more ops to fallback. - return decomp_native_layer_norm(input, normalized_shape, weight, bias, eps) - - @register_decomposition([aten.sym_constrain_range_for_size.default]) def sym_constrain_range_for_size( symbol: torch.SymInt, diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 7eafc45036b10..d374be59c9446 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2902,11 +2902,6 @@ def is_aligned(x): aten.embedding_dense_backward, warn=False ) # (XPU-only and faster than decomp) -if torch.mtia.is_available(): - make_fallback( - aten.native_layer_norm, warn=False - ) # (MTIA-only and faster than decomp) - # 1.5) Easy or Impossible make_fallback(aten._cdist_forward) # p=2 should be feasible From d4de871adfac825e12bae9068e1c8433bd58455d Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Fri, 21 Nov 2025 06:24:36 -0800 Subject: [PATCH 161/230] Revert #168264 + Python-side LRU cache when native op schema is not supported (#168269) This reverts #168264 but with a bugfix for the reason why it was reverted. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/168269 Approved by: https://github.com/wconstab, https://github.com/albanD, https://github.com/zpcore, https://github.com/malfet --- test/cpp/jit/test_custom_operators.cpp | 13 +- test/custom_operator/test_custom_ops.cpp | 2 +- test/distributed/tensor/test_op_strategy.py | 31 +- test/distributed/tensor/test_tensor_ops.py | 10 +- torch/_C/__init__.pyi.in | 2 + torch/csrc/PyInterpreter.cpp | 2 + torch/csrc/autograd/python_variable.cpp | 1188 ++++++++++++++++- torch/csrc/autograd/python_variable.h | 9 + torch/csrc/jit/frontend/schema_matching.cpp | 2 +- torch/csrc/jit/ir/alias_analysis.cpp | 2 +- torch/csrc/jit/ir/ir.cpp | 2 +- torch/csrc/jit/python/init.cpp | 6 +- torch/csrc/jit/runtime/operator.cpp | 70 +- torch/csrc/jit/runtime/operator.h | 5 +- .../jit/runtime/symbolic_shape_registry.cpp | 2 +- torch/csrc/utils/python_arg_parser.cpp | 82 +- torch/csrc/utils/python_arg_parser.h | 12 + torch/distributed/_tools/mem_tracker.py | 1 - torch/distributed/tensor/_api.py | 42 +- torch/distributed/tensor/_dispatch.py | 109 +- torch/distributed/tensor/_sharding_prop.py | 4 + torch/distributed/tensor/debug/__init__.py | 27 +- 22 files changed, 1467 insertions(+), 156 deletions(-) diff --git a/test/cpp/jit/test_custom_operators.cpp b/test/cpp/jit/test_custom_operators.cpp index 58f87717844de..66295d0380629 100644 --- a/test/cpp/jit/test_custom_operators.cpp +++ b/test/cpp/jit/test_custom_operators.cpp @@ -15,7 +15,7 @@ namespace jit { TEST(CustomOperatorTest, InferredSchema) { torch::RegisterOperators reg( "foo::bar", [](double a, at::Tensor b) { return a + b; }); - auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar")); + auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar")); ASSERT_EQ(ops.size(), 1); auto& op = ops.front(); @@ -43,8 +43,7 @@ TEST(CustomOperatorTest, ExplicitSchema) { "foo::bar_with_schema(float a, Tensor b) -> Tensor", [](double a, at::Tensor b) { return a + b; }); - auto& ops = - getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema")); + auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema")); ASSERT_EQ(ops.size(), 1); auto& op = ops.front(); @@ -77,7 +76,7 @@ TEST(CustomOperatorTest, ListParameters) { torch::List> complexdoubles, torch::List tensors) { return floats; }); - auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists")); + auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists")); ASSERT_EQ(ops.size(), 1); auto& op = ops.front(); @@ -123,7 +122,7 @@ TEST(CustomOperatorTest, ListParameters2) { "foo::lists2(Tensor[] tensors) -> Tensor[]", [](torch::List tensors) { return tensors; }); - auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists2")); + auto ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists2")); ASSERT_EQ(ops.size(), 1); auto& op = ops.front(); @@ -213,7 +212,7 @@ TEST(TestCustomOperator, OperatorGeneratorUndeclared) { }, aliasAnalysisFromSchema())}); - auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::not_exist")); + auto ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::not_exist")); ASSERT_EQ(ops.size(), 0); } @@ -232,7 +231,7 @@ TEST(TestCustomOperator, OperatorGeneratorBasic) { }, aliasAnalysisFromSchema())}); - auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::bar")); + auto ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::bar")); ASSERT_EQ(ops.size(), 1); auto& op = ops.front(); diff --git a/test/custom_operator/test_custom_ops.cpp b/test/custom_operator/test_custom_ops.cpp index a526bebd26144..9791006d1498f 100644 --- a/test/custom_operator/test_custom_ops.cpp +++ b/test/custom_operator/test_custom_ops.cpp @@ -22,7 +22,7 @@ void check_all_parameters( template Result get_operator_from_registry_and_execute(const char* op_name, Args&&... args) { - auto& ops = torch::jit::getAllOperatorsFor( + auto ops = torch::jit::getAllOperatorsFor( torch::jit::Symbol::fromQualString(op_name)); TORCH_INTERNAL_ASSERT(ops.size() == 1); diff --git a/test/distributed/tensor/test_op_strategy.py b/test/distributed/tensor/test_op_strategy.py index 139f5fb61fac8..72d95efcfa8c9 100644 --- a/test/distributed/tensor/test_op_strategy.py +++ b/test/distributed/tensor/test_op_strategy.py @@ -34,7 +34,11 @@ register_op_strategy, replicate_op_strategy, ) -from torch.distributed.tensor.debug import CommDebugMode +from torch.distributed.tensor.debug import ( + _clear_fast_path_sharding_prop_cache, + _clear_python_sharding_prop_cache, + CommDebugMode, +) from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.distributed._tensor.common_dtensor import ( create_local_tensor_test_class, @@ -479,7 +483,8 @@ def op_strategy_context(op_overload, strategy_func, schema_info=None): del propagator.op_to_schema_info[op_overload] else: propagator.op_to_schema_info[op_overload] = _origin_op_strategy_schema - propagator.propagate_op_sharding.cache.cache_clear() + _clear_fast_path_sharding_prop_cache() + _clear_python_sharding_prop_cache() def detect_exists_identical_opspec(*args, op, mesh, strategy_function) -> bool: @@ -645,6 +650,28 @@ def test_call_with_different_nontensor_args(self): self.assertEqual(out1.full_tensor(), out2.full_tensor()) +class TestStrategyOperation(DTensorTestBase): + @property + def world_size(self): + return 2 + + @with_comms + def test_cache_clean(self): + mesh = self.build_device_mesh() + test_op = torch.ops.mylib.numpy_sin + x = torch.randn(2, device=self.device_type) + y = torch.randn(2, device=self.device_type) + x_dt = distribute_tensor(x, mesh, [Shard(0)]) + y_dt = distribute_tensor(y, mesh, [Shard(0)]) + with op_strategy_context(test_op.default, replicate_op_strategy): + self._test_op_on_dtensor(test_op, x_dt, y_dt) + with self.assertRaisesRegex( + NotImplementedError, + f"Operator {test_op.default} does not have a sharding strategy registered", + ): + self._test_op_on_dtensor(test_op, x_dt, y_dt) + + DistTensorReplicateStrategyRegistrationTestWithLocalTensor = ( create_local_tensor_test_class( DistTensorReplicateStrategyRegistrationTest, diff --git a/test/distributed/tensor/test_tensor_ops.py b/test/distributed/tensor/test_tensor_ops.py index fc0a2b16955ca..a6cdc64a2e1dd 100644 --- a/test/distributed/tensor/test_tensor_ops.py +++ b/test/distributed/tensor/test_tensor_ops.py @@ -736,11 +736,11 @@ def test_where_type_promotion(self): @with_comms def test_dtensor_dtype_conversion(self): from torch.distributed.tensor.debug import ( - _clear_sharding_prop_cache, - _get_sharding_prop_cache_info, + _clear_fast_path_sharding_prop_cache, + _get_fast_path_sharding_prop_cache_stats, ) - _clear_sharding_prop_cache() + _clear_fast_path_sharding_prop_cache() device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] # by default we start from bf16 dtype @@ -760,13 +760,13 @@ def test_dtensor_dtype_conversion(self): self.assertEqual(bf16_sharded_dtensor1.to_local().dtype, torch.bfloat16) # by this point we only have cache misses - hits, misses, _, _ = _get_sharding_prop_cache_info() + hits, misses = _get_fast_path_sharding_prop_cache_stats() self.assertEqual(hits, 0) self.assertEqual(misses, 2) # convert to fp32 again and see if there's cache hit bf16_sharded_dtensor1.float() - hits, misses, _, _ = _get_sharding_prop_cache_info() + hits, misses = _get_fast_path_sharding_prop_cache_stats() # by now we should have cache hit self.assertEqual(hits, 1) self.assertEqual(misses, 2) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 1af6df5e7664a..e9b58b9ce71eb 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1967,6 +1967,8 @@ def _DTensor_OpSchema_recompute_comparison_key(self: OpSchema) -> None: ... def _DTensor_compute_global_tensor_info( tensor: Tensor, mesh: DeviceMesh, placements: Sequence[Placement] ) -> tuple[list[_int], list[_int]]: ... +def _get_DTensor_sharding_propagator_cache_stats() -> tuple[_int, _int]: ... +def _clear_DTensor_sharding_propagator_cache() -> None: ... # Defined in torch/csrc/multiprocessing/init.cpp def _multiprocessing_init() -> None: ... diff --git a/torch/csrc/PyInterpreter.cpp b/torch/csrc/PyInterpreter.cpp index 8a2e0d533ff0c..7f36d88bdaa32 100644 --- a/torch/csrc/PyInterpreter.cpp +++ b/torch/csrc/PyInterpreter.cpp @@ -338,6 +338,8 @@ void ConcretePyInterpreterVTable::dispatch( nullptr, torch_api_function_overload.ptr(), nullptr, + &op, + &arguments, TorchFunctionName::TorchDispatch); pushPyOutToStack( op, stack, py::reinterpret_steal(obj), "__torch_dispatch__"); diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 8165fd910c2c1..150512c972684 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -1,8 +1,13 @@ +#include #include +#include #include +#include #include #include #include +#include +#include #include #include #include @@ -40,7 +45,6 @@ #include -#include #include #include #include @@ -820,29 +824,84 @@ static PyObject* THPVariable_make_wrapper_subclass( END_HANDLE_TH_ERRORS } -static py::handle get_dtensor_spec_class() { #if IS_PYBIND_2_13_PLUS - PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store - storage; - return storage - .call_once_and_store_result([]() -> py::object { - return py::module::import("torch") - .attr("distributed") - .attr("tensor") - .attr("_dtensor_spec") - .attr("DTensorSpec"); - }) - .get_stored(); +#define DEFINE_CACHING_PYTHON_IMPORT_GETTER(name, import_expr) \ + static py::handle name() { \ + PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store \ + storage; \ + return storage \ + .call_once_and_store_result( \ + []() -> py::object { return import_expr; }) \ + .get_stored(); \ + } #else - static py::handle dtensor_spec_class = py::object(py::module::import("torch") - .attr("distributed") - .attr("tensor") - .attr("_dtensor_spec") - .attr("DTensorSpec")) - .release(); - return dtensor_spec_class; +#define DEFINE_CACHING_PYTHON_IMPORT_GETTER(name, import_expr) \ + static py::handle name() { \ + static py::handle storage = py::object(import_expr).release(); \ + return storage; \ + } #endif -} + +DEFINE_CACHING_PYTHON_IMPORT_GETTER( + get_dtensor_class_impl, + py::module::import("torch.distributed.tensor").attr("DTensor")) + +py::handle get_dtensor_class() { + return get_dtensor_class_impl(); +} + +DEFINE_CACHING_PYTHON_IMPORT_GETTER( + get_dtensor_spec_class, + py::module::import("torch.distributed.tensor") + .attr("_dtensor_spec") + .attr("DTensorSpec")) + +DEFINE_CACHING_PYTHON_IMPORT_GETTER( + get_replicate_class, + py::module::import("torch.distributed.tensor") + .attr("placement_types") + .attr("Replicate")) + +DEFINE_CACHING_PYTHON_IMPORT_GETTER( + get_tensor_meta_class, + py::module::import("torch.distributed.tensor") + .attr("_dtensor_spec") + .attr("TensorMeta")) + +DEFINE_CACHING_PYTHON_IMPORT_GETTER( + get_dtensor_op_dispatcher, + py::module::import("torch.distributed.tensor") + .attr("DTensor") + .attr("_op_dispatcher")) + +DEFINE_CACHING_PYTHON_IMPORT_GETTER( + get_dtensor_dispatch, + py::module::import("torch.distributed.tensor") + .attr("DTensor") + .attr("_op_dispatcher") + .attr("_dispatch_fast_path_python_tail")) + +DEFINE_CACHING_PYTHON_IMPORT_GETTER( + get_dtensor_dispatcher_wrap, + py::module::import("torch.distributed.tensor") + .attr("DTensor") + .attr("_op_dispatcher") + .attr("wrap")) + +DEFINE_CACHING_PYTHON_IMPORT_GETTER( + get_dtensor_get_local_results_slow_path, + py::module::import("torch") + .attr("distributed") + .attr("tensor") + .attr("DTensor") + .attr("_op_dispatcher") + .attr("_dispatch_get_local_results_slow_path")) + +DEFINE_CACHING_PYTHON_IMPORT_GETTER( + get_output_sharding_class, + py::module::import("torch.distributed.tensor") + .attr("_op_schema") + .attr("OutputSharding")) static bool arg_type_tensor_or_tensor_list_like(py::handle arg) { const auto dtensor_spec_class = get_dtensor_spec_class(); @@ -870,13 +929,26 @@ static bool arg_type_tensor_or_tensor_list_like(py::handle arg) { #define FOR_EACH_DTENSOR_INTERNED_STRING(_) \ MAYBE_FOR_EACH_PYTHON_3_10_MINUS_DTENSOR_INTERNED_STRING(_) \ _(_comparison_key) \ + _(_custom_op_handlers) \ _(_local_tensor) \ _(_spec) \ + _(_unwrap_to_op_info_impl) \ _(args_schema) \ + _(compute_mesh) \ + _(device_mesh) \ + _(dtype) \ + _(get_coordinate) \ _(kwargs_schema) \ + _(ndim) \ + _(needs_pytree) \ + _(needs_redistribute) \ _(op) \ + _(op_to_schema_info) \ + _(output_sharding) \ + _(output_spec) \ _(schema_info) \ _(shape) \ + _(sharding_propagator) \ _(size) \ _(static_argnum) \ _(static_kwargkey) \ @@ -891,6 +963,7 @@ struct DTensorInternedStrings { static DTensorInternedStrings dtensor_interned_strings; +#ifdef USE_DISTRIBUTED static bool intern_dtensor_strings() { #define INTERN_DTENSOR_STRING(s) \ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dtensor_interned_strings.s == nullptr); \ @@ -903,6 +976,7 @@ static bool intern_dtensor_strings() { #undef INTERN_DTENSOR_STRING return true; } +#endif static bool checked_not(PyObject* obj) { int result = PyObject_Not(obj); @@ -912,6 +986,36 @@ static bool checked_not(PyObject* obj) { return result; } +static bool checked_istrue(PyObject* obj) { + int result = PyObject_IsTrue(obj); + if (result == -1) { + throw py::error_already_set(); + } + return result; +} + +// pybind11 does not not use PyObject_Vectorcall currently; it seems +// to materialize a tuple of args instead. +template +static py::object checked_vectorcall( + PyObject* obj, + std::array args) { + PyObject* result = PyObject_Vectorcall(obj, args.data(), N, nullptr); + if (!result) { + throw py::error_already_set(); + } + return py::reinterpret_steal(result); +} + +template +static py::object checked_vectorcall(PyObject* obj, Args... args) { + static_assert( + (std::is_same_v && ...), + "must pass PyObject* to checked_vectorcall!"); + std::array arr = {args...}; + return checked_vectorcall(obj, arr); +} + static c10::SymDimVector tuple_to_symintlist(PyObject* obj) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(PyTuple_Check(obj)); c10::SymDimVector res; @@ -932,6 +1036,599 @@ static c10::SymDimVector tuple_to_symintlist(PyObject* obj) { return res; } +// As a Python object, DTensorSpec can be stored directly within +// IValue, but doing so is inefficient -- it requires a +// heap-allocated, reference counted intermediate +// ivalue::PyObjectHolder. +// Representation options: +// 1) Add an IValue tag to represent a placeholder object. +// 2) Play representational tricks -- stuff information into an IValue +// payload, such as by creating impossible +// intrusive_ptr_target*. Problem: this would cause IValue copying and +// possibly destruction to crash and so would be horribly unsafe. +// 3) Represent DTensorSpec directly inside IValue despite the inefficiency. +// 4) Leave the actual DTensor in the list of IValues, but detect it efficiently +// and transparently replace. +// 5) Just use a 24-byte struct of IValue + extra py::object. +// +// Given the high blast radius of (1), the unsafety of (2), the likely +// poor performance of (3), and detection of (4) looking less +// efficient than (5), (5) seems like the best path forward. + +// We can't safely steal bits from IValue, so we just use 24 bytes of +// space. If dtensor_spec is non-null (truthy) then it's the active +// member, otherwise it's iv. +struct IValueOrDTensorSpec { + IValueOrDTensorSpec() = default; + explicit IValueOrDTensorSpec(c10::IValue v) : iv(std::move(v)) {} + explicit IValueOrDTensorSpec(py::object dts) : dtensor_spec(std::move(dts)) {} + c10::IValue iv; + py::object dtensor_spec; + + bool operator==(const IValueOrDTensorSpec& rhs) const { + return dtensor_spec + ? (rhs.dtensor_spec && dtensor_spec.equal(rhs.dtensor_spec)) + : (iv == rhs.iv); + } +}; + +// This corresponds to the Python OpSchema class in that it is the key +// for the (native version of the) sharding propagator cache. It is +// missing essentially everything else from the Python OpSchema +// though. +class NativeOpSchema { + public: + NativeOpSchema( + const c10::OperatorHandle& op, + c10::SmallVector comparison_key, + std::size_t comparison_key_hash, + std::size_t args_schema_len) + : op_(op), + hash_(hash_combine( + hash_combine( + std::hash()(op), + comparison_key_hash), + args_schema_len)), + args_schema_len_(args_schema_len), + comparison_key_(std::move(comparison_key)) {} + + bool operator==(const NativeOpSchema& rhs) const { + // If two NativeOpSchema are being compared, they are probably + // equal, because comparison is occurring during a hash table + // lookup and we know the hashes are already equal. Therefore, we + // don't bother checking hash_ first. + return op_ == rhs.op_ && args_schema_len_ == rhs.args_schema_len_ && + comparison_key_ == rhs.comparison_key_; + } + + std::size_t hash() const { + return hash_; + } + + private: + // It would *not* be correct to store this by reference, because we + // have no guarantees about its lifetime. This class is cheap anyway. + c10::OperatorHandle op_; + std::size_t hash_; + // Subtle point: consider clamp.Tensor(Tensor self, Tensor? + // min=None, Tensor? max=None). The invocations clamp(t1, None, t2) + // and clamp(t1, t2, None) have the same comparison key (t1, t2) + // because we drop non-static non-tensor args from comparison. The + // only way we happen to be able to tell them apart is that we omit + // trailing defaulted arguments from the args tuple passed to + // __torch_dispatch__ (and hence to DTensor dispatch as well), so + // they have different args_schema_len_. + // + // I am preserving this existing behavior, but I suspect we should + // make an algorithm change to be less brittle, such as including + // None defaults for Tensor arguments in the comparison. + std::size_t args_schema_len_; + // There is no particular justification for the choice of 8 + // here. Feel free to change it. + c10::SmallVector comparison_key_; +}; + +namespace std { +template <> +struct hash { + std::size_t operator()(const NativeOpSchema& schema) const { + return schema.hash(); + } +}; +} // namespace std + +// Map from OpSchema to pyobject sharding propagation config. +class NativeShardingPropagatorCache { + public: + // Returns an invalid (falsey) py::object if the lookup fails. + py::object find(const NativeOpSchema& op_schema) const { + if (auto it = repr_.find(op_schema); it != repr_.end()) { + hits_++; + return py::object(it->second); + } + misses_++; + return py::object(); + } + + void insert(NativeOpSchema&& op_schema, py::object output_sharding) { + auto [it, inserted] = + repr_.emplace(std::move(op_schema), std::move(output_sharding)); + TORCH_INTERNAL_ASSERT( + inserted, + "tried to insert already-present element in NativeShardingPropagatorCache!"); + } + + auto hits() const { + return hits_; + } + + auto misses() const { + return misses_; + } + + private: + c10::FastMap repr_; + // Cache is thread-local, so we don't take any further action for + // thread-safety of these. + mutable std::size_t hits_ = 0; + mutable std::size_t misses_ = 0; +}; + +static std::optional> +create_native_op_schema( + const c10::OperatorHandle& op, + py::handle py_op, + torch::jit::Stack* stack); + +static std::mutex native_sharding_propagator_cache_cleanup_mutex; +static c10:: + FastMap*> + all_thread_caches; +thread_local std::optional + native_sharding_propagator_cache_DO_NOT_USE; + +NativeShardingPropagatorCache& +get_thread_local_native_sharding_propagator_cache() { + if (!native_sharding_propagator_cache_DO_NOT_USE.has_value()) { + native_sharding_propagator_cache_DO_NOT_USE.emplace(); + std::lock_guard lock( + native_sharding_propagator_cache_cleanup_mutex); + const auto this_thread_id = std::this_thread::get_id(); + all_thread_caches[this_thread_id] = + &native_sharding_propagator_cache_DO_NOT_USE; + py::dict thread_dict = + py::reinterpret_borrow(PyThreadState_GetDict()); + // We need to clean up before Python detaches from the thread if + // the thread is being destroyed. + if (!thread_dict.contains("__DTensor_fastpath_thread_cache_cleanup")) { + thread_dict["__DTensor_fastpath_thread_cache_cleanup"] = + py::capsule(new std::thread::id(this_thread_id), [](void* p) { + auto* ptid = reinterpret_cast(p); + { + std::lock_guard inner_lock( + native_sharding_propagator_cache_cleanup_mutex); + auto it = all_thread_caches.find(*ptid); + if (it != all_thread_caches.end()) { + // We need to both: + // 1) free python objects, and + it->second->reset(); + // 2) make sure we don't try to come back and mess with + // a destroyed thread-local at module unload (e.g., + // process exit) time. + all_thread_caches.erase(it); + } + } + delete ptid; + }); + } + } + return native_sharding_propagator_cache_DO_NOT_USE.value(); +} + +// We need to clean up all thread_locals if our module is getting +// unloaded. +void cleanup_thread_local_native_sharding_propagator_caches() { + std::lock_guard lock( + native_sharding_propagator_cache_cleanup_mutex); + for (auto& [_, popt_cache] : all_thread_caches) { + popt_cache->reset(); + } + all_thread_caches.clear(); +} + +static void replace_dtensors_with_local_tensor(torch::jit::Stack& stack); + +static bool is_default_overload(const std::string& overload_name) { + return overload_name.empty() || overload_name == "default"; +} + +static bool is_random_op(const c10::OperatorHandle& op) { + // NOTE: must stay in sync with _random_ops in + // torch/distributed/tensor/_dispatch.py + constexpr auto aten_namespace_prefix_len = 6; + const auto& op_name = op.operator_name(); + if (op_name.name.size() <= aten_namespace_prefix_len || + memcmp(op_name.name.data(), "aten::", aten_namespace_prefix_len) != 0) { + return false; + } + static constexpr std::array random_names = {{ + "native_dropout", + "normal_", + "rand_like", + "randn_like", + "uniform_", + "bernoulli", + }}; + std::string_view name_without_namespace( + op_name.name.c_str() + aten_namespace_prefix_len, + op_name.name.size() - aten_namespace_prefix_len); + if (name_without_namespace == "bernoulli_") { + return op_name.overload_name == "float"; + } + if (name_without_namespace == "randint_like") { + return is_default_overload(op_name.overload_name) || + op_name.overload_name == "low_dtype" || + op_name.overload_name == "low_dtype_out"; + } + const auto it = std::find( + random_names.begin(), random_names.end(), name_without_namespace); + if (it == random_names.end()) { + return false; + } + return is_default_overload(op_name.overload_name); +} + +// Puts local results on the stack. Return true for success, false for bailout +// to slow path. +static bool get_local_results( + const c10::OperatorHandle& op, + py::handle output_sharding, + py::handle compute_mesh, + bool participating, + torch::jit::Stack* stack) { + if (participating) { + // computation that happens in the current rank of the mesh, normal case + if (checked_istrue( + output_sharding.attr(dtensor_interned_strings.needs_redistribute) + .ptr()) || + is_random_op(op)) { + // Bail out to slow path. + return false; + } + // normal case, run local sharded op computation. + + // It is slightly inefficient that we take another pass over + // arguments here when we just did one in create_native_op_schema to + // create the comparison key. However, we have a crucial difference: + // in the NativeOpSchema, we don't want to waste time dealing with + // defaulted args. Here, we need to provide defaulted args because + // we are going to make a local op call. + replace_dtensors_with_local_tensor(*stack); + op.callBoxed(*stack); + } else { + // For a non-participating device (happens on rank that does not + // belong to the device mesh), we do: + // + // 1. if the return type is scalar, set the local result to + // None. + // 2. if the return type is Tensor or List[Tensor], return + // empty tensor(s) with correct dtype. + + stack->clear(); + + auto spec = output_sharding.attr(dtensor_interned_strings.output_spec); + if (spec.is_none()) { + // For a scalar return type, the non-participating device has + // None as its local result. + stack->emplace_back(); // Return None. + return true; + } + + const auto default_tensor = [](py::handle spec) -> Tensor { + auto tensor_meta = spec.attr(dtensor_interned_strings.tensor_meta); + TORCH_CHECK( + !tensor_meta.is_none(), py::str(spec), " has no tensor metadata."); + const auto sizes = tensor_meta.attr(dtensor_interned_strings.shape); + TORCH_CHECK( + PyTuple_Check(sizes.ptr()), "spec.tensor_meta.shape must be a tuple"); + const auto dtype = tensor_meta.attr(dtensor_interned_strings.dtype); + TORCH_CHECK( + THPDtype_Check(dtype.ptr()), + "spec.tensor_meta.dtype must be a torch.dtype"); + const auto scalar_type = + reinterpret_cast(dtype.ptr())->scalar_type; + if (py::cast(sizes).empty()) { + // scalar tensor + return at::zeros({}, scalar_type); + } else { + // non-scalar tensor + return at::empty({0}, scalar_type); + } + }; + auto handle_sequence = [&default_tensor, &op, stack](auto sequence) { + c10::List result(op.schema().returns().at(0).type()); + for (const auto& item : sequence) { + TORCH_CHECK( + !item.is_none(), + "return type ", + op.schema().returns().at(0).type(), + " in DTensor op is not supported"); + result.push_back(default_tensor(item)); + } + stack->push_back(std::move(result)); + }; + + if (py::isinstance(spec, get_dtensor_spec_class())) { + stack->push_back(default_tensor(spec)); + } else if (PyList_Check(spec.ptr())) { + handle_sequence(py::reinterpret_borrow(spec)); + } else if (PyTuple_Check(spec.ptr())) { + handle_sequence(py::reinterpret_borrow(spec)); + } else if (PySequence_Check(spec.ptr())) { + handle_sequence(py::reinterpret_borrow(spec)); + } else { + // return None. + stack->emplace_back(); + } + } + return true; +} + +static void functionalize_unsafe_set(at::Tensor& dst, const at::Tensor& src) { + at::native::checkSetStorage( + dst, + src.storage(), + dst.sym_storage_offset(), + dst.sym_sizes(), + dst.sym_strides(), + /*check_offset_in_bounds=*/false); +} + +static bool sets_intersect( + const std::unordered_set& smaller, + const std::unordered_set& bigger) { + if (smaller.size() > bigger.size()) { + return sets_intersect(bigger, smaller); + } + for (const auto& item : smaller) { + if (bigger.find(item) != bigger.end()) { + return true; + } + } + return false; +} + +py::object dispatchDTensorOp( + const c10::OperatorHandle& op, + py::handle py_op, + py::handle args, + py::handle kwargs, + torch::jit::Stack* stack) { + py::object cached_sharding; + const auto op_dispatcher = get_dtensor_op_dispatcher(); + { + const auto custom_op_handlers = + op_dispatcher.attr(dtensor_interned_strings._custom_op_handlers); + TORCH_CHECK( + PyDict_Check(custom_op_handlers.ptr()), + "_custom_op_handlers must be a dict!"); + PyObject* custom_op_handler = + PyDict_GetItemWithError(custom_op_handlers.ptr(), py_op.ptr()); + if (custom_op_handler) { + auto result = checked_vectorcall( + custom_op_handler, py_op.ptr(), args.ptr(), kwargs.ptr()); + stack->clear(); + return result; + } else if (PyErr_Occurred()) { + throw py::error_already_set(); + } + } + + torch::jit::Stack saved_args = *stack; + NativeShardingPropagatorCache* native_sharding_propagator_cache = nullptr; + // In the original Python implementation of DTensor dispatch, the creation + // of OpInfo (which includes the OpSchema computed here) never fails. However, + // C++ support for all the features of OpSchema are not supported; in this + // case opt_native_op_schema is nullopt. In this case, we need to fallback + // to the Python logic for doing so. If you are comparing against the old + // Python code, this is a bit tricky, since the Python 'dispatch' function + // has been completely deleted. + + // First, we will try to short-circuit Python entirely using the fast path. + // Here, we never materialize OpInfo, we generate a gimped NativeOpSchema + // object which has exactly the information you need to do a hash lookup. + auto opt_native_op_schema = create_native_op_schema(op, py_op, stack); + if (opt_native_op_schema.has_value()) { + native_sharding_propagator_cache = + &get_thread_local_native_sharding_propagator_cache(); + cached_sharding = + native_sharding_propagator_cache->find(opt_native_op_schema->first); + } + + py::object py_op_info; + if (!cached_sharding) { + // OK, the C++ fastpath failed. Let's use the Python path to generate the + // OpInfo (which is guaranteed to work), which we will need to either + // redo the cache lookup or compute the value for real. + py_op_info = checked_vectorcall( + op_dispatcher.attr("unwrap_to_op_info").ptr(), + py_op.ptr(), + args.ptr(), + kwargs.ptr()); + + py::object sharding = checked_vectorcall( + op_dispatcher.attr("_propagate_op_sharding_dispatch_slow_path").ptr(), + py_op.ptr(), + args.ptr(), + kwargs.ptr(), + py_op_info.ptr(), + /*try_cache*/ !opt_native_op_schema.has_value() ? Py_True : Py_False); + // This is a hack, because the dispatch slow path sometimes returns + // a sharding result (in which case we need to keep going) but it + // will sometimes just decompose and directly return a Tensor result, + // in which case we should return immediately. In this case, sharding + // is not a sharding at all; it's the real result! + if (!py::isinstance(sharding, get_output_sharding_class())) { + stack->clear(); + return sharding; + } + cached_sharding = sharding; + if (opt_native_op_schema.has_value()) { + native_sharding_propagator_cache->insert( + std::move(opt_native_op_schema->first), std::move(sharding)); + } + py_op_info.attr(dtensor_interned_strings.output_sharding) = cached_sharding; + } + + const auto get_py_op_info_if_needed = [&, &args = args, &kwargs = kwargs]() { + if (!py_op_info) { + py_op_info = checked_vectorcall( + op_dispatcher.attr(dtensor_interned_strings._unwrap_to_op_info_impl) + .ptr(), + py_op.ptr(), + args.ptr(), + kwargs.ptr(), + Py_False); + py_op_info.attr(dtensor_interned_strings.output_sharding) = + cached_sharding; + } + }; + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + !kwargs.is_none(), + "Python op_dispatch implementation expects non-None kwargs"); + + py::object compute_mesh; + if (opt_native_op_schema.has_value()) { + compute_mesh = std::move(opt_native_op_schema->second); + } else { + get_py_op_info_if_needed(); + compute_mesh = py_op_info.attr(dtensor_interned_strings.compute_mesh); + } + + const bool participating = + !checked_vectorcall( + compute_mesh.attr(dtensor_interned_strings.get_coordinate).ptr()) + .is_none(); + const bool local_results_success = get_local_results( + op, cached_sharding, compute_mesh, participating, stack); + py::object py_local_results; + if (local_results_success) { + py_local_results = torch::jit::createPyObjectForStack(std::move(*stack)); + } else { + get_py_op_info_if_needed(); + py_local_results = checked_vectorcall( + get_dtensor_get_local_results_slow_path().ptr(), + py_op.ptr(), + args.ptr(), + py_op_info.ptr()); + } + + const auto& operator_name = op.operator_name(); + // Simple analysis of function schema to determine if this is an + // inplace variant. It might not be entirely correct, but it's good + // enough for now. + const bool is_inplace_op = + !operator_name.name.empty() && operator_name.name.back() == '_'; + // Simple analysis of function schema to determine if this is an + // ou variant. It might not be entirely correct, but it's good + // enough for now. + const bool is_out_variant_op = !is_inplace_op && + operator_name.overload_name.find("out") != std::string::npos; + + // Fast path for default or view ops. + const auto output_spec = + cached_sharding.attr(dtensor_interned_strings.output_spec); + if (!is_inplace_op && !is_out_variant_op && + !(output_spec.is_none() && + (op.operator_name().name == "aten::equal" && + is_default_overload(op.operator_name().overload_name)))) { + const auto wrap = get_dtensor_dispatcher_wrap(); + auto wrapped_result = checked_vectorcall( + wrap.ptr(), py_local_results.ptr(), output_spec.ptr()); + if (!participating) { + stack->clear(); + return wrapped_result; + } + + // Direct C++ implementation of return_and_correct_aliasing for view ops. + + // py::tuple's default constructor allocates a size-0 tuple, so we + // wrap in optional to get a detectable empty state. + std::optional wrapped_result_tuple; + if (PyTuple_Check(wrapped_result.ptr())) { + wrapped_result_tuple = py::reinterpret_borrow(wrapped_result); + } + const auto& returns = op.schema().returns(); + const auto num_arguments = op.schema().arguments().size(); + for (const auto arg_idx : c10::irange(num_arguments)) { + const auto& arg_schema = op.schema().arguments()[arg_idx]; + const auto* arg_alias_info = arg_schema.alias_info(); + if (!arg_alias_info || arg_alias_info->isWrite()) { + continue; + } + // If we ever get here, it's a view op. Therefore, it does not + // have mutable output aliases, so we skip that portion of + // return_and_correct_aliasing. Furthermore, we *only* want to + // return_and_correct_aliasing if it's a view op, so we do not + // need to port the mutable output aliases portion of + // return_and_correct_aliasing at all. + const c10::IValue& arg_iv = + saved_args.at(saved_args.size() - num_arguments + arg_idx); + if (!arg_iv.isTensor()) { + continue; + } + const auto& arg = arg_iv.toTensor(); + int ret_idx = 0; + for (const auto& ret_schema : returns) { + const auto* ret_alias_info = ret_schema.alias_info(); + if (!ret_alias_info) { + ret_idx++; + continue; + } + if (sets_intersect( + arg_alias_info->beforeSets(), ret_alias_info->beforeSets())) { + py::object ret; + if (wrapped_result_tuple.has_value()) { + ret = wrapped_result_tuple.value()[ret_idx]; + } else { + TORCH_INTERNAL_ASSERT(ret_idx == 0); + ret = wrapped_result; + } + if (PyList_Check(ret.ptr())) { + py::list ret_list = py::reinterpret_borrow(ret); + for (const auto& r : ret_list) { + auto tensor = py::cast(r); + functionalize_unsafe_set(tensor, arg); + } + } else { + auto tensor = py::cast(ret); + functionalize_unsafe_set(tensor, arg); + } + } + ret_idx++; + } + } + stack->clear(); + return wrapped_result; + } + + auto dispatch = get_dtensor_dispatch(); + auto result = checked_vectorcall( + dispatch.ptr(), + py_op.ptr(), + args.ptr(), + kwargs.ptr(), + compute_mesh.ptr(), + cached_sharding.ptr(), + py_local_results.ptr(), + participating ? Py_True : Py_False, + is_inplace_op ? Py_True : Py_False, + is_out_variant_op ? Py_True : Py_False); + stack->clear(); + return result; +} + // DTensor-specific variant of make_wrapper_subclass to minimize DTensor // overhead. static PyObject* THPVariable_dtensor_new( @@ -1008,27 +1705,44 @@ static PyObject* THPVariable_dtensor_new( END_HANDLE_TH_ERRORS } +struct NativeRuntimeSchemaInfo { + py::object static_kwargkey; + size_t static_argnum; +}; + +NativeRuntimeSchemaInfo unpack_runtime_schema_info( + py::handle runtime_schema_info, + size_t num_args) { + NativeRuntimeSchemaInfo result; + if (!runtime_schema_info) { + result.static_argnum = num_args; + } else { + result.static_argnum = py::cast( + runtime_schema_info.attr(dtensor_interned_strings.static_argnum)); + result.static_kwargkey = + runtime_schema_info.attr(dtensor_interned_strings.static_kwargkey); + TORCH_CHECK( + result.static_kwargkey.is_none() || + PyList_Check(result.static_kwargkey.ptr()), + "RuntimeSchemaInfo.static_kwargkey must be a list!"); + } + return result; +} + static bool DTensor_OpSchema_recompute_comparison_key_impl( PyObject* self, const py::tuple& args_schema) { - py::object static_kwargkey; - size_t static_argnum = 0; const py::handle self_handle = py::handle(self); - const py::handle schema_info = + const auto schema_info = self_handle.attr(dtensor_interned_strings.schema_info); - if (checked_not(schema_info.ptr())) { - static_argnum = args_schema.size(); - static_kwargkey = py::none(); - } else { - static_argnum = py::cast( - schema_info.attr(dtensor_interned_strings.static_argnum)); - static_kwargkey = - schema_info.attr(dtensor_interned_strings.static_kwargkey); - } + NativeRuntimeSchemaInfo native_info = unpack_runtime_schema_info( + checked_not(schema_info.ptr()) ? py::handle() : py::handle(schema_info), + args_schema.size()); c10::SmallVector args_to_hash; size_t idx = 0; for (const auto& e : args_schema) { - if (idx >= static_argnum || arg_type_tensor_or_tensor_list_like(e)) { + if (idx >= native_info.static_argnum || + arg_type_tensor_or_tensor_list_like(e)) { if (PyList_Check(e.ptr())) { args_to_hash.push_back( py::reinterpret_steal(PyList_AsTuple(e.ptr()))); @@ -1043,24 +1757,19 @@ static bool DTensor_OpSchema_recompute_comparison_key_impl( args_to_hash_tup[idx] = std::move(args_to_hash[idx]); } PyObject* comparison_key = nullptr; - if (!static_kwargkey.is_none()) { - if (!PyList_Check(static_kwargkey.ptr())) { - PyErr_SetString( - PyExc_TypeError, "self.schema_info.static_kwargkey must be a list!"); - return false; - } - py::list static_kwargkey_list = - py::reinterpret_borrow(static_kwargkey); + if (native_info.static_kwargkey && !native_info.static_kwargkey.is_none()) { + py::list static_kwargkey = + py::reinterpret_borrow(native_info.static_kwargkey); auto raw_kwargs_schema = self_handle.attr(dtensor_interned_strings.kwargs_schema); if (!PyDict_Check(raw_kwargs_schema.ptr())) { PyErr_SetString(PyExc_TypeError, "self.kwargs_schema must be a dict!"); return false; } - py::tuple kwargs_to_hash(static_kwargkey_list.size()); + py::tuple kwargs_to_hash(static_kwargkey.size()); int idx = 0; auto kwargs_schema = py::reinterpret_borrow(raw_kwargs_schema); - for (const auto& k : static_kwargkey_list) { + for (const auto& k : static_kwargkey) { PyObject* item = PyDict_GetItemWithError(kwargs_schema.ptr(), k.ptr()); if (item) { kwargs_to_hash[idx++] = py::reinterpret_borrow(item); @@ -1255,6 +1964,370 @@ static PyObject* DTensor_compute_global_tensor_info( END_HANDLE_TH_ERRORS } +enum class TensorFlavor { + NON_TENSOR, + EXACTLY_DTENSOR, + EXACTLY_TENSOR, + DTENSOR_SUBCLASS, + NON_DTENSOR_TENSOR_SUBCLASS, +}; + +static std::pair check_for_dtensor_or_tensor( + const at::Tensor& tensor) { + if (!tensor.defined()) { + return {TensorFlavor::NON_TENSOR, py::object()}; + } + + // I don't think we need to check for wrapped_number() tensors here; + // the try_replicate_spec_for_scalar_tensor stuff in our caller + // specifically handles 1-element tensors. + + torch::jit::guardAgainstNamedTensor(tensor); + auto py_tensor = py::cast(tensor); + + const auto dtensor = get_dtensor_class(); + auto* const obj_type = Py_TYPE(py_tensor.ptr()); + if (obj_type == (PyTypeObject*)dtensor.ptr()) { + return {TensorFlavor::EXACTLY_DTENSOR, std::move(py_tensor)}; + } + // Fast path for plain old Tensors. + if (THPVariable_CheckTypeExact(obj_type)) { + return {TensorFlavor::EXACTLY_TENSOR, std::move(py_tensor)}; + } + if (py::isinstance(py_tensor, dtensor)) { + return {TensorFlavor::DTENSOR_SUBCLASS, std::move(py_tensor)}; + } + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + THPVariableClass && py::isinstance(py_tensor, THPVariableClass)); + return {TensorFlavor::NON_DTENSOR_TENSOR_SUBCLASS, std::move(py_tensor)}; +} + +static std::pair check_for_dtensor_or_tensor( + const c10::IValue& iv) { + if (!iv.isTensor()) { + return {TensorFlavor::NON_TENSOR, py::object()}; + } + + return check_for_dtensor_or_tensor(iv.toTensor()); +} + +static c10::List replace_dtensors_with_local_tensor( + const c10::List& tl) { + c10::List local_list(tl.elementType()); + local_list.reserve(tl.size()); + for (const auto& elt : tl) { + const auto [tensor_flavor, py_tensor] = check_for_dtensor_or_tensor(elt); + if (tensor_flavor == TensorFlavor::EXACTLY_DTENSOR || + tensor_flavor == TensorFlavor::DTENSOR_SUBCLASS) { + local_list.push_back(THPVariable_Unpack( + py_tensor.attr(dtensor_interned_strings._local_tensor).ptr())); + } else { + local_list.push_back(elt); + } + } + return local_list; +} + +static void replace_dtensors_with_local_tensor(torch::jit::Stack& stack) { + for (auto& arg : stack) { + if (arg.isList()) { + arg = replace_dtensors_with_local_tensor(arg.toList()); + continue; + } + const auto [tensor_flavor, py_tensor] = check_for_dtensor_or_tensor(arg); + if (tensor_flavor == TensorFlavor::EXACTLY_DTENSOR || + tensor_flavor == TensorFlavor::DTENSOR_SUBCLASS) { + arg = THPVariable_Unpack( + py_tensor.attr(dtensor_interned_strings._local_tensor).ptr()); + } + } +} + +static py::object try_find_mesh_from_args( + const c10::OperatorHandle& op, + const OperatorArgsKwargsView& args_kwargs) { + for (auto argument_it = args_kwargs.args_begin(); + argument_it != args_kwargs.args_end(); + ++argument_it) { + const auto [tensor_flavor, py_tensor] = + check_for_dtensor_or_tensor(*argument_it); + if (tensor_flavor == TensorFlavor::EXACTLY_DTENSOR || + tensor_flavor == TensorFlavor::DTENSOR_SUBCLASS) { + return py::reinterpret_borrow( + py_tensor.attr(dtensor_interned_strings.device_mesh)); + } + } + TORCH_CHECK_VALUE( + false, "Cannot find device mesh from args for op : ", op.operator_name()); +} + +static /*DTensorSpec*/ py::object try_replicate_spec_for_scalar_tensor( + bool allow_implicit_replication, + py::handle op_call, + py::handle py_tensor, + py::handle compute_mesh) { + const Tensor& tensor_arg = THPVariable_Unpack(py_tensor.ptr()); + const bool numel_is_one = tensor_arg.numel() == 1; + if (numel_is_one && tensor_arg.dim() == 1) { + TORCH_WARN( + "Found a non-scalar tensor with numel=1 and ndim!=0, " + "we are implicitly creating a replicated DTensor for it. " + "However, please consider changing it to a scalar tensor " + "or explicitly create a DTensor under distributed environment."); + } + + TORCH_CHECK( + numel_is_one || allow_implicit_replication, + py::str(op_call), + " got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!"); + + // scalar tensor can be safely treated as replicated. + const auto num_placements = + py::cast(compute_mesh.attr(dtensor_interned_strings.ndim)); + py::tuple placements_tuple(num_placements); + py::object replicate = get_replicate_class()(); + for (const auto idx : c10::irange(num_placements)) { + PyTuple_SET_ITEM( + placements_tuple.ptr(), + idx, + py::reinterpret_borrow(replicate).release().ptr()); + } + + return checked_vectorcall( + get_dtensor_spec_class().ptr(), + compute_mesh.ptr(), + placements_tuple.ptr(), + checked_vectorcall( + get_tensor_meta_class().ptr(), + py_tensor.attr(dtensor_interned_strings.shape).ptr(), + py_tensor.attr(dtensor_interned_strings.stride)().ptr(), + py_tensor.attr(dtensor_interned_strings.dtype).ptr()) + .ptr()); +} + +// May return unset object, in which case there was no runtime schema +// info. +static py::object get_runtime_schema_info_for_op(py::handle py_op) { + const auto op_dispatcher = get_dtensor_op_dispatcher(); + const auto sharding_propagator = + op_dispatcher.attr(dtensor_interned_strings.sharding_propagator); + const py::dict op_to_schema_info = py::reinterpret_borrow( + sharding_propagator.attr(dtensor_interned_strings.op_to_schema_info)); + + PyObject* runtime_schema_info = + PyDict_GetItemWithError(op_to_schema_info.ptr(), py_op.ptr()); + if (!runtime_schema_info && PyErr_Occurred()) { + throw py::error_already_set(); + } + return py::reinterpret_borrow(runtime_schema_info); +} + +static bool contains_any_symint(const py::tuple& tup) { + for (const auto& s : tup) { + if (THPUtils_checkLong(s.ptr())) { + continue; + } + if (torch::is_symint(s)) { + return true; + } + } + return false; +} + +static bool dtensor_spec_has_symints(py::handle spec) { + const auto tensor_meta = spec.attr(dtensor_interned_strings.tensor_meta); + if (tensor_meta.is_none()) { + return false; + } + py::object raw_shape = tensor_meta.attr(dtensor_interned_strings.shape); + if (!PyTuple_Check(raw_shape.ptr())) { + PyErr_SetString(PyExc_TypeError, "TensorMeta.shape must be a tuple!"); + throw py::error_already_set(); + } + const auto shape = py::reinterpret_steal(raw_shape.release()); + return contains_any_symint(shape); +} + +static std::optional> +create_native_op_schema( + const c10::OperatorHandle& op, + py::handle py_op, + torch::jit::Stack* stack) { + // fused schema part of unwrap_to_op_info + recompute_comparison_key, + // operating on IValues instead of Python stuff. + + py::object runtime_schema_info = get_runtime_schema_info_for_op(py_op); + if (runtime_schema_info && + checked_istrue(py::handle(runtime_schema_info) + .attr(dtensor_interned_strings.needs_pytree) + .ptr())) { + // Punting on pytree flattening in the fast path on IValues for + // now since only a minority of ops need it. + return std::nullopt; + } + + OperatorArgsKwargsView args_kwargs(op, *stack); + auto native_info = unpack_runtime_schema_info( + py::handle(runtime_schema_info), args_kwargs.num_positional_args()); + + c10::SmallVector comparison_key; + std::size_t comparison_key_hash = 0; + + py::object compute_mesh = py::none(); + + const auto handle_non_dtensor_arg = + [&comparison_key, &comparison_key_hash, &native_info]( + size_t idx, c10::IValue arg) { + if (idx >= native_info.static_argnum) { + if (arg.isList()) { + const auto& list = arg.toList(); + if (list.empty()) { + arg = c10::ivalue::Tuple::create({}); + } else { + // WARNING: here we rely on c10::List being represented + // by a contiguous array of IValue for efficiency! + arg = c10::ivalue::Tuple::create(c10::ArrayRef( + &(*list.begin()).get(), list.size())); + } + } else if (arg.isTensor() && !arg.toTensor().defined()) { + // Coerce undefined Tensor to None, just as we do when + // converting IValues to PyObject. Otherwise comparison + // doesn't work. (undefined Tensors can get here because + // check_for_dtensor_or_tensor calls them non-Tensors, but + // doesn't have a way to do the coercion for us.) + arg = c10::IValue(); + } + comparison_key_hash = + c10::hash_combine(comparison_key_hash, c10::IValue::hash(arg)); + comparison_key.emplace_back(std::move(arg)); + } + }; + const auto handle_dtensor_arg = [&comparison_key, + &comparison_key_hash](py::object arg) { + comparison_key_hash = c10::hash_combine( + comparison_key_hash, static_cast(py::hash(arg))); + comparison_key.emplace_back(std::move(arg)); + }; + + Py_ssize_t idx = 0; + const bool allow_implicit_replication = + at::get_dtensor_allow_implicit_replication(); + for (auto argument_it = args_kwargs.args_begin(); + argument_it != args_kwargs.args_end(); + ++argument_it) { + const auto& arg = *argument_it; + const auto [tensor_flavor, py_tensor] = check_for_dtensor_or_tensor(arg); + switch (tensor_flavor) { + case TensorFlavor::EXACTLY_DTENSOR: + case TensorFlavor::DTENSOR_SUBCLASS: { + py::object spec = py_tensor.attr(dtensor_interned_strings._spec); + if (dtensor_spec_has_symints(spec)) { + // Symints are unhashable, so we can't use the cache for + // sharding propagation. bail out to slow path. + return std::nullopt; + } + handle_dtensor_arg(std::move(spec)); + if (compute_mesh.is_none()) { + compute_mesh = py::reinterpret_borrow( + py_tensor.attr(dtensor_interned_strings.device_mesh)); + } + break; + } + case TensorFlavor::EXACTLY_TENSOR: + case TensorFlavor::NON_DTENSOR_TENSOR_SUBCLASS: { + if (compute_mesh.is_none()) { + compute_mesh = try_find_mesh_from_args(op, args_kwargs); + } + handle_dtensor_arg(try_replicate_spec_for_scalar_tensor( + allow_implicit_replication, py_op, py_tensor, compute_mesh)); + break; + } + case TensorFlavor::NON_TENSOR: { + // non DTensor/Tensor args (i.e. int/float/bool), just add to + // local_args + handle_non_dtensor_arg(idx, arg); + break; + } + default: + TORCH_INTERNAL_ASSERT(false, "can't happen"); + break; + } + idx++; + } + + TORCH_CHECK( + !compute_mesh.is_none(), + "found no DeviceMesh from dtensor args for ", + op.operator_name()); + + if (native_info.static_kwargkey && !native_info.static_kwargkey.is_none()) { + // Separator to disambiguate kwargs from args in comparison and hashing. + static constexpr int64_t kwargs_separator = 0x0011223344556677LL; + comparison_key.emplace_back(static_cast(kwargs_separator)); + comparison_key_hash = hash_combine(comparison_key_hash, kwargs_separator); + + for (auto argument_it = args_kwargs.kwargs_begin(); + argument_it != args_kwargs.kwargs_end(); + ++argument_it) { + // Rather than hash/compare the string key, we can just use the + // index of the kwarg in the schema! + const auto underlying_index = argument_it.underlying_index(); + comparison_key.emplace_back(c10::IValue(underlying_index)); + comparison_key_hash = hash_combine( + comparison_key_hash, c10::IValue::hash(comparison_key.back().iv)); + const auto [tensor_flavor, py_tensor] = + check_for_dtensor_or_tensor(*argument_it); + switch (tensor_flavor) { + case TensorFlavor::EXACTLY_DTENSOR: + case TensorFlavor::DTENSOR_SUBCLASS: { + handle_dtensor_arg(py_tensor.attr(dtensor_interned_strings._spec)); + break; + } + case TensorFlavor::EXACTLY_TENSOR: + case TensorFlavor::NON_DTENSOR_TENSOR_SUBCLASS: { + handle_dtensor_arg(try_replicate_spec_for_scalar_tensor( + allow_implicit_replication, py_op, py_tensor, compute_mesh)); + break; + } + case TensorFlavor::NON_TENSOR: { + handle_non_dtensor_arg(native_info.static_argnum, *argument_it); + break; + } + default: + TORCH_INTERNAL_ASSERT(false, "can't happen"); + break; + } + } + } + + return std::make_pair( + NativeOpSchema( + op, + std::move(comparison_key), + comparison_key_hash, + args_kwargs.num_positional_args()), + std::move(compute_mesh)); +} + +static PyObject* get_DTensor_sharding_propagator_cache_stats( + PyObject* self, + PyObject* noargs) { + HANDLE_TH_ERRORS + auto& cache = get_thread_local_native_sharding_propagator_cache(); + py::tuple result(2); + result[0] = cache.hits(); + result[1] = cache.misses(); + return result.release().ptr(); + END_HANDLE_TH_ERRORS +} + +static PyObject* clear_DTensor_sharding_propagator_cache( + PyObject* self, + PyObject* noargs) { + native_sharding_propagator_cache_DO_NOT_USE.reset(); + Py_RETURN_NONE; +} + using getter = PyObject* (*)(PyObject*, void*); using setter = int (*)(PyObject*, PyObject*, void*); @@ -2228,7 +3301,7 @@ static PyMethodDef extra_methods[] = { {nullptr}}; // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) -static PyMethodDef extra_functions[] = { +static PyMethodDef extra_dtensor_functions[] = { {"_DTensor_OpSchema_post_init", DTensor_OpSchema_post_init, METH_O, @@ -2241,6 +3314,14 @@ static PyMethodDef extra_functions[] = { castPyCFunctionFast(DTensor_compute_global_tensor_info), METH_FASTCALL, compute_global_tensor_info_doc}, + {"_get_DTensor_sharding_propagator_cache_stats", + get_DTensor_sharding_propagator_cache_stats, + METH_NOARGS, + nullptr}, + {"_clear_DTensor_sharding_propagator_cache", + clear_DTensor_sharding_propagator_cache, + METH_NOARGS, + nullptr}, {nullptr}}; struct THPVariableMeta { @@ -2601,13 +3682,22 @@ bool THPVariable_initModule(PyObject* module) { PyModule_AddObject(module, "TensorBase", (PyObject*)&THPVariableType); Py_INCREF(&THPVariableType); PyModule_AddObject(module, "_TensorBase", (PyObject*)&THPVariableType); +#ifdef USE_DISTRIBUTED + PyModule_AddObject( + module, + "__DTensor_fastpath_cache_cleanup", + py::capsule( + []() { cleanup_thread_local_native_sharding_propagator_caches(); }) + .release() + .ptr()); + if (!intern_dtensor_strings()) { + return false; + } + PyModule_AddFunctions(module, extra_dtensor_functions); +#endif torch::autograd::initTorchFunctions(module); torch::autograd::initTensorImplConversion(module); torch::utils::validate_numpy_for_dlpack_deleter_bug(); - if (!intern_dtensor_strings()) { - return false; - } - PyModule_AddFunctions(module, extra_functions); return true; } diff --git a/torch/csrc/autograd/python_variable.h b/torch/csrc/autograd/python_variable.h index af733f2ad1769..5b6f089990693 100644 --- a/torch/csrc/autograd/python_variable.h +++ b/torch/csrc/autograd/python_variable.h @@ -90,6 +90,15 @@ void pushPyOutToStack( py::object out, const char* msg); +py::handle get_dtensor_class(); + +py::object dispatchDTensorOp( + const c10::OperatorHandle& op, + py::handle py_op, + py::handle args, + py::handle kwargs, + torch::jit::Stack* stack); + inline PyObject* THPVariable_WrapList( const torch::autograd::variable_list& inputs) { PyObject* pyinput = PyList_New(static_cast(inputs.size())); diff --git a/torch/csrc/jit/frontend/schema_matching.cpp b/torch/csrc/jit/frontend/schema_matching.cpp index d866e4f434448..c3525ac9c8a20 100644 --- a/torch/csrc/jit/frontend/schema_matching.cpp +++ b/torch/csrc/jit/frontend/schema_matching.cpp @@ -679,7 +679,7 @@ Value* emitBuiltinCall( at::ArrayRef args, at::ArrayRef kwargs, const std::optional& self) { - const auto& variants = getAllOperatorsFor(name); + auto variants = getAllOperatorsFor(name); const auto& builtin_functions = getAllBuiltinFunctionsFor(name); // first let's set the graph's version diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index ac99385401be4..513258236ac4b 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -617,7 +617,7 @@ void AliasDb::analyzeImpl(Node* node) { oss << input->type()->str() << ", "; } oss << "\n\nCandidates:"; - const auto& candidates = getAllOperatorsFor(node->kind()); + auto candidates = getAllOperatorsFor(node->kind()); for (const auto& candidate : candidates) { oss << "\n\t" << candidate->schema(); } diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index 08bfe47382952..9b00a703e352e 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -1088,7 +1088,7 @@ const FunctionSchema* Node::maybeSchema() const { const Operator* Node::maybeOperator() const { if (!op_) { - const auto& candidates = getAllOperatorsFor(kind()); + auto candidates = getAllOperatorsFor(kind()); for (const auto& candidate : candidates) { if (matches(candidate->schema())) { op_ = candidate.get(); diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 8dc4cb7ac9349..a7f16a7dc5a04 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1693,7 +1693,7 @@ void initJITBindings(PyObject* module) { [](const std::string& op_name, const std::string& overload_name) { try { auto symbol = Symbol::fromQualString(op_name); - const auto& operations = getAllOperatorsFor(symbol); + auto operations = getAllOperatorsFor(symbol); for (const auto& op : operations) { if (op->schema().overload_name() == overload_name) { return op->schema(); @@ -1714,7 +1714,7 @@ void initJITBindings(PyObject* module) { const std::string& overload_name) -> std::optional { try { auto symbol = Symbol::fromQualString(op_name); - const auto& operations = getAllOperatorsFor(symbol); + auto operations = getAllOperatorsFor(symbol); bool allow_numbers_as_tensors = opAllowsNumbersAsTensors(symbol); for (const auto& op : operations) { if (op->schema().overload_name() == overload_name) { @@ -2138,7 +2138,7 @@ void initJITBindings(PyObject* module) { m.def("_jit_get_custom_class_schemas", customClassSchemasForBCCheck); m.def("_jit_get_schemas_for_operator", [](const std::string& qualified_name) { auto symbol = Symbol::fromQualString(qualified_name); - const auto& operations = getAllOperatorsFor(symbol); + auto operations = getAllOperatorsFor(symbol); return fmap(operations, [](const std::shared_ptr& op) { return op->schema(); }); diff --git a/torch/csrc/jit/runtime/operator.cpp b/torch/csrc/jit/runtime/operator.cpp index 35dead2a395c9..6f9dec70cddc9 100644 --- a/torch/csrc/jit/runtime/operator.cpp +++ b/torch/csrc/jit/runtime/operator.cpp @@ -53,6 +53,16 @@ struct OperatorRegistry { to_register.clear(); } + const std::vector>& getOperatorsWithLockHeld( + Symbol name) { + registerPendingOperators(); + static std::vector> empty; + auto it = operators.find(name); + if (it != operators.end()) + return it->second; + return empty; + } + public: void registerOperator(Operator&& op) { std::lock_guard guard(lock); @@ -143,14 +153,35 @@ struct OperatorRegistry { return it->second; } - const std::vector>& getOperators(Symbol name) { + // This function returns internal lock-protected state. We need to + // copy it to avoid race conditions. + std::vector> getOperators(Symbol name) { std::lock_guard guard(lock); - registerPendingOperators(); - static std::vector> empty; - auto it = operators.find(name); - if (it != operators.end()) - return it->second; - return empty; + return getOperatorsWithLockHeld(name); + } + + std::vector> getSortedOperators(Symbol name) { + std::lock_guard guard(lock); + const auto& unsortedOps = getOperatorsWithLockHeld(name); + // Depending on the order of registration, aten or jit ops may be + // registered first. This sorting is helpful in cases where + // deterministic (i.e. not dependent on build config) behavior is + // desired; e.g. torch.ops.aten.* uses this function, and tries to + // find the "first" op that matches input args. Without the sorting, + // the "first" op may change depending on registration order. + std::vector> sortedOps; + sortedOps.reserve(unsortedOps.size()); + std::copy_if( + unsortedOps.begin(), + unsortedOps.end(), + std::back_inserter(sortedOps), + [](const std::shared_ptr& op) { return op->isC10Op(); }); + std::copy_if( + unsortedOps.begin(), + unsortedOps.end(), + std::back_inserter(sortedOps), + [](const std::shared_ptr& op) { return !op->isC10Op(); }); + return sortedOps; } std::vector findSimilarOperators(Symbol input_op) { @@ -387,35 +418,16 @@ void deregisterOperator(const FunctionSchema& schema) { getRegistry().deregisterOperator(schema); } -const std::vector> getAllOperators() { +std::vector> getAllOperators() { return getRegistry().getAllOperators(); } -const std::vector>& getAllOperatorsFor(Symbol name) { +std::vector> getAllOperatorsFor(Symbol name) { return getRegistry().getOperators(name); } std::vector> getAllSortedOperatorsFor(Symbol name) { - const auto& unsortedOps = getAllOperatorsFor(name); - // Depending on the order of registration, aten or jit ops may be - // registered first. This sorting is helpful in cases where - // deterministic (i.e. not dependent on build config) behavior is - // desired; e.g. torch.ops.aten.* uses this function, and tries to - // find the "first" op that matches input args. Without the sorting, - // the "first" op may change depending on registration order. - std::vector> sortedOps; - sortedOps.reserve(unsortedOps.size()); - std::copy_if( - unsortedOps.begin(), - unsortedOps.end(), - std::back_inserter(sortedOps), - [](const std::shared_ptr& op) { return op->isC10Op(); }); - std::copy_if( - unsortedOps.begin(), - unsortedOps.end(), - std::back_inserter(sortedOps), - [](const std::shared_ptr& op) { return !op->isC10Op(); }); - return sortedOps; + return getRegistry().getSortedOperators(name); } std::shared_ptr findOperatorFor(const c10::OperatorName& full_name) { diff --git a/torch/csrc/jit/runtime/operator.h b/torch/csrc/jit/runtime/operator.h index bde3825f5ea38..6b6972deeebf0 100644 --- a/torch/csrc/jit/runtime/operator.h +++ b/torch/csrc/jit/runtime/operator.h @@ -260,8 +260,9 @@ struct TORCH_API Operator { TORCH_API std::string canonicalSchemaString(const FunctionSchema& schema); -TORCH_API const std::vector> getAllOperators(); -TORCH_API const std::vector>& getAllOperatorsFor( +TORCH_API std::vector> getAllOperators(); +// This function returns a copy for thread safety. +TORCH_API std::vector> getAllOperatorsFor( Symbol name); // Returns operators in the order which OpOverloadPacket resolves them. TORCH_API std::vector> getAllSortedOperatorsFor( diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp index 74f87e46757ea..b1f0f410f14fe 100644 --- a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp +++ b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp @@ -79,7 +79,7 @@ auto compilation_unit = std::make_shared(); const std::optional getInplaceVariant( const FunctionSchema& base_schema) { - auto& inplace_variants = + auto inplace_variants = getAllOperatorsFor(c10::Symbol::fromQualString(base_schema.name() + "_")); for (const auto& variant : inplace_variants) { diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 0a046523127d5..e89f7887320a0 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -12,6 +13,7 @@ #include #include #include +#include #include #include @@ -301,6 +303,16 @@ static py::object maybe_get_registered_torch_dispatch_rule( return result; } +static bool is_dtensor(PyObject* obj) { +#ifdef USE_DISTRIBUTED + const py::handle dtensor = get_dtensor_class(); + return (PyObject*)Py_TYPE(obj) == dtensor.ptr() || + py::isinstance(py::handle(obj), dtensor); +#else + return false; +#endif +} + // NB: Invariant: if you run this function, you MUST test if the returned // py::object is nullptr, as this will occur WITHOUT error condition being set. // And if an error happens, this function is responsible for throwing a C++ @@ -313,8 +325,8 @@ static py::object dispatch_on_subclass( PyObject* torch_api_function, bool is_torch_function, const char* torch_function_name_str, - std::optional maybe_mode_key = - std::nullopt) { + const c10::OperatorHandle* opt_op, + torch::jit::Stack* opt_stack) { py::object ret; for (auto& arg : overloaded_args) { py::object torch_function = @@ -367,13 +379,39 @@ static py::object dispatch_on_subclass( } } - ret = py::reinterpret_steal(PyObject_CallFunctionObjArgs( - torch_function.ptr(), - torch_api_function, - py_types.ptr(), - args, - kwargs, - NULL)); + if (!is_torch_function && is_dtensor(arg)) { + if (opt_op && opt_stack) { + ret = dispatchDTensorOp( + *opt_op, torch_api_function, args, kwargs, opt_stack); + } else { + // Slow path -- reconstruct C++ data structures since they were not + // provided. + auto schema = py::cast( + py::handle(torch_api_function).attr("_schema")); + auto opt_op_handle = + c10::Dispatcher::singleton().findOp(schema.operator_name()); + TORCH_CHECK( + opt_op_handle.has_value(), + "could not look up op for ", + schema.operator_name()); + const auto& op_handle = *opt_op_handle; + auto stack = torch::jit::createStackForSchema( + op_handle.schema(), + py::reinterpret_borrow(args), + py::reinterpret_borrow(kwargs), + std::nullopt); + ret = dispatchDTensorOp( + op_handle, torch_api_function, args, kwargs, &stack); + } + } else { + ret = py::reinterpret_steal(PyObject_CallFunctionObjArgs( + torch_function.ptr(), + torch_api_function, + py_types.ptr(), + args, + kwargs, + NULL)); + } if (ret.ptr() == nullptr) { throw python_error(); } @@ -480,6 +518,28 @@ auto handle_torch_function_no_python_arg_parser( PyObject* torch_api_function, const char* module_name, TorchFunctionName torch_function_name) -> PyObject* { + return handle_torch_function_no_python_arg_parser( + overloaded_args, + args, + kwargs, + func_name, + torch_api_function, + module_name, + nullptr, + nullptr, + torch_function_name); +} + +auto handle_torch_function_no_python_arg_parser( + at::ArrayRef overloaded_args, + PyObject* args, + PyObject* kwargs, + const char* func_name, + PyObject* torch_api_function, + const char* module_name, + const c10::OperatorHandle* opt_op, + torch::jit::Stack* opt_stack, + TorchFunctionName torch_function_name) -> PyObject* { const char* torch_function_name_str = nullptr; switch (torch_function_name) { case TorchFunctionName::TorchFunction: @@ -579,7 +639,9 @@ auto handle_torch_function_no_python_arg_parser( py_types, torch_api_function, is_torch_function, - torch_function_name_str); + torch_function_name_str, + opt_op, + opt_stack); if (curr_ret.ptr() != nullptr) { ret = curr_ret; } diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index 3ee12f14528e2..4a73a21916776 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -1287,6 +1287,18 @@ auto TORCH_PYTHON_API handle_torch_function_no_python_arg_parser( TorchFunctionName torch_function_name = TorchFunctionName::TorchFunction) -> PyObject*; +auto handle_torch_function_no_python_arg_parser( + at::ArrayRef overloaded_args, + PyObject* args, + PyObject* kwargs, + const char* func_name, + PyObject* torch_api_function, + const char* module_name, + const c10::OperatorHandle* opt_op, + torch::jit::Stack* opt_stack, + TorchFunctionName torch_function_name = TorchFunctionName::TorchFunction) + -> PyObject*; + // Used for getters of Tensor properties auto handle_torch_function_getter( THPVariable* self, diff --git a/torch/distributed/_tools/mem_tracker.py b/torch/distributed/_tools/mem_tracker.py index 59692d9237b66..819e16ca99698 100644 --- a/torch/distributed/_tools/mem_tracker.py +++ b/torch/distributed/_tools/mem_tracker.py @@ -391,7 +391,6 @@ def __init__(self) -> None: # Weak references to the topmost AC module currently active self._ac_mod: Optional[weakref.ref] = None self._orig_resize = torch.UntypedStorage.resize_ - self._orig_dtensor_dispatch = DTensor._op_dispatcher.dispatch self._depth = 0 def _update_snap( diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py index a6b6e39511974..fb072d8dce629 100644 --- a/torch/distributed/tensor/_api.py +++ b/torch/distributed/tensor/_api.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates +import copy import inspect import warnings from collections.abc import Callable, Sequence @@ -96,16 +97,23 @@ def backward(ctx, grad_output: torch.Tensor): # type: ignore[override] ) tensor_stride = tuple(tensor_stride) grad_placements = grad_placements or dtensor_spec.placements - grad_spec = DTensorSpec( - mesh, - grad_placements, - tensor_meta=TensorMeta( - shape=dtensor_meta.shape, - stride=tensor_stride, - dtype=dtensor_meta.dtype, - ), - ) - + if ( + tensor_stride == dtensor_meta.stride + and grad_placements == dtensor_spec.placements + ): + # Avoid actual sharing of specs in case they're modified during (e.g.) + # sharding propagation. + grad_spec = copy.copy(dtensor_spec) + else: + grad_spec = DTensorSpec( + mesh, + grad_placements, + tensor_meta=TensorMeta( + shape=dtensor_meta.shape, + stride=tensor_stride, + dtype=dtensor_meta.dtype, + ), + ) return ( # pyrefly: ignore [bad-argument-type] DTensor( @@ -338,14 +346,14 @@ def __coerce_same_metadata_as_tangent__(self, flatten_spec, expected_type=None): ) @classmethod - @torch._disable_dynamo - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] - return DTensor._op_dispatcher.dispatch( - func, - args, - kwargs or {}, + # We just need to have an implementation here; the __torch_dispatch__ machinery + # calls into a specific C++ fast path that doesn't call here. + # See #167051 for details + # python_arg_parser.cpp: dispatch_on_subclass() + # -> python_variable.cpp: dispatchDTensorOp() + raise NotImplementedError( + "DTensor.__torch_dispatch__ should not actually get called" ) @staticmethod diff --git a/torch/distributed/tensor/_dispatch.py b/torch/distributed/tensor/_dispatch.py index b883c954de3b6..aaa5d25c79ba7 100644 --- a/torch/distributed/tensor/_dispatch.py +++ b/torch/distributed/tensor/_dispatch.py @@ -10,9 +10,15 @@ import torch.distributed.tensor._api as dtensor import torch.distributed.tensor._random as random from torch._library.utils import fill_defaults +from torch.distributed._functional_collectives import _are_we_tracing from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta -from torch.distributed.tensor._op_schema import OpInfo, OpSchema, OutputSpecType +from torch.distributed.tensor._op_schema import ( + OpInfo, + OpSchema, + OutputSharding, + OutputSpecType, +) from torch.distributed.tensor._random import is_rng_supported_mesh from torch.distributed.tensor._redistribute import redistribute_local_tensor from torch.distributed.tensor._sharding_prop import ShardingPropagator @@ -125,6 +131,8 @@ class OpDispatcher: def __init__(self) -> None: self.sharding_propagator = ShardingPropagator() + # NOTE: must stay in sync with is_random_op in + # torch/csrc/autograd/python_variable.cpp self._random_ops = { aten.native_dropout.default, aten.normal_.default, @@ -147,6 +155,17 @@ def __init__(self) -> None: aten.as_strided.default: as_strided_handler, } + # ******************************************************************************************** + # def dispatch(...) + # + # NOTE: this class no longer contains the top-level dispatch entrypoint! + # See #167051 for details + # + # The entrypoint has been moved to C++, and it handles common cases and then calls back into + # OpDispatcher python to handle corner cases. + # See dispatchDTensorOp() defined in python_variable.cpp and called from python_arg_parser.cpp + # ******************************************************************************************** + # This flag is used internally to control whether we treat the torch.Tensor(non-DTensor) # as implicitly replicated or we throw error to user. # NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave @@ -159,26 +178,39 @@ def _allow_implicit_replication(self) -> bool: def _allow_implicit_replication(self, value: bool) -> None: return torch._C._set_dtensor_allow_implicit_replication(value) - def dispatch( + def _propagate_op_sharding_dispatch_slow_path( self, op_call: torch._ops.OpOverload, args: tuple[object, ...], kwargs: dict[str, object], + op_info: OpInfo, + # The logic here is a bit messy. There are several reasons why the + # C++ fastpath may have bailed out. If we just cache missed, we will + # come here because we need to actually calculate the real thing. + # There's no need to have a SECOND Python cache lookup; the C++ native + # cache completely subsumes it. But sometimes, we will have failed + # to compute the cache key in C++ entirely. In this case, we DO need + # to do a cache lookup in Python, as the missing cache key in C++ + # means we don't have access to it all. Furthermore, without duping + # this function, we need to do the try_cache test inside of the + # try-except block so that either case hits the inference mode / + # exception rewrapping case. + # + # This should be cleaned up. First, ensuring the C++ codepath can + # always compute a key will be a big help. Second, we should properly + # fastpath inference mode composite implicit autograd so that you + # don't have to throw an exception even in "fastpath". + try_cache: bool, ) -> object: - """ - Main dispatching logic. Follows precedence order: - (1) custom_op_handler - (2) registered sharding strategy, then rule - (3) composite implicit autograd decomposition - """ - if op_call in self._custom_op_handlers: - return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator] - - # extract local tensor and sharding infos to a OpInfo - op_info = self.unwrap_to_op_info(op_call, args, kwargs) - try: - self.sharding_propagator.propagate(op_info) + # We have basically inlined propagate() here, but WITHOUT the + # output_sharding assignment + if try_cache and not _are_we_tracing(): + return self.sharding_propagator.propagate_op_sharding(op_info.schema) + else: + return self.sharding_propagator.propagate_op_sharding_non_cached( + op_info.schema + ) except NotImplementedError: if torch._C._dispatch_has_kernel_for_dispatch_key( op_call.name(), torch._C.DispatchKey.CompositeImplicitAutograd @@ -195,6 +227,12 @@ def dispatch( f"{e}\n\nSharding propagation failed for {op_info.schema}" ) from e + def _dispatch_get_local_results_slow_path( + self, + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + op_info: OpInfo, + ) -> object: output_sharding = op_info.output_sharding assert output_sharding is not None, "output sharding should not be None" @@ -266,7 +304,7 @@ def dispatch( # 2. if the return type is Tensor or List[Tensor], return empty # tensor(s) with correct dtype. spec = output_sharding.output_spec - ret_list = op_info.schema.op._schema.returns + ret_list = op_call._schema.returns if spec is None: # For a scalar return type, the non-participating device has None @@ -301,6 +339,23 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor: raise NotImplementedError( f"return type {ret_type} in DTensor op is not supported" ) + return local_results + + def _dispatch_fast_path_python_tail( + self, + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], + compute_mesh: DeviceMesh, + output_sharding: OutputSharding, + local_results: object, + participating: bool, + is_inplace_op: bool, + is_out_variant_op: bool, + ) -> object: + """ + Tail of main dispatching logic, called from C++ fast path. + """ if output_sharding.output_spec is None: if op_call == aten.equal.default: @@ -310,12 +365,12 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor: assert local_results is None or isinstance(local_results, bool) r = torch.tensor( int(local_results) if local_results is not None else 1, - device=mesh.device_type, + device=compute_mesh.device_type, ) dist.all_reduce(r, op=dist.ReduceOp.MIN) local_results = bool(r.item()) - if op_info.schema.is_inplace_op(): + if is_inplace_op: # inplace op should return self instead of re-wrapping if output_sharding.output_spec is not None: output_spec = output_sharding.output_spec @@ -349,7 +404,7 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor: return args[0] else: return None - elif op_info.schema.is_out_variant_op(): + elif is_out_variant_op: # out variant could possibly have multiple out args (i.e. lu_unpack.out) output_specs = ( (output_sharding.output_spec,) @@ -368,8 +423,9 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor: assert len(out_dts) >= 1, "out variant should have at least one out arg" return tuple(out_dts) if len(out_dts) > 1 else out_dts[0] else: + assert op_call == aten.equal.default, op_call ret = self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined] - if participating and op_info.schema.is_view_op(): + if participating and op_call._schema._is_view_op(): return return_and_correct_aliasing(op_call, args, kwargs, ret) else: return ret @@ -436,6 +492,15 @@ def unwrap_to_op_info( op_call: torch._ops.OpOverload, args: tuple[object, ...], kwargs: dict[str, object], + ) -> OpInfo: + return self._unwrap_to_op_info_impl(op_call, args, kwargs, True) + + def _unwrap_to_op_info_impl( + self, + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], + create_schema: bool, ) -> OpInfo: # get runtime schema info to determine whether to use pytree to flatten inputs runtime_schema_info = self.sharding_propagator.op_to_schema_info.get( @@ -512,7 +577,9 @@ def unwrap_to_op_info( ), kwargs_schema, schema_info=runtime_schema_info, - ), + ) + if create_schema + else None, # type: ignore[arg-type] args_schema, tuple(local_args), local_kwargs, diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index ede7515efd102..f3dc04ef10f97 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -345,6 +345,10 @@ def spec_to_strategy(spec: object) -> object: ) def propagate(self, op_info: OpInfo) -> None: + # NB: The logic here is duplicated in _propagate_op_sharding_dispatch_slow_path. + # Ideally, this function would be deleted, but there are a handful of + # one off call sites here that aren't cleaned up. + # We cannot use an lru cache if we know that inputs will have dynamic shapes, # because SymInts are not hashable. # This is generally ok because this only happens during tracing in torch.compile, diff --git a/torch/distributed/tensor/debug/__init__.py b/torch/distributed/tensor/debug/__init__.py index a74f1449ad125..e6aeca3b93a12 100644 --- a/torch/distributed/tensor/debug/__init__.py +++ b/torch/distributed/tensor/debug/__init__.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +import torch._C from torch.distributed.tensor.debug._comm_mode import CommDebugMode from torch.distributed.tensor.debug._visualize_sharding import visualize_sharding @@ -6,11 +7,12 @@ __all__ = ["CommDebugMode", "visualize_sharding"] -def _get_sharding_prop_cache_info(): +def _get_python_sharding_prop_cache_info(): """ - Get the cache info for the sharding propagation cache, used for debugging purpose only. + Get the cache info for the Python sharding propagation cache, used for debugging purpose only. This would return a named tuple showing hits, misses, maxsize and cursize of the sharding - propagator cache. + propagator cache. Note that directly calling into the sharding propagator does not share cache + state with the DTensor dispatch fast path! """ from torch.distributed.tensor._api import DTensor @@ -19,9 +21,17 @@ def _get_sharding_prop_cache_info(): ) -def _clear_sharding_prop_cache(): +def _get_fast_path_sharding_prop_cache_stats(): """ - Clears the cache for the sharding propagation cache, used for debugging purpose only. + Get a tuple (hits, misses) for the fast path sharding propagation cache, used for debugging + only. + """ + return torch._C._get_DTensor_sharding_propagator_cache_stats() + + +def _clear_python_sharding_prop_cache(): + """ + Clears the cache for the Python sharding propagation cache, used for debugging purpose only. """ from torch.distributed.tensor._api import DTensor @@ -30,6 +40,13 @@ def _clear_sharding_prop_cache(): ) +def _clear_fast_path_sharding_prop_cache(): + """ + Clears the cache for the fast path sharding propagation cache, used for debugging purpose only. + """ + torch._C._clear_DTensor_sharding_propagator_cache() + + # Set namespace for exposed private names CommDebugMode.__module__ = "torch.distributed.tensor.debug" visualize_sharding.__module__ = "torch.distributed.tensor.debug" From 5d34e5eb803d591720e675fe0f3fcf2d4d4594a0 Mon Sep 17 00:00:00 2001 From: Dev Sashidhar Date: Fri, 21 Nov 2025 15:12:49 +0000 Subject: [PATCH 162/230] Fix unused gradient tracking to respect create_graph (#168295) Fixes https://github.com/pytorch/pytorch/issues/168059 PyTorch was incorrectly setting requires_grad=True on unused gradients even when create_graph=False. This caused unnecessary autograd tracking and extra memory usage. This PR updates the logic to set requires_grad based on the value of create_graph, ensuring that unused gradients are tracked only when explicitly requested. This aligns torch.autograd.grad( ) behavior with expectations. Includes test coverage in test_unused_grad_requires_grad_with_materialize Pull Request resolved: https://github.com/pytorch/pytorch/pull/168295 Approved by: https://github.com/soulitzer --- test/test_autograd.py | 12 ++++++++++++ torch/autograd/__init__.py | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/test/test_autograd.py b/test/test_autograd.py index 5960ac8add36d..bc6967cdfb038 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -403,6 +403,18 @@ def backward(ctx, g0, g1): out = Func.apply(a)[0] out.backward() + def test_unused_grad_requires_grad_with_materialize(self): + x = torch.ones(10, requires_grad=True) + y = torch.ones(10, requires_grad=True) + z = (x**2).sum() + + g = torch.autograd.grad( + z, (x, y), allow_unused=True, materialize_grads=True, create_graph=False + ) + + self.assertFalse(g[0].requires_grad) + self.assertFalse(g[1].requires_grad) + def test_legacy_function_deprecation_exception(self): # Trigger exception class MyFunction(Function): diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py index 5c8d2664ed7db..cfab4fa5e2d5f 100644 --- a/torch/autograd/__init__.py +++ b/torch/autograd/__init__.py @@ -532,7 +532,7 @@ def vjp(gO): result = tuple( output if output is not None - else torch.zeros_like(input, requires_grad=True) + else torch.zeros_like(input, requires_grad=create_graph) for (output, input) in zip(result, inputs) ) return result From a69d3cf1ba36131fc53a3653b8014ce2c11b0ff0 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 21 Nov 2025 06:08:00 -0800 Subject: [PATCH 163/230] [BE] C++20 template instantiation adjustments (#168132) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Resolve some dependent name disambiguation issues, which results in compilation errors when compiled with MSVC using C++20 standard. Failure look something like the following ``` C:\actions-runner\_work\pytorch\pytorch\aten\src\ATen\native\cuda\Nonzero.cu(215): error: type name is not allowed (void*)pinned_num_nonzeros_h.const_data_ptr(), ^ C:\actions-runner\_work\pytorch\pytorch\aten\src\ATen\native\cuda\Nonzero.cu(215): error: expected an expression (void*)pinned_num_nonzeros_h.const_data_ptr(), ^ ``` And resolved by adding a `template` in front of it The only C++20 specific change in the [standard](https://en.cppreference.com/w/cpp/language/dependent_name.html) that I can spot are the following: > The following expressions are type-dependent : > - an expression whose any subexpression is a type-dependent expression this, if the class is a dependent type. > - an [identifier expression](https://en.cppreference.com/w/cpp/language/name.html) that is not a [concept-id](https://en.cppreference.com/w/cpp/language/constraints.html) and(since C++20) > - contains an identifier for which name lookup finds at least one dependent declaration > - contains a dependent [template-id](https://en.cppreference.com/w/cpp/language/templates.html#template-id) Also replace `*((int*)ptr + offs)` with `static_cast(ptr)[offs]` Pull Request resolved: https://github.com/pytorch/pytorch/pull/168132 Approved by: https://github.com/ngimel, https://github.com/Skylion007, https://github.com/cyyever ghstack dependencies: #168165 --- aten/src/ATen/native/cuda/LossCTC.cu | 30 +++++++++---------- aten/src/ATen/native/cuda/Nonzero.cu | 7 ++--- .../ATen/native/cuda/UpSampleBilinear2d.cu | 2 +- .../src/ATen/native/cuda/UpSampleNearest3d.cu | 2 +- .../nested/cuda/NestedTensorBinaryOps.cu | 2 +- aten/src/ATen/native/sparse/cuda/SoftMax.cu | 14 ++++----- .../native/sparse/cuda/SparseCsrTensorMath.cu | 2 +- 7 files changed, 29 insertions(+), 30 deletions(-) diff --git a/aten/src/ATen/native/cuda/LossCTC.cu b/aten/src/ATen/native/cuda/LossCTC.cu index c6d3c25200d50..4c5eabd049687 100644 --- a/aten/src/ATen/native/cuda/LossCTC.cu +++ b/aten/src/ATen/native/cuda/LossCTC.cu @@ -242,7 +242,7 @@ std::tuple ctc_loss_gpu_template(const Tensor& log_probs, const int64_t max_target_length = 0; auto tg_batch_offsets = at::empty({batch_size}, at::device(at::kCPU).dtype(at::kLong)); - auto tg_batch_offsets_data = tg_batch_offsets.mutable_data_ptr(); + auto tg_batch_offsets_data = tg_batch_offsets.template mutable_data_ptr(); if (targets.dim() == 1) { // concatenated targets int64_t pos = 0; for (int64_t i = 0; i < batch_size; i++) { @@ -304,12 +304,12 @@ std::tuple ctc_loss_gpu_template(const Tensor& log_probs, const ctc_loss_log_alpha_gpu_kernel<<>>( log_alpha.mutable_data_ptr(), - log_probs.const_data_ptr(), input_lengths_t.const_data_ptr(), log_probs.size(0), - targets.const_data_ptr(), target_lengths_t.const_data_ptr(), max_target_length, + log_probs.const_data_ptr(), input_lengths_t.template const_data_ptr(), log_probs.size(0), + targets.const_data_ptr(), target_lengths_t.template const_data_ptr(), max_target_length, neg_log_likelihood.mutable_data_ptr(), log_probs.stride(0), log_probs.stride(1), log_probs.stride(2), log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2), - tg_batch_offsets.const_data_ptr(), tg_target_stride, + tg_batch_offsets.template const_data_ptr(), tg_target_stride, batch_size, BLANK); C10_CUDA_KERNEL_LAUNCH_CHECK(); return std::make_tuple(neg_log_likelihood, log_alpha); @@ -613,7 +613,7 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_ int64_t max_target_length; auto tg_batch_offsets = at::empty({batch_size}, TensorOptions(at::CPU(kLong))); - auto tg_batch_offsets_data = tg_batch_offsets.mutable_data_ptr(); + auto tg_batch_offsets_data = tg_batch_offsets.template mutable_data_ptr(); if (targets.dim() == 1) { // concatenated targets int64_t pos = 0; max_target_length = 0; @@ -663,11 +663,11 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_ dim3 grid(1, (batch_size+threads_batch-1)/threads_batch); ctc_loss_backward_log_beta_gpu_kernel<<>> (log_beta.mutable_data_ptr(), - log_probs.const_data_ptr(), input_lengths_t.const_data_ptr(), log_probs.size(0), - targets.const_data_ptr(), target_lengths_t.const_data_ptr(), max_target_length, + log_probs.const_data_ptr(), input_lengths_t.template const_data_ptr(), log_probs.size(0), + targets.const_data_ptr(), target_lengths_t.template const_data_ptr(), max_target_length, log_probs.stride(0), log_probs.stride(1), log_probs.stride(2), log_beta.stride(0), log_beta.stride(1), log_beta.stride(2), - tg_batch_offsets.const_data_ptr(), tg_target_stride, + tg_batch_offsets.template const_data_ptr(), tg_target_stride, batch_size, BLANK); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -717,14 +717,14 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_ (grad.mutable_data_ptr(), grad_out.const_data_ptr(), grad_out.stride(0), log_alpha.const_data_ptr(), log_beta.const_data_ptr(), - log_probs.const_data_ptr(), input_lengths_t.const_data_ptr(), - targets.const_data_ptr(), target_lengths_t.const_data_ptr(), + log_probs.const_data_ptr(), input_lengths_t.template const_data_ptr(), + targets.const_data_ptr(), target_lengths_t.template const_data_ptr(), neg_log_likelihood.const_data_ptr(), grad.stride(0), grad.stride(1), grad.stride(2), log_probs.stride(0), log_probs.stride(1), log_probs.stride(2), log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2), log_beta.stride(0), log_beta.stride(1), log_beta.stride(2), - tg_batch_offsets.const_data_ptr(), tg_target_stride, + tg_batch_offsets.template const_data_ptr(), tg_target_stride, batch_size, zero_infinity); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { // small problem, use naive algorithm @@ -740,14 +740,14 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_ (grad.mutable_data_ptr(), grad_out.const_data_ptr(), grad_out.stride(0), log_alpha.const_data_ptr(), log_beta.const_data_ptr(), - log_probs.const_data_ptr(), input_lengths_t.const_data_ptr(), log_probs.size(0), - targets.const_data_ptr(), target_lengths_t.const_data_ptr(), max_target_length, + log_probs.const_data_ptr(), input_lengths_t.template const_data_ptr(), log_probs.size(0), + targets.const_data_ptr(), target_lengths_t.template const_data_ptr(), max_target_length, neg_log_likelihood.const_data_ptr(), grad.stride(0), grad.stride(1), grad.stride(2), log_probs.stride(0), log_probs.stride(1), log_probs.stride(2), log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2), log_beta.stride(0), log_beta.stride(1), log_beta.stride(2), - tg_batch_offsets.const_data_ptr(), tg_target_stride, + tg_batch_offsets.template const_data_ptr(), tg_target_stride, batch_size, num_labels, BLANK, zero_infinity); C10_CUDA_KERNEL_LAUNCH_CHECK(); // catch launch errors } @@ -765,7 +765,7 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_ (batch_size+threads_batch-1)/threads_batch); ctc_loss_zero_padded_gradients<<>>( grad.mutable_data_ptr(), - input_lengths_t.const_data_ptr(), + input_lengths_t.template const_data_ptr(), grad.stride(0), grad.stride(1), grad.stride(2), diff --git a/aten/src/ATen/native/cuda/Nonzero.cu b/aten/src/ATen/native/cuda/Nonzero.cu index ed32e9ac45b30..d4eb1b792e7f1 100644 --- a/aten/src/ATen/native/cuda/Nonzero.cu +++ b/aten/src/ATen/native/cuda/Nonzero.cu @@ -212,7 +212,7 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) { std::nullopt /* memory format */ ); at::cuda::memcpy_and_sync( - pinned_num_nonzeros_h.data_ptr(), + pinned_num_nonzeros_h.template data_ptr(), num_nonzeros.get(), sizeof(int) * num_chunks, cudaMemcpyDeviceToHost, @@ -220,7 +220,7 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) { int64_t num_nonzeros_h = 0; for (int64_t idx = 0; idx < num_chunks; idx++) { - num_nonzeros_h += (int)*(pinned_num_nonzeros_h.const_data_ptr() + idx); + num_nonzeros_h += pinned_num_nonzeros_h.template const_data_ptr()[idx]; } // num_nonzeros_h = (int)*(pinned_num_nonzeros_h.const_data_ptr()); // expected output size is num_nonzeros x ndim @@ -267,8 +267,7 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) { ((int*)num_nonzeros.get()) + idx, remaining, stream)); - curr_nonzeros += - (int)*(pinned_num_nonzeros_h.const_data_ptr() + idx); + curr_nonzeros += pinned_num_nonzeros_h.template const_data_ptr()[idx]; } if (num_nonzeros_h > 0 && self.dim() > 1) { TensorDims dims; diff --git a/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu b/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu index b46bbaa6500b9..5ccc1143d4dc1 100644 --- a/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu +++ b/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu @@ -806,7 +806,7 @@ static void upsample_gen2d_aa_out_cuda_template( using accscalar_t = at::acc_type; auto idata = input.packed_accessor64(); - auto odata = output_c.packed_accessor64(); + auto odata = output_c.template packed_accessor64(); const accscalar_t height_scale = area_pixel_compute_scale( input_height, output_height, align_corners, scales_h); diff --git a/aten/src/ATen/native/cuda/UpSampleNearest3d.cu b/aten/src/ATen/native/cuda/UpSampleNearest3d.cu index aae4625f6a39e..159de54156dd6 100644 --- a/aten/src/ATen/native/cuda/UpSampleNearest3d.cu +++ b/aten/src/ATen/native/cuda/UpSampleNearest3d.cu @@ -189,7 +189,7 @@ static void upsample_nearest3d_out_cuda_template( using accscalar_t = at::acc_type; auto idata = input.const_data_ptr(); - auto odata = output_c.mutable_data_ptr(); + auto odata = output_c.template mutable_data_ptr(); const float depth_scale = compute_scales_value(scales_d, input_depth, output_depth); const float height_scale = compute_scales_value(scales_h, input_height, output_height); diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorBinaryOps.cu b/aten/src/ATen/native/nested/cuda/NestedTensorBinaryOps.cu index e624295642422..203dafdccfcc6 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorBinaryOps.cu +++ b/aten/src/ATen/native/nested/cuda/NestedTensorBinaryOps.cu @@ -88,7 +88,7 @@ void _nested_op_dense_esuhm_kernel(Tensor& result, const Tensor& self, const Ten const scalar_t* self_data_ptr = self_buffer.const_data_ptr(); const scalar_t* other_data_ptr = other.const_data_ptr(); scalar_t* result_data_ptr = result_buffer.data_ptr(); - int64_t* result_offsets_ptr = result_offsets.data_ptr(); + int64_t* result_offsets_ptr = result_offsets.template data_ptr(); nested_op_dense_kernelLauncher( self_data_ptr, diff --git a/aten/src/ATen/native/sparse/cuda/SoftMax.cu b/aten/src/ATen/native/sparse/cuda/SoftMax.cu index 7e3b502bf6f41..2ee8de3fd5edf 100644 --- a/aten/src/ATen/native/sparse/cuda/SoftMax.cu +++ b/aten/src/ATen/native/sparse/cuda/SoftMax.cu @@ -307,7 +307,7 @@ std::tuple compute_pool_max( int64_t* offsets_ptr = offsets.data_ptr(); auto sorted_indices = at::empty({nnz}, indices.options()); - thrust_ptr sorted_indices_thrust_ptr(sorted_indices.data_ptr()); + thrust_ptr sorted_indices_thrust_ptr(sorted_indices.template data_ptr()); thrust::sequence( policy, sorted_indices_thrust_ptr, sorted_indices_thrust_ptr + nnz, 0); @@ -326,17 +326,17 @@ std::tuple compute_pool_max( sorted_indices_thrust_ptr + nnz, thrust::make_constant_iterator(int64_t(1)), thrust::make_discard_iterator(), - thrust_ptr(pool_sizes.data_ptr()), + thrust_ptr(pool_sizes.template data_ptr()), [offsets_ptr] __device__(int64_t x, int64_t y) { return offsets_ptr[x] == offsets_ptr[y]; }); auto new_sz = thrust::distance( - thrust_ptr(pool_sizes.data_ptr()), new_end.second); + thrust_ptr(pool_sizes.template data_ptr()), new_end.second); pool_sizes.resize_({new_sz}); auto pool_offsets = pool_sizes.clone(); thrust_ptr pool_offsets_thrust_ptr( - pool_offsets.data_ptr()); + pool_offsets.template data_ptr()); thrust::exclusive_scan( policy, pool_offsets_thrust_ptr, @@ -353,9 +353,9 @@ std::tuple compute_pool_max( auto mx_buffer_ptr = mx_buffer.data_ptr(); - auto pool_sizes_ptr = pool_sizes.data_ptr(); - auto sorted_indices_ptr = sorted_indices.data_ptr(); - auto pool_offsets_ptr = pool_offsets.data_ptr(); + auto pool_sizes_ptr = pool_sizes.template data_ptr(); + auto sorted_indices_ptr = sorted_indices.template data_ptr(); + auto pool_offsets_ptr = pool_offsets.template data_ptr(); thrust::for_each( policy, diff --git a/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu index f8923dd1a61c1..bb4b095d7f12e 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu @@ -500,7 +500,7 @@ Tensor reduce_sparse_csr_dim0_cuda_template(const Tensor& sparse, ReductionOp ro AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "reduce_sparse_csr_dim0_cuda_indices", [&]() { index_t* col_indices_ptr = col_indices.data_ptr(); - index_t* new_col_indices_ptr = new_col_indices.data_ptr(); + index_t* new_col_indices_ptr = new_col_indices.template data_ptr(); reduce_sparse_csr_dim0_cuda_kernel<<>>(new_values_acc_ptr, new_col_indices_ptr, new_nnz, From f6fb8dd0583e53e1e2a2ce4bfb4a6f706d68f0ce Mon Sep 17 00:00:00 2001 From: Thanh Ha Date: Fri, 21 Nov 2025 15:29:57 +0000 Subject: [PATCH 164/230] Use r7i.4xlarge for B200 build (#167078) The build system is oversized for what is necessary. Reduce the size to optimize costs. The default workflow runner is `linux.r7i.4xlarge` so we are just removing the runner definition in the workflow so that it uses the default. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167078 Approved by: https://github.com/nWEIdia, https://github.com/seemethere --- .github/workflows/b200-distributed.yml | 2 +- .github/workflows/b200-symm-mem.yml | 2 +- .github/workflows/operator_microbenchmark.yml | 15 ++++++++++++++- .github/workflows/test-b200.yml | 4 ++-- 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/.github/workflows/b200-distributed.yml b/.github/workflows/b200-distributed.yml index 596a31431e61b..bb85a4ddfc85e 100644 --- a/.github/workflows/b200-distributed.yml +++ b/.github/workflows/b200-distributed.yml @@ -37,7 +37,7 @@ jobs: needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runner: linux.12xlarge.memory + runner: linux.r7i.4xlarge build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '10.0' diff --git a/.github/workflows/b200-symm-mem.yml b/.github/workflows/b200-symm-mem.yml index 7fa8a8a730447..ba28066dd5602 100644 --- a/.github/workflows/b200-symm-mem.yml +++ b/.github/workflows/b200-symm-mem.yml @@ -37,7 +37,7 @@ jobs: needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runner: linux.12xlarge.memory + runner: linux.r7i.4xlarge build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100-symm docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '10.0' diff --git a/.github/workflows/operator_microbenchmark.yml b/.github/workflows/operator_microbenchmark.yml index dd5cd832570f9..cd27b3a8a97db 100644 --- a/.github/workflows/operator_microbenchmark.yml +++ b/.github/workflows/operator_microbenchmark.yml @@ -18,11 +18,22 @@ permissions: contents: read jobs: + get-label-type: + name: get-label-type + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + # H100 A100 runners opmicrobenchmark-build: if: github.repository_owner == 'pytorch' name: opmicrobenchmark-build uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: runner: linux.12xlarge.memory build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 @@ -51,8 +62,10 @@ jobs: if: github.repository_owner == 'pytorch' name: opmicrobenchmark-build-b200 uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: - runner: linux.12xlarge.memory + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runner: linux.r7i.4xlarge build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '10.0' diff --git a/.github/workflows/test-b200.yml b/.github/workflows/test-b200.yml index 07fd9b18fdada..7cc935f46d6c8 100644 --- a/.github/workflows/test-b200.yml +++ b/.github/workflows/test-b200.yml @@ -54,7 +54,7 @@ jobs: needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runner: linux.12xlarge.memory + runner: linux.r7i.4xlarge build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '10.0' @@ -75,4 +75,4 @@ jobs: docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm100-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm100-build.outputs.test-matrix }} aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only - secrets: inherit \ No newline at end of file + secrets: inherit From 008ac433b06c4177e6b6b6d2a63fc4aebbc1fe74 Mon Sep 17 00:00:00 2001 From: "xinan.lin" Date: Fri, 21 Nov 2025 06:16:57 +0000 Subject: [PATCH 165/230] [Inductor XPU GEMM] Step 1/N: Refactor cutlass configuration. (#160174) This PR is the first step toward implementing RFC #160175. Currently, all Cutlass-related Torch Inductor configs are located in `torch._inductor.config.cuda`. This PR refactors the device-agnostic Cutlass configurations into `torch._inductor.config.cutlass`, so they can be shared and reused by XPU as well. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160174 Approved by: https://github.com/EikanWang, https://github.com/mlazos, https://github.com/jansel --- benchmarks/inductor_backends/cutlass.py | 2 +- test/inductor/test_cutlass_backend.py | 110 +++++++++--------- torch/_inductor/codecache.py | 30 ++--- .../codegen/cuda/cuda_cpp_scheduling.py | 2 +- torch/_inductor/codegen/cuda/cuda_template.py | 2 +- torch/_inductor/codegen/cuda/cutlass_cache.py | 2 +- torch/_inductor/codegen/cuda/cutlass_utils.py | 4 +- torch/_inductor/codegen/cuda/gemm_template.py | 24 ++-- torch/_inductor/config.py | 73 +++++++----- torch/_inductor/fuzzer.py | 2 +- torch/_inductor/select_algorithm.py | 4 +- torch/_inductor/utils.py | 8 +- 12 files changed, 141 insertions(+), 122 deletions(-) diff --git a/benchmarks/inductor_backends/cutlass.py b/benchmarks/inductor_backends/cutlass.py index b2ed506302aec..af06333038947 100644 --- a/benchmarks/inductor_backends/cutlass.py +++ b/benchmarks/inductor_backends/cutlass.py @@ -125,7 +125,7 @@ def name(self) -> str: def to_options(self) -> dict[str, Any]: return { **super().to_options(), - "cuda.cutlass_instantiation_level": self.cutlass_instantiation_level, + "cutlass.cutlass_instantiation_level": self.cutlass_instantiation_level, } diff --git a/test/inductor/test_cutlass_backend.py b/test/inductor/test_cutlass_backend.py index 55f8dd5d24ebc..673d3e87d2a5f 100644 --- a/test/inductor/test_cutlass_backend.py +++ b/test/inductor/test_cutlass_backend.py @@ -133,10 +133,10 @@ def gen_args(op, shape, dtype=torch.float16): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 1, + "cutlass.cutlass_max_profiling_configs": 1, "benchmark_epilogue_fusion": False, # EVT doesn't support benchmark fusion yet - "cuda.cutlass_tma_only": True, - "cuda.cutlass_epilogue_fusion_enabled": True, + "cutlass.cutlass_tma_only": True, + "cutlass.cutlass_epilogue_fusion_enabled": True, } ) @@ -144,9 +144,9 @@ def gen_args(op, shape, dtype=torch.float16): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 1, + "cutlass.cutlass_max_profiling_configs": 1, "benchmark_epilogue_fusion": False, # EVT doesn't support benchmark fusion yet - "cuda.cutlass_tma_only": True, + "cutlass.cutlass_tma_only": True, } ) @@ -234,8 +234,8 @@ def mm(a, b): "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", "compile_threads": 4, - "cuda.cutlass_backend_min_gemm_size": 100000, - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_backend_min_gemm_size": 100000, + "cutlass.cutlass_max_profiling_configs": 2, } ): with mock.patch( @@ -287,7 +287,7 @@ def test_cutlass_backend_subproc_mm(self): "autotune_in_subproc": True, "max_autotune_gemm_backends": "CUTLASS", "compile_threads": 4, - "cuda.cutlass_max_profiling_configs": 4, + "cutlass.cutlass_max_profiling_configs": 4, } ): Y_compiled = torch.compile(torch.mm)(a, b) @@ -324,7 +324,7 @@ def test_cutlass_backend_subproc_addmm(self, dtype): "autotune_in_subproc": True, "max_autotune_gemm_backends": "CUTLASS", "compile_threads": 4, - "cuda.cutlass_max_profiling_configs": 4, + "cutlass.cutlass_max_profiling_configs": 4, } ): for x_shape in x_shapes: @@ -354,7 +354,7 @@ def test_cutlass_backend_subproc_bmm(self): "autotune_in_subproc": True, "max_autotune_gemm_backends": "CUTLASS", "compile_threads": 4, - "cuda.cutlass_max_profiling_configs": 4, + "cutlass.cutlass_max_profiling_configs": 4, } ): Y_compiled = torch.compile(torch.bmm)(a, b) @@ -386,7 +386,7 @@ def forward(self, a, b, c): "max_autotune": True, "autotune_in_subproc": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_max_profiling_configs": 1, + "cutlass.cutlass_max_profiling_configs": 1, } ): from torch._inductor.utils import run_and_get_code @@ -428,8 +428,8 @@ def forward(self, a, b, c): "max_autotune": True, "autotune_in_subproc": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_max_profiling_configs": 1, - "cuda.cutlass_max_profiling_swizzle_options": [ + "cutlass.cutlass_max_profiling_configs": 1, + "cutlass.cutlass_max_profiling_swizzle_options": [ 1, 2, 4, @@ -505,7 +505,7 @@ def forward(self, a, b): { "max_autotune": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, } ), dynamo_config.patch({"error_on_recompile": dynamic}), @@ -595,9 +595,9 @@ def forward(self, x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale): { "max_autotune": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, "benchmark_epilogue_fusion": False, # EVT doesn't support benchmark fusion yet - "cuda.cutlass_tma_only": True, + "cutlass.cutlass_tma_only": True, } ), dynamo_config.patch({"error_on_recompile": dynamic}), @@ -677,7 +677,7 @@ def forward(self, x, a, b): { "max_autotune": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, } ), dynamo_config.patch({"error_on_recompile": dynamic}), @@ -746,7 +746,7 @@ def forward(self, a, b): { "max_autotune": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, } ): expected = [model(*input) for input in inputs] @@ -775,8 +775,8 @@ def test_max_autotune_cutlass_backend_regular_mm_streamk( "max_autotune": True, "autotune_in_subproc": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_max_profiling_configs": 2, - "cuda.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels + "cutlass.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels } ): for M, K, N in ( @@ -819,7 +819,7 @@ def test_streamk_with_dynamic( { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels + "cutlass.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels } ): with self.assertRaisesRegex(InductorError, r".*NoValidChoicesError.*"): @@ -849,8 +849,8 @@ def test_streamk_with_static( { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 1, - "cuda.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels + "cutlass.cutlass_max_profiling_configs": 1, + "cutlass.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels } ): _ = compiled_model(a, b) @@ -884,7 +884,7 @@ def _test_max_autotune_cutlass_backend_epilogue_fusion( "max_autotune": True, "autotune_in_subproc": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_max_profiling_configs": 4, + "cutlass.cutlass_max_profiling_configs": 4, "cuda.version": "12.2", # required to enable the Kernels we need } ): @@ -983,7 +983,7 @@ def mm(a, b): "max_autotune": True, "autotune_in_subproc": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, } ): Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b) @@ -1002,7 +1002,7 @@ def forward(self, x, w): "max_autotune": True, "autotune_in_subproc": False, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, } ): model = MyModel() @@ -1040,7 +1040,7 @@ def forward(self, x, w): "max_autotune": True, "autotune_in_subproc": False, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, } ): model = MyModel() @@ -1073,8 +1073,8 @@ def forward(self, x, w): "max_autotune": True, "autotune_in_subproc": False, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_op_allowlist_regex": "128x256x64.*stream_k_warpspecialized_cooperative_epi_nosmem", - "cuda.cutlass_max_profiling_configs": 1, + "cutlass.cutlass_op_allowlist_regex": "128x256x64.*stream_k_warpspecialized_cooperative_epi_nosmem", + "cutlass.cutlass_max_profiling_configs": 1, } ): model = MyModel() @@ -1117,7 +1117,7 @@ def mm(a, b): "max_autotune": True, "autotune_in_subproc": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, "autotune_local_cache": True, } ): @@ -1157,9 +1157,9 @@ def my_addmm(x, a, b, alpha, beta): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 2, - "cuda.cutlass_op_allowlist_regex": "", - "cuda.cutlass_op_denylist_regex": "pingpong", + "cutlass.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_op_allowlist_regex": "", + "cutlass.cutlass_op_denylist_regex": "pingpong", } ): with mock.patch( @@ -1202,9 +1202,9 @@ def addmm(x, a, b, alpha, beta): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 2, - "cuda.cutlass_op_allowlist_regex": "pingpong", - "cuda.cutlass_op_denylist_regex": None, + "cutlass.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_op_allowlist_regex": "pingpong", + "cutlass.cutlass_op_denylist_regex": None, } ): with mock.patch( @@ -1273,7 +1273,7 @@ def run_test(use_fast_accum): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, } ): with mock.patch( @@ -1350,7 +1350,7 @@ def test_cutlass_backend_shape_coverage_mm( { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, } ), mock.patch( @@ -1461,8 +1461,8 @@ def test_standalone_runner(self): { "max_autotune": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_max_profiling_configs": 2, - "cuda.generate_test_runner": True, # put standalone runner in the generated code + "cutlass.cutlass_max_profiling_configs": 2, + "cutlass.generate_test_runner": True, # put standalone runner in the generated code } ): from tempfile import NamedTemporaryFile @@ -1544,7 +1544,7 @@ def mm(a, b): { "max_autotune": True, "max_autotune_gemm_backends": "ATEN,TRITON,CUTLASS", - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, # needed for log searching "fx_graph_cache": False, "fx_graph_remote_cache": False, @@ -1608,8 +1608,8 @@ def counting_render(self, *args, **kwargs): "max_autotune_gemm_backends": "CUTLASS", "fx_graph_cache": False, "fx_graph_remote_cache": False, - "cuda.enable_caching_codegen": True, - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.enable_caching_codegen": True, + "cutlass.cutlass_max_profiling_configs": 2, } ): compiled_model = torch.compile(model, fullgraph=True) @@ -1660,10 +1660,10 @@ def counting_render(self, *args, **kwargs): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, "fx_graph_cache": False, "fx_graph_remote_cache": False, - "cuda.enable_caching_codegen": True, + "cutlass.enable_caching_codegen": True, } ): # Get expected results @@ -1721,10 +1721,10 @@ def counting_render(self, *args, **kwargs): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, "fx_graph_cache": False, "fx_graph_remote_cache": False, - "cuda.enable_caching_codegen": True, + "cutlass.enable_caching_codegen": True, } ): # Get expected results @@ -1752,7 +1752,7 @@ def test_cutlass_backend_matmul_same_tensor(self): { "max_autotune": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, } ): compiled = torch.compile(torch.mm) @@ -1771,7 +1771,7 @@ def test_cutlass_backend_matmul_nonzero_offset(self): { "max_autotune": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cuda.cutlass_max_profiling_configs": 2, + "cutlass.cutlass_max_profiling_configs": 2, } ): compiled = torch.compile(torch.mm) @@ -1795,7 +1795,7 @@ def forward(self, B): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 1, + "cutlass.cutlass_max_profiling_configs": 1, } ): _ = torch.compile(model)(B) @@ -1817,7 +1817,7 @@ def forward(self, B): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 1, + "cutlass.cutlass_max_profiling_configs": 1, } ): _ = torch.compile(model)(B) @@ -1845,7 +1845,7 @@ def forward(self, B): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 1, + "cutlass.cutlass_max_profiling_configs": 1, } ): _ = torch.compile(model)(B) @@ -1871,7 +1871,7 @@ def forward(self, a, b): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cuda.cutlass_max_profiling_configs": 1, + "cutlass.cutlass_max_profiling_configs": 1, } ): if use_aoti: @@ -1968,7 +1968,7 @@ def forward(self, a, b, extra_args): # baseline is cutlass kernel + triton # matches expected casting behavior - with config.patch({"cuda.cutlass_epilogue_fusion_enabled": False}): + with config.patch({"cutlass.cutlass_epilogue_fusion_enabled": False}): ref_result = torch.compile(model)(a, b, extra_args) self.assertEqual( @@ -2368,7 +2368,7 @@ def test_config_number_post_filtering(self) -> None: "max_autotune_gemm_backends": "CUTLASS", # needed for log searching "force_disable_caches": True, - "cuda.cutlass_max_profiling_swizzle_options": [2], + "cutlass.cutlass_max_profiling_swizzle_options": [2], } ): with mock.patch( diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index a30644312332b..2542d5ecefd3f 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -34,7 +34,7 @@ from tempfile import _TemporaryFileWrapper from time import time, time_ns from types import ModuleType -from typing import Any, cast, Generic, NoReturn, TYPE_CHECKING, TypeVar, Union +from typing import Any, cast, Generic, NoReturn, Optional, TYPE_CHECKING, TypeVar, Union from typing_extensions import override, Self import torch @@ -3741,7 +3741,7 @@ def _load_triton_kernel_from_source( return getattr(PyCodeCache.load(source_code), kernel_name) -def _cuda_compiler() -> str | None: +def _cuda_compiler() -> Optional[str]: if cuda_env.nvcc_exist(config.cuda.cuda_cxx): return config.cuda.cuda_cxx if config.is_fbcode(): @@ -3759,7 +3759,7 @@ def _cutlass_path() -> str: return parutil.get_dir_path("cutlass-4-headers") else: - return config.cuda.cutlass_dir + return config.cutlass.cutlass_dir def _cutlass_paths() -> list[str]: @@ -3807,7 +3807,7 @@ def cutlass_key() -> bytes: return resource_file.read().encode() combined_hash = hashlib.sha256() - build_code_hash([config.cuda.cutlass_dir], "", combined_hash) + build_code_hash([config.cutlass.cutlass_dir], "", combined_hash) return combined_hash.digest() @@ -3877,14 +3877,14 @@ def _nvcc_compiler_options() -> list[str]: "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", "-w", f"-gencode=arch=compute_{arch},code=[{','.join(code)}]", - config.cuda.compile_opt_level, + config.cutlass.compile_opt_level, "-std=c++17", "--expt-relaxed-constexpr", "-DNDEBUG", ] if config.is_fbcode(): options.extend(["-ccbin", os.path.dirname(build_paths.gcc)]) - if config.cuda.enable_debug_info: + if config.cutlass.enable_debug_info: options.extend(["-lineinfo", "-g", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"]) if config.cuda.enable_ptxas_info: options.extend( @@ -3896,7 +3896,7 @@ def _nvcc_compiler_options() -> list[str]: "--source-in-ptx", ] ) # Annotate the ptx file with source information - if config.cuda.use_fast_math: + if config.cutlass.use_fast_math: options.extend( [ "--use_fast_math", @@ -4100,7 +4100,7 @@ def write(cls, source_code: str, dst_file_ext: str) -> tuple[str, str]: Returns the hash key of source code, and the path to the file. """ - if config.cuda.cutlass_hash_with_compile_cmd: + if config.cutlass.cutlass_hash_with_compile_cmd: cuda_command = repr( cuda_compile_command(["dummy_input"], "dummy_output", dst_file_ext) ) @@ -4151,7 +4151,7 @@ def compile( output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext error_path = binary_error_path(output_path) binary_remote_cache = cls.get_kernel_binary_remote_cache( - caching_enabled=config.cuda.use_binary_remote_cache + caching_enabled=config.cutlass.use_binary_remote_cache and not config.force_disable_caches, caching_available=config.is_fbcode(), ) @@ -4166,13 +4166,13 @@ def compile( cmd_parts, error_output = json.loads(error_json) if ( binary_remote_cache is not None - and config.cuda.upload_to_binary_remote_cache + and config.cutlass.upload_to_binary_remote_cache ): # This ensures that a local error is uploaded to the remote cache, # as we make no assumptions about the remote cache having the same # information as the local cache binary_remote_cache.put( - error_path, config.cuda.binary_remote_cache_force_write + error_path, config.cutlass.binary_remote_cache_force_write ) cls.cache[key_with_ext] = CUDACodeCache.CacheEntry( input_path, output_path, error_json @@ -4236,11 +4236,11 @@ def compile( # Upload to remote cache if enabled if ( binary_remote_cache is not None - and config.cuda.upload_to_binary_remote_cache + and config.cutlass.upload_to_binary_remote_cache ): # will log on errors, but not fail out binary_remote_cache.put( - output_path, config.cuda.binary_remote_cache_force_write + output_path, config.cutlass.binary_remote_cache_force_write ) cls.cache[key_with_ext] = CUDACodeCache.CacheEntry( input_path, output_path, None @@ -4293,10 +4293,10 @@ def _record_cuda_compile_error( # Upload to remote cache directly from memory if enabled if ( binary_remote_cache is not None - and config.cuda.upload_to_binary_remote_cache + and config.cutlass.upload_to_binary_remote_cache ): binary_remote_cache.put( - error_path, config.cuda.binary_remote_cache_force_write + error_path, config.cutlass.binary_remote_cache_force_write ) diff --git a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py index 2496860ca1f7c..16b09d4ba80eb 100644 --- a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py +++ b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py @@ -257,7 +257,7 @@ def _can_fuse_epilogue_impl( ) return False elif ( - not config.cuda.cutlass_epilogue_fusion_enabled + not config.cutlass.cutlass_epilogue_fusion_enabled or not config.epilogue_fusion ): why("cutlass epilogue fusion is not enabled") diff --git a/torch/_inductor/codegen/cuda/cuda_template.py b/torch/_inductor/codegen/cuda/cuda_template.py index 79dfa9c6c391f..92c86120570d6 100644 --- a/torch/_inductor/codegen/cuda/cuda_template.py +++ b/torch/_inductor/codegen/cuda/cuda_template.py @@ -110,7 +110,7 @@ def generate_code_and_args( args are different. """ key: Optional[str] = None - if config.cuda.enable_caching_codegen: + if config.cutlass.enable_caching_codegen: key = self.make_key(name=name, input_key=input_key, layout_repr=layout_repr) if key is not None and key in self.code_cache: diff --git a/torch/_inductor/codegen/cuda/cutlass_cache.py b/torch/_inductor/codegen/cuda/cutlass_cache.py index 66db98867b413..cad4a37902304 100644 --- a/torch/_inductor/codegen/cuda/cutlass_cache.py +++ b/torch/_inductor/codegen/cuda/cutlass_cache.py @@ -75,7 +75,7 @@ def maybe_fetch_ops() -> Optional[list[Any]]: # get_cuda_version might return "12.4.0" or "12.4" # but we want to use "12.4" version: str = ".".join(get_cuda_version().split(".")[:2]) - instantiation_level: str = config.cuda.cutlass_instantiation_level + instantiation_level: str = config.cutlass.cutlass_instantiation_level # filename and filepath request_key: str = get_config_request_key(arch, version, instantiation_level) diff --git a/torch/_inductor/codegen/cuda/cutlass_utils.py b/torch/_inductor/codegen/cuda/cutlass_utils.py index fa46e8766cd58..3ce3a49bb94e9 100644 --- a/torch/_inductor/codegen/cuda/cutlass_utils.py +++ b/torch/_inductor/codegen/cuda/cutlass_utils.py @@ -98,7 +98,7 @@ def path_join(path0, path1): # contains both cutlass and cutlass_library # we need cutlass for eVT - cutlass_python_path = path_join(config.cuda.cutlass_dir, "python") + cutlass_python_path = path_join(config.cutlass.cutlass_dir, "python") torch_root = os.path.abspath(os.path.dirname(torch.__file__)) mock_src_path = os.path.join( torch_root, @@ -252,7 +252,7 @@ def _gen_ops_cached(arch, version) -> dict[Any, Any]: ) return {} arch = _normalize_cuda_arch(arch) - instantiation_level: str = config.cuda.cutlass_instantiation_level + instantiation_level: str = config.cutlass.cutlass_instantiation_level args = CUTLASSArgs( architectures=arch, cuda_version=version, diff --git a/torch/_inductor/codegen/cuda/gemm_template.py b/torch/_inductor/codegen/cuda/gemm_template.py index c4b7188bd9e62..9148ee7877d03 100644 --- a/torch/_inductor/codegen/cuda/gemm_template.py +++ b/torch/_inductor/codegen/cuda/gemm_template.py @@ -19,7 +19,7 @@ from torch._inductor.utils import clear_on_fresh_cache from ... import ir -from ...config import cuda as inductor_cuda_config +from ...config import cutlass as inductor_cutlass_config from ...ir import ( Buffer, ChoiceCaller, @@ -578,7 +578,7 @@ def _add_cutlass_gemm_choices( for name, op in ops: for ( swizzle - ) in inductor_cuda_config.cutlass_max_profiling_swizzle_options: + ) in inductor_cutlass_config.cutlass_max_profiling_swizzle_options: description = f"{name} swizzle={swizzle}" self.maybe_append_choice( choices, @@ -635,7 +635,7 @@ def header(self) -> IndentedBuffer: #include "cutlass/util/tensor_view_io.h" """ ) - if inductor_cuda_config.generate_test_runner and not is_dynamic( + if inductor_cutlass_config.generate_test_runner and not is_dynamic( *self.input_nodes, self.output_node ): res.splice(GEMM_STANDALONE_RUNNER_ADDITIONAL_INCLUDES) @@ -953,7 +953,7 @@ def filter_op( ) return None - if inductor_cuda_config.cutlass_tma_only and not self._has_tma_epilogue(op): + if inductor_cutlass_config.cutlass_tma_only and not self._has_tma_epilogue(op): return None # Set epilogue. @@ -975,14 +975,16 @@ def filter_op( return None # Apply regex filters at the end when configuration name doesn't change anymore - if inductor_cuda_config.cutlass_op_allowlist_regex: + if inductor_cutlass_config.cutlass_op_allowlist_regex: if not re.search( - inductor_cuda_config.cutlass_op_allowlist_regex, op.configuration_name() + inductor_cutlass_config.cutlass_op_allowlist_regex, + op.configuration_name(), ): return None - if inductor_cuda_config.cutlass_op_denylist_regex is not None: + if inductor_cutlass_config.cutlass_op_denylist_regex is not None: if re.search( - inductor_cuda_config.cutlass_op_denylist_regex, op.configuration_name() + inductor_cutlass_config.cutlass_op_denylist_regex, + op.configuration_name(), ): return None @@ -1035,7 +1037,7 @@ def gen_ops(self) -> "list[tuple[str, cutlass_gemm_op.GemmOperation]]": # type: time.time() - start_time, ) sorted_res = sorted(res.items()) - ret_res = sorted_res[: inductor_cuda_config.cutlass_max_profiling_configs] + ret_res = sorted_res[: inductor_cutlass_config.cutlass_max_profiling_configs] if len(self.filtered_ops_cache) < 50: self.filtered_ops_cache[self.cache_key] = ret_res else: @@ -1277,7 +1279,9 @@ def render( # type: ignore[override] } options.update(dict(zip(extra_names, extra_inputs))) res = self._template_from_string(self._get_template()).render(**options) - if inductor_cuda_config.generate_test_runner and not is_dynamic(X, W, Y, Bias): + if inductor_cutlass_config.generate_test_runner and not is_dynamic( + X, W, Y, Bias + ): test_runner_code = self._template_from_string( GEMM_STANDALONE_RUNNER_TEMPLATE ).render(**options) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index af466dc61031a..f3592b93469cd 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1822,28 +1822,13 @@ class aot_inductor_mode: compile_standalone: bool = False -class cuda: - """Settings for cuda backend, today this consists of cutlass""" - - # CUDA arch to use for CUDA template kernel compilation. - # e.g. "70", "75", "80", "90", etc. - # When arch is None, Inductor uses torch.cuda.get_device_capability(0). - arch: Optional[str] = None - - # CUDA version to use for CUDA template kernel compilation. - # e.g. "11.4", "12.1", etc. - # When version is None, Inductor uses torch.version.cuda. - version: Optional[str] = None +class cutlass: + """ + Config specific to cutlass backend. + """ - # Optimization level for the host compiler. compile_opt_level: Literal["-O0", "-O1", "-O2", "-O3", "-OS"] = "-O1" - # Whether to enable device LTO (link-time-optimization). - enable_cuda_lto = False - - # Whether to keep intermediate files dring compilation. - enable_ptxas_info = False - # Whether to enable debug info, e.g. line number, cutlass debug info. enable_debug_info = False @@ -1855,7 +1840,10 @@ class cuda: cutlass_dir = os.path.realpath( os.environ.get( "TORCHINDUCTOR_CUTLASS_DIR", - os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/"), + os.path.join( + os.path.dirname(torch.__file__), + "../third_party/cutlass/", + ), ) ) @@ -1875,14 +1863,6 @@ class cuda: # Whether to only use TMA-compatible kernels in CUTLASS cutlass_tma_only = False - # Path to CUDA NVCC. - # NVCC search order: - # 1) cuda_cxx set in this config - # 2) CUDACXX environment variable - # 3) CUDA_HOME environment variable - # 4) default system search PATH. - cuda_cxx: Optional[str] = None - # Minimum value of M*N*K to consider the CUTLASS backend for GEMM ops. cutlass_backend_min_gemm_size: int = 1 @@ -1952,6 +1932,41 @@ class cuda: enable_caching_codegen: bool = True +class cuda(cutlass): + # CUDA arch to use for CUDA template kernel compilation. + # e.g. "70", "75", "80", "90", etc. + # When arch is None, Inductor uses torch.cuda.get_device_capability(0). + arch: Optional[str] = None + + # CUDA version to use for CUDA template kernel compilation. + # e.g. "11.4", "12.1", etc. + # When version is None, Inductor uses torch.version.cuda. + version: Optional[str] = None + + # Path to CUDA NVCC. + # NVCC search order: + # 1) cuda_cxx set in this config + # 2) CUDACXX environment variable + # 3) CUDA_HOME environment variable + # 4) default system search PATH. + cuda_cxx: Optional[str] = None + + # Whether to enable device LTO (link-time-optimization). + enable_cuda_lto = False + + # Whether to keep intermediate files dring compilation. + enable_ptxas_info = False + + +class xpu(cutlass): + # Xe arch to use for SYCL template kernel compilation. + # eg. 12, 20, which corresponding to Xe12(PVC) and Xe20 (BMG) + arch: Optional[str] = None + # oneAPI version to use for SYCL template kernel compilation. + # e.g. "20250201". + version: Optional[str] = None + + class rocm: # Offload arch list for device code compilation, e.g. ["gfx90a", "gfx942"]. # If empty, the `native` arch is used @@ -2160,7 +2175,7 @@ class trace: # trace functions are not relevant to config caching "trace", # uses absolute path - "cuda.cutlass_dir", + "cutlass.cutlass_dir", # not relevant "worker_start_method", "compile_threads", diff --git a/torch/_inductor/fuzzer.py b/torch/_inductor/fuzzer.py index 152dce2026766..2d288e683be5a 100644 --- a/torch/_inductor/fuzzer.py +++ b/torch/_inductor/fuzzer.py @@ -480,7 +480,7 @@ def keys(self) -> KeysView[ComboType]: "aot_inductor.presets": DEFAULT, # Typing "cuda.arch": DEFAULT, # Out of Scope "cuda.version": DEFAULT, # Out of Scope - "cuda.cutlass_dir": DEFAULT, # Out of Scope + "cutlass.cutlass_dir": DEFAULT, # Out of Scope "cuda.cuda_cxx": DEFAULT, # Out of Scope "rocm.arch": DEFAULT, # Out of Scope "rocm.ck_supported_arch": DEFAULT, # Out of Scope diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 28cdfbf0cc7ea..625f35ba36c06 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -3570,8 +3570,8 @@ def prescreen_choices( candidates = [] if ( - config.cuda.cutlass_prescreening - and len(config.cuda.cutlass_max_profiling_swizzle_options) > 1 + config.cutlass.cutlass_prescreening + and len(config.cutlass.cutlass_max_profiling_swizzle_options) > 1 ): candidates.extend( [ diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index f029a2e73f038..59db1aeb12325 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -2023,7 +2023,7 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: from .virtualized import V gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1) - if gemm_size <= 0 or gemm_size < config.cuda.cutlass_backend_min_gemm_size: + if gemm_size <= 0 or gemm_size < config.cutlass.cutlass_backend_min_gemm_size: return False from .codegen.cuda.cutlass_utils import try_import_cutlass @@ -2044,9 +2044,9 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: if not try_import_cutlass(): log.warning( "Failed to import CUTLASS lib. Please check whether " - "_inductor.config.cuda.cutlass_dir %s is set correctly. " + "_inductor.config.cutlass.cutlass_dir %s is set correctly. " "Skipping CUTLASS backend for now.", - config.cuda.cutlass_dir, + config.cutlass.cutlass_dir, ) return False return res @@ -2054,7 +2054,7 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: def _use_cutlass_for_op(op_name: str) -> bool: """Check if CUTLASS should be used for the given operation.""" - enabled_ops = config.cuda.cutlass_enabled_ops.upper() + enabled_ops = config.cutlass.cutlass_enabled_ops.upper() if enabled_ops == "ALL": return True return op_name.upper() in [x.strip() for x in enabled_ops.split(",")] From 7556637e289d00f8aec58252c3b5b45dcfd6eb61 Mon Sep 17 00:00:00 2001 From: "xinan.lin" Date: Fri, 21 Nov 2025 06:16:57 +0000 Subject: [PATCH 166/230] [Inductor XPU GEMM] Step 2/N: Move out cutlass files from torch/_inductor/codegen/cuda (#160685) This PR is the second step toward implementing RFC https://github.com/pytorch/pytorch/issues/160175. Currently, all Cutlass-related code are located in `torch/_inductor/codege/cuda` This PR moves those file to `torch/_inductor/codegen/cutlass` so they can be shared and reused by XPU as well. Signed-off-by: xinan.lin Pull Request resolved: https://github.com/pytorch/pytorch/pull/160685 Approved by: https://github.com/EikanWang, https://github.com/mlazos ghstack dependencies: #160174 --- test/inductor/test_cutlass_backend.py | 27 +++++++++---------- test/inductor/test_cutlass_evt.py | 14 +++++----- test/test_public_bindings.py | 4 +-- torch/_inductor/codegen/common.py | 2 +- .../codegen/cuda/cuda_cpp_scheduling.py | 6 ++--- .../__init__.py | 0 .../cutlass_cache.py => cutlass/cache.py} | 8 +++--- .../codegen/{cuda => cutlass}/cuda_kernel.py | 6 ++--- .../{cuda => cutlass}/cuda_template.py | 2 +- .../{cuda => cutlass}/gemm_template.py | 14 +++++----- .../lib_extensions}/__init__.py | 0 .../cutlass_mock_imports/__init__.py | 0 .../cutlass_mock_imports/cuda/__init__.py | 0 .../cutlass_mock_imports/cuda/cuda.py | 0 .../cutlass_mock_imports/cuda/cudart.py | 0 .../cutlass_mock_imports/pydot/__init__.py | 0 .../cutlass_mock_imports/scipy/__init__.py | 0 .../cutlass_mock_imports/scipy/special.py | 0 .../lib_extensions}/evt_extensions.py | 2 +- .../gemm_operation_extensions.py | 2 +- .../python_evt.py} | 0 .../{cuda => cutlass}/serialization.py | 2 +- .../cutlass_utils.py => cutlass/utils.py} | 6 ++--- torch/_inductor/ir.py | 2 +- torch/_inductor/kernel/bmm.py | 2 +- torch/_inductor/kernel/mm.py | 2 +- torch/_inductor/select_algorithm.py | 16 ++++++----- torch/_inductor/utils.py | 2 +- 28 files changed, 62 insertions(+), 57 deletions(-) rename torch/_inductor/codegen/{cuda/cutlass_lib_extensions => cutlass}/__init__.py (100%) rename torch/_inductor/codegen/{cuda/cutlass_cache.py => cutlass/cache.py} (94%) rename torch/_inductor/codegen/{cuda => cutlass}/cuda_kernel.py (99%) rename torch/_inductor/codegen/{cuda => cutlass}/cuda_template.py (99%) rename torch/_inductor/codegen/{cuda => cutlass}/gemm_template.py (99%) rename torch/_inductor/codegen/{cuda/cutlass_lib_extensions/cutlass_mock_imports => cutlass/lib_extensions}/__init__.py (100%) create mode 100644 torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/__init__.py rename torch/_inductor/codegen/{cuda/cutlass_lib_extensions => cutlass/lib_extensions}/cutlass_mock_imports/cuda/__init__.py (100%) rename torch/_inductor/codegen/{cuda/cutlass_lib_extensions => cutlass/lib_extensions}/cutlass_mock_imports/cuda/cuda.py (100%) rename torch/_inductor/codegen/{cuda/cutlass_lib_extensions => cutlass/lib_extensions}/cutlass_mock_imports/cuda/cudart.py (100%) rename torch/_inductor/codegen/{cuda/cutlass_lib_extensions => cutlass/lib_extensions}/cutlass_mock_imports/pydot/__init__.py (100%) rename torch/_inductor/codegen/{cuda/cutlass_lib_extensions => cutlass/lib_extensions}/cutlass_mock_imports/scipy/__init__.py (100%) rename torch/_inductor/codegen/{cuda/cutlass_lib_extensions => cutlass/lib_extensions}/cutlass_mock_imports/scipy/special.py (100%) rename torch/_inductor/codegen/{cuda/cutlass_lib_extensions => cutlass/lib_extensions}/evt_extensions.py (99%) rename torch/_inductor/codegen/{cuda/cutlass_lib_extensions => cutlass/lib_extensions}/gemm_operation_extensions.py (99%) rename torch/_inductor/codegen/{cuda/cutlass_python_evt.py => cutlass/python_evt.py} (100%) rename torch/_inductor/codegen/{cuda => cutlass}/serialization.py (99%) rename torch/_inductor/codegen/{cuda/cutlass_utils.py => cutlass/utils.py} (99%) diff --git a/test/inductor/test_cutlass_backend.py b/test/inductor/test_cutlass_backend.py index 673d3e87d2a5f..212795c2d4925 100644 --- a/test/inductor/test_cutlass_backend.py +++ b/test/inductor/test_cutlass_backend.py @@ -14,7 +14,9 @@ from typing import Optional from torch._dynamo.exc import BackendCompilerFailed -from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer +from torch._inductor.codegen.cutlass.serialization import ( + get_cutlass_operation_serializer, +) from torch._inductor.utils import clear_caches from torch.export import Dim from torch.testing._internal.logging_utils import log_settings @@ -32,11 +34,8 @@ from torch._dynamo import config as dynamo_config from torch._dynamo.utils import counters from torch._inductor import config -from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller -from torch._inductor.codegen.cuda.cutlass_utils import ( - _gen_ops_cached, - get_max_alignment, -) +from torch._inductor.codegen.cutlass.cuda_kernel import CUDATemplateCaller +from torch._inductor.codegen.cutlass.utils import _gen_ops_cached, get_max_alignment from torch._inductor.exc import InductorError from torch._inductor.ir import FixedLayout from torch._inductor.select_algorithm import NoValidChoicesError @@ -206,7 +205,7 @@ def run_evt_test(self, model, op, shape, num_fusions=1): def test_check_paths(self): cutlass_mock_imports_path = os.path.join( os.path.dirname(torch.__file__), - "_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports", + "_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports", ) cutlass_mock_cuda_path = os.path.join(cutlass_mock_imports_path, "cuda") cutlass_mock_pydot_path = os.path.join(cutlass_mock_imports_path, "pydot") @@ -251,7 +250,7 @@ def mm(a, b): @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_import_cutlass(self): - from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass + from torch._inductor.codegen.cutlass.utils import try_import_cutlass self.assertTrue(try_import_cutlass()) @@ -259,7 +258,7 @@ def test_import_cutlass(self): import cutlass_library # noqa: F401 def test_cutlass_key(self): - from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass + from torch._inductor.codegen.cutlass.utils import try_import_cutlass self.assertTrue(try_import_cutlass()) from torch._inductor.codecache import cutlass_key @@ -1467,7 +1466,7 @@ def test_standalone_runner(self): ): from tempfile import NamedTemporaryFile - from torch._inductor.codegen.cuda.cutlass_utils import ( + from torch._inductor.codegen.cutlass.utils import ( cuda_standalone_runner_compile_command, CUDACompileSourceCapturingContext, ) @@ -1553,7 +1552,7 @@ def mm(a, b): with ( log_settings("+inductor"), self.assertLogs( - logger="torch._inductor.codegen.cuda", level=logging.DEBUG + logger="torch._inductor.codegen.cutlass", level=logging.DEBUG ) as test_log, ): Y_compiled = torch.compile(mm, dynamic=False)(a, b) @@ -1591,7 +1590,7 @@ def forward(self, A, B): expected = model(A, B) # Track render calls - from torch._inductor.codegen.cuda.gemm_template import CUTLASSGemmTemplate + from torch._inductor.codegen.cutlass.gemm_template import CUTLASSGemmTemplate original_render = CUTLASSGemmTemplate.render render_call_count = 0 @@ -1645,7 +1644,7 @@ def forward(self, a, b, c, d): d = torch.randn(64, 128).cuda().half().t() # Track render calls - from torch._inductor.codegen.cuda.gemm_template import CUTLASSGemmTemplate + from torch._inductor.codegen.cutlass.gemm_template import CUTLASSGemmTemplate original_render = CUTLASSGemmTemplate.render render_call_count = 0 @@ -1706,7 +1705,7 @@ def forward(self, a, b): b = torch.randn(32, 64).cuda().half().t() # Track render calls - from torch._inductor.codegen.cuda.gemm_template import CUTLASSGemmTemplate + from torch._inductor.codegen.cutlass.gemm_template import CUTLASSGemmTemplate original_render = CUTLASSGemmTemplate.render render_call_count = 0 diff --git a/test/inductor/test_cutlass_evt.py b/test/inductor/test_cutlass_evt.py index 862aeb5db1c88..dd296b7f75ac7 100644 --- a/test/inductor/test_cutlass_evt.py +++ b/test/inductor/test_cutlass_evt.py @@ -5,7 +5,7 @@ import torch from torch._dynamo.test_case import TestCase -from torch._inductor.codegen.cuda.cutlass_utils import ( +from torch._inductor.codegen.cutlass.utils import ( torch_dtype_to_cutlass_type, try_import_cutlass, ) @@ -28,7 +28,7 @@ DataType = cutlass_lib.DataType from cutlass_cppgen.backend.evt.ir.tensor import Tensor as CutlassTensor - from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import ( + from torch._inductor.codegen.cutlass.lib_extensions.evt_extensions import ( _render_argument_type, _trace, trace, @@ -107,7 +107,7 @@ class TestCutlassEVT(TestCase): @unittest.skipIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_py_codegen_accumulator_return(self): - from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen + from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen from torch._inductor.virtualized import V size = (100, 300, 200) @@ -164,7 +164,7 @@ def fn(accum, buf1, buf2): @unittest.skipIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_py_codegen_disjoint_read_indexing(self): - from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen + from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen from torch._inductor.virtualized import V size = (100, 300, 200) @@ -213,7 +213,7 @@ def inner_fn_buf4(index): @unittest.skipIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_py_codegen_broadcasting(self): - from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen + from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen from torch._inductor.virtualized import V size = (100, 300, 200) @@ -273,7 +273,7 @@ def fn(accum, buf1, buf2): @unittest.skipIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_py_codegen(self): - from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen + from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen from torch._inductor.virtualized import V size = (100, 300, 200) @@ -329,7 +329,7 @@ def fn(accum, buf1, buf2): @unittest.skipIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_example_tensor_creation(self): - from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import ( + from torch._inductor.codegen.cutlass.lib_extensions.evt_extensions import ( create_example_tensors, ) from torch._inductor.virtualized import V diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 7a9f8f3aa317f..d175a205935a7 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -292,7 +292,7 @@ def onerror(modname): # do not get imported by public code. # DO NOT add public modules here. private_allowlist = { - "torch._inductor.codegen.cuda.cuda_kernel", + "torch._inductor.codegen.cutlass.cuda_kernel", # TODO(#133647): Remove the onnx._internal entries after # onnx and onnxscript are installed in CI. "torch.onnx._internal.exporter", @@ -357,7 +357,7 @@ def onerror(modname): "torch.testing._internal.distributed.rpc.rpc_test", "torch.testing._internal.distributed.rpc.tensorpipe_rpc_agent_test_fixture", "torch.testing._internal.distributed.rpc_utils", - "torch._inductor.codegen.cuda.cuda_template", + "torch._inductor.codegen.cutlass.cuda_template", "torch._inductor.codegen.cutedsl._cutedsl_utils", "torch._inductor.codegen.cuda.gemm_template", "torch._inductor.codegen.cpp_template", diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 8b5e68780cb28..617b0a91d67a0 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -2657,7 +2657,7 @@ def _bound_variable(self, name: str, *args: Any, **kwargs: Any) -> ValueRanges[A """ from ..bounds import ValueRangeAnalysis from ..select_algorithm import TritonTemplateKernel - from .cuda.cuda_kernel import CUDATemplateKernel + from .cutlass.cuda_kernel import CUDATemplateKernel if isinstance(V.kernel, TritonTemplateKernel): return ValueRanges.unknown() diff --git a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py index 16b09d4ba80eb..591a95b18f252 100644 --- a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py +++ b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from typing import cast -from torch._inductor.codegen.cuda.cutlass_python_evt import ( +from torch._inductor.codegen.cutlass.python_evt import ( CutlassEVTCodegen, MockCutlassHandler, ) @@ -267,9 +267,7 @@ def _can_fuse_epilogue_impl( return False try: - from torch._inductor.codegen.cuda.cutlass_python_evt import ( - CutlassEVTCodegen, - ) + from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen CutlassEVTCodegen.ir_to_evt_python_code( cuda_template_buffer.get_name(), diff --git a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py b/torch/_inductor/codegen/cutlass/__init__.py similarity index 100% rename from torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py rename to torch/_inductor/codegen/cutlass/__init__.py diff --git a/torch/_inductor/codegen/cuda/cutlass_cache.py b/torch/_inductor/codegen/cutlass/cache.py similarity index 94% rename from torch/_inductor/codegen/cuda/cutlass_cache.py rename to torch/_inductor/codegen/cutlass/cache.py index cad4a37902304..9de1a6257c2d1 100644 --- a/torch/_inductor/codegen/cuda/cutlass_cache.py +++ b/torch/_inductor/codegen/cutlass/cache.py @@ -10,9 +10,11 @@ import torch._inductor.config as config from torch._inductor.codecache import cutlass_key -from torch._inductor.codegen.cuda import cutlass_utils, serialization from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch, get_cuda_version -from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer +from torch._inductor.codegen.cutlass import serialization, utils +from torch._inductor.codegen.cutlass.serialization import ( + get_cutlass_operation_serializer, +) from torch._inductor.runtime.cache_dir_utils import cache_dir from torch._inductor.utils import clear_on_fresh_cache @@ -39,7 +41,7 @@ def get_file_hash(file_module): return hashlib.sha256(f.read()).hexdigest() serialization_hash = get_file_hash(serialization) - cutlass_utils_hash = get_file_hash(cutlass_utils) + cutlass_utils_hash = get_file_hash(utils) hash_target = "-".join( [ diff --git a/torch/_inductor/codegen/cuda/cuda_kernel.py b/torch/_inductor/codegen/cutlass/cuda_kernel.py similarity index 99% rename from torch/_inductor/codegen/cuda/cuda_kernel.py rename to torch/_inductor/codegen/cutlass/cuda_kernel.py index 97643ef00a7bd..9622dc759a6c4 100644 --- a/torch/_inductor/codegen/cuda/cuda_kernel.py +++ b/torch/_inductor/codegen/cutlass/cuda_kernel.py @@ -16,7 +16,7 @@ from torch._inductor.utils import do_bench_using_profiling, OrderedSet, Placeholder from torch.utils._sympy.value_ranges import ValueRanges -from .cutlass_utils import DTYPE_TO_CUTLASS_TYPE +from .utils import DTYPE_TO_CUTLASS_TYPE if TYPE_CHECKING: @@ -47,7 +47,7 @@ if TYPE_CHECKING: - from torch._inductor.codegen.cuda.cuda_template import CUDATemplate + from torch._inductor.codegen.cutlass.cuda_template import CUDATemplate log = logging.getLogger(__name__) @@ -424,7 +424,7 @@ def cutlass_dtype(self, node: IRNode, default_dtype="void") -> Optional[str]: # Helper method, called into from CUTLASSGemmTemplate if node is None: return default_dtype - from torch._inductor.codegen.cuda.cuda_template import CUTLASSTemplate + from torch._inductor.codegen.cutlass.cuda_template import CUTLASSTemplate return CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype] diff --git a/torch/_inductor/codegen/cuda/cuda_template.py b/torch/_inductor/codegen/cutlass/cuda_template.py similarity index 99% rename from torch/_inductor/codegen/cuda/cuda_template.py rename to torch/_inductor/codegen/cutlass/cuda_template.py index 92c86120570d6..384713f157062 100644 --- a/torch/_inductor/codegen/cuda/cuda_template.py +++ b/torch/_inductor/codegen/cutlass/cuda_template.py @@ -20,7 +20,7 @@ from ...virtualized import V from ..common import KernelTemplate from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel -from .cutlass_utils import DTYPE_TO_CUTLASS_TYPE +from .utils import DTYPE_TO_CUTLASS_TYPE if TYPE_CHECKING: diff --git a/torch/_inductor/codegen/cuda/gemm_template.py b/torch/_inductor/codegen/cutlass/gemm_template.py similarity index 99% rename from torch/_inductor/codegen/cuda/gemm_template.py rename to torch/_inductor/codegen/cutlass/gemm_template.py index 9148ee7877d03..58f9622571dcc 100644 --- a/torch/_inductor/codegen/cuda/gemm_template.py +++ b/torch/_inductor/codegen/cutlass/gemm_template.py @@ -11,7 +11,7 @@ import torch import torch.utils._pytree as pytree from torch._inductor.autotune_process import TensorMeta -from torch._inductor.codegen.cuda.cutlass_cache import maybe_fetch_ops +from torch._inductor.codegen.cutlass.cache import maybe_fetch_ops from torch._inductor.codegen.wrapper import PythonWrapperCodegen from torch._inductor.runtime.runtime_utils import dynamo_timed from torch._inductor.scheduler import BaseSchedulerNode @@ -32,11 +32,11 @@ from ...utils import is_dynamic, Placeholder from ...virtualized import V from ..common import IndentedBuffer -from . import cutlass_utils +from . import utils as cutlass_utils from .cuda_kernel import CUDATemplateKernel from .cuda_template import CUTLASSTemplate -from .cutlass_python_evt import CutlassEVTCodegen, scaled_mm_evt -from .cutlass_utils import ( +from .python_evt import CutlassEVTCodegen, scaled_mm_evt +from .utils import ( ACCUMULATOR_DTYPES, dtype_match, torch_dtype_to_cutlass_type, @@ -1474,7 +1474,7 @@ def _render_evt( output_dtype: torch.dtype, accumulator_dtype: torch.dtype, ) -> tuple[str, str, str, EVTArgRenames]: - from .cutlass_lib_extensions.evt_extensions import create_example_tensors, trace + from .lib_extensions.evt_extensions import create_example_tensors, trace acc_dtype = torch_dtype_to_cutlass_type(accumulator_dtype) output_dtype = torch_dtype_to_cutlass_type(output_dtype) @@ -1561,7 +1561,7 @@ def _define_gemm_instance( assert cutlass_utils.try_import_cutlass() import cutlass_library.library as cutlass_lib - from .cutlass_lib_extensions import gemm_operation_extensions as gemm_extensions + from .lib_extensions import gemm_operation_extensions as gemm_extensions emitter = gemm_extensions.EmitGemmUniversal3xInstanceWithEVT(evt_name=evt_name) # type: ignore[call-arg] @@ -1701,6 +1701,8 @@ def clone_with_transposed_stride(node: IRNode) -> IRNode: class CUTLASS2xGemmTemplate(CUTLASSGemmTemplate): + """CUTLASS 2x GEMM Template, which is used to generate CUTLASS GEMM kernels""" + def __init__( self, input_nodes: list[Buffer], diff --git a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/__init__.py b/torch/_inductor/codegen/cutlass/lib_extensions/__init__.py similarity index 100% rename from torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/__init__.py rename to torch/_inductor/codegen/cutlass/lib_extensions/__init__.py diff --git a/torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/__init__.py b/torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__init__.py b/torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/cuda/__init__.py similarity index 100% rename from torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__init__.py rename to torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/cuda/__init__.py diff --git a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/cuda.py b/torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/cuda/cuda.py similarity index 100% rename from torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/cuda.py rename to torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/cuda/cuda.py diff --git a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/cudart.py b/torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/cuda/cudart.py similarity index 100% rename from torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/cudart.py rename to torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/cuda/cudart.py diff --git a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/pydot/__init__.py b/torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/pydot/__init__.py similarity index 100% rename from torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/pydot/__init__.py rename to torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/pydot/__init__.py diff --git a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/__init__.py b/torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/scipy/__init__.py similarity index 100% rename from torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/__init__.py rename to torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/scipy/__init__.py diff --git a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/special.py b/torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/scipy/special.py similarity index 100% rename from torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/special.py rename to torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/scipy/special.py diff --git a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py b/torch/_inductor/codegen/cutlass/lib_extensions/evt_extensions.py similarity index 99% rename from torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py rename to torch/_inductor/codegen/cutlass/lib_extensions/evt_extensions.py index 472438fec90e3..c1daa78228ba1 100644 --- a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py +++ b/torch/_inductor/codegen/cutlass/lib_extensions/evt_extensions.py @@ -10,7 +10,7 @@ ) from torch.utils._ordered_set import OrderedSet -from ..cutlass_utils import torch_dtype_to_cutlass_type, try_import_cutlass +from ..utils import torch_dtype_to_cutlass_type, try_import_cutlass EpilogueFunctor = Any # EpilogueFunctor local class defined in _trace diff --git a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py b/torch/_inductor/codegen/cutlass/lib_extensions/gemm_operation_extensions.py similarity index 99% rename from torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py rename to torch/_inductor/codegen/cutlass/lib_extensions/gemm_operation_extensions.py index 95af1a968a97c..d10669d40bea0 100644 --- a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py +++ b/torch/_inductor/codegen/cutlass/lib_extensions/gemm_operation_extensions.py @@ -1,5 +1,5 @@ # mypy: ignore-errors -from ..cutlass_utils import try_import_cutlass +from ..utils import try_import_cutlass # copied / modified from original at diff --git a/torch/_inductor/codegen/cuda/cutlass_python_evt.py b/torch/_inductor/codegen/cutlass/python_evt.py similarity index 100% rename from torch/_inductor/codegen/cuda/cutlass_python_evt.py rename to torch/_inductor/codegen/cutlass/python_evt.py diff --git a/torch/_inductor/codegen/cuda/serialization.py b/torch/_inductor/codegen/cutlass/serialization.py similarity index 99% rename from torch/_inductor/codegen/cuda/serialization.py rename to torch/_inductor/codegen/cutlass/serialization.py index a17f04b0a1b5a..39184e4e6e2c6 100644 --- a/torch/_inductor/codegen/cuda/serialization.py +++ b/torch/_inductor/codegen/cutlass/serialization.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Any, Optional -from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass +from torch._inductor.codegen.cutlass.utils import try_import_cutlass class CUTLASSOperationSerializer: diff --git a/torch/_inductor/codegen/cuda/cutlass_utils.py b/torch/_inductor/codegen/cutlass/utils.py similarity index 99% rename from torch/_inductor/codegen/cuda/cutlass_utils.py rename to torch/_inductor/codegen/cutlass/utils.py index 3ce3a49bb94e9..56e02edbb99d5 100644 --- a/torch/_inductor/codegen/cuda/cutlass_utils.py +++ b/torch/_inductor/codegen/cutlass/utils.py @@ -23,7 +23,7 @@ from ...runtime.runtime_utils import cache_dir from ...virtualized import V from ..cpp_utils import DTYPE_TO_CPP -from .cuda_env import get_cuda_arch, get_cuda_version +from ..cuda.cuda_env import get_cuda_arch, get_cuda_version log = logging.getLogger(__name__) @@ -104,8 +104,8 @@ def path_join(path0, path1): torch_root, "_inductor", "codegen", - "cuda", - "cutlass_lib_extensions", + "cutlass", + "lib_extensions", "cutlass_mock_imports", ) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 67e0174443882..d13182e717494 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -119,7 +119,7 @@ from torch.fx.experimental.symbolic_shapes import SympyBoolean from torch.fx.node import Argument - from .codegen.cuda.cuda_template import CUDATemplate + from .codegen.cutlass.cuda_template import CUDATemplate from .codegen.wrapper import PythonWrapperCodegen from .graph import GraphLowering from .utils import IndentedBuffer diff --git a/torch/_inductor/kernel/bmm.py b/torch/_inductor/kernel/bmm.py index a155d35b5d059..7aeed4d8b92a9 100644 --- a/torch/_inductor/kernel/bmm.py +++ b/torch/_inductor/kernel/bmm.py @@ -262,7 +262,7 @@ def _to_dtype(x): and use_cutlass_template(layout, m, n, k) and _use_cutlass_for_op(name) ): - from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate + from ..codegen.cutlass.gemm_template import CUTLASS3xGemmTemplate CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( choices, layout, kernel_inputs.nodes() diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index 5b57c458f46e6..1dd6c2fbfcd75 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -20,7 +20,7 @@ from torch.torch_version import TorchVersion from .. import config as inductor_config, distributed_autotune -from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate +from ..codegen.cutlass.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate from ..codegen.rocm.ck_tile_universal_gemm_template import CKTileGemmTemplate from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate from ..codegen.subgraph import SubgraphChoiceCaller, SubgraphTemplate diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 625f35ba36c06..aab68f3d0a744 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -2707,7 +2707,7 @@ def __call__( best_config_future=None, return_choice=False, # TODO: return_choice is temporary and will be refactored soon ): - from .codegen.cuda.cuda_kernel import CUDATemplateCaller + from .codegen.cutlass.cuda_kernel import CUDATemplateCaller # Run preprocessing functions on choices for preprocessing_fn in self.preprocessing_fns: @@ -3223,7 +3223,7 @@ def wait_on_futures(): "select_algorithm_num_precompilation_exceptions" ] += 1 exceptions.append((futures[future], e)) - from torch._inductor.codegen.cuda.cuda_kernel import ( + from torch._inductor.codegen.cutlass.cuda_kernel import ( CUDATemplateCaller, ) @@ -3410,7 +3410,9 @@ def benchmark_choices( try: timing = cls.benchmark_choice(choice, autotune_args) except CUDACompileError: - from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller + from torch._inductor.codegen.cutlass.cuda_kernel import ( + CUDATemplateCaller, + ) if not isinstance(choice, CUDATemplateCaller): log.exception( @@ -3421,7 +3423,9 @@ def benchmark_choices( log.warning("Not yet implemented", exc_info=True) timing = float("inf") except RuntimeError as e: - from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller + from torch._inductor.codegen.cutlass.cuda_kernel import ( + CUDATemplateCaller, + ) msg = str(e) if "invalid argument" in msg: @@ -3566,7 +3570,7 @@ def prescreen_choices( return prescreen_winners # prescreen cutlass - from .codegen.cuda.cuda_kernel import CUDATemplateCaller + from .codegen.cutlass.cuda_kernel import CUDATemplateCaller candidates = [] if ( @@ -3600,7 +3604,7 @@ def prune_choices_postscreen( """ Prune the choices after prescreening. """ - from .codegen.cuda.cuda_kernel import CUDATemplateCaller + from .codegen.cutlass.cuda_kernel import CUDATemplateCaller prescreen_key = f"{name}:{inputs_key}" diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 59db1aeb12325..4f3ff7e01879f 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -2025,7 +2025,7 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1) if gemm_size <= 0 or gemm_size < config.cutlass.cutlass_backend_min_gemm_size: return False - from .codegen.cuda.cutlass_utils import try_import_cutlass + from .codegen.cutlass.utils import try_import_cutlass # Do not use cutlass template on ROCm if torch.version.hip: From f3b068633f1ea37c014abfce33c40e1e3b484452 Mon Sep 17 00:00:00 2001 From: Chinmay Kuchinad Date: Fri, 21 Nov 2025 16:17:10 +0000 Subject: [PATCH 167/230] Skipping few distributed tests for 2 GPU setups (#168265) Skipping test/distributed/tensor/test_math_ops.py::DistMathOpsTestWithLocalTensor::test_std and test/distributed/tensor/test_redistribute.py::RedistributeTest::test_redistribute_shard_dim_change_float32 if less than 4 gpus are present. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168265 Approved by: https://github.com/jeffdaily --- test/distributed/tensor/test_dtensor_dispatch_overhead.py | 2 ++ test/distributed/tensor/test_math_ops.py | 1 + test/distributed/tensor/test_redistribute.py | 2 ++ 3 files changed, 5 insertions(+) diff --git a/test/distributed/tensor/test_dtensor_dispatch_overhead.py b/test/distributed/tensor/test_dtensor_dispatch_overhead.py index 7d08725205e60..ab9b578b80f93 100644 --- a/test/distributed/tensor/test_dtensor_dispatch_overhead.py +++ b/test/distributed/tensor/test_dtensor_dispatch_overhead.py @@ -10,6 +10,7 @@ import torch from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor import distribute_tensor, DTensor, Shard +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -65,6 +66,7 @@ class DistOpDispatchOverHead(DTensorTestBase): def world_size(self) -> int: return 4 + @skip_if_lt_x_gpu(4) @with_comms def test_dtensor_add_op_dispatch_overhead(self): if torch.cuda.is_available(): diff --git a/test/distributed/tensor/test_math_ops.py b/test/distributed/tensor/test_math_ops.py index 5eb92a44188e6..2922c5ff85960 100644 --- a/test/distributed/tensor/test_math_ops.py +++ b/test/distributed/tensor/test_math_ops.py @@ -1037,6 +1037,7 @@ def test_matching_partial_reduction_ops(self): self.assertTrue(out_with_redistribute.placements[0].is_replicate()) self.assertEqual(out_without_redistribute, out_with_redistribute) + @skip_if_lt_x_gpu(4) @with_comms def test_std(self): mesh = DeviceMesh(self.device_type, torch.arange(4).reshape(2, 2)) diff --git a/test/distributed/tensor/test_redistribute.py b/test/distributed/tensor/test_redistribute.py index 86bb567a39616..ec1d69e9b02e6 100644 --- a/test/distributed/tensor/test_redistribute.py +++ b/test/distributed/tensor/test_redistribute.py @@ -23,6 +23,7 @@ from torch.distributed.tensor._dtensor_spec import ShardOrderEntry from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.placement_types import _StridedShard, MaskPartial +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -507,6 +508,7 @@ def test_redistribute_uneven_sharding(self): dt_full_tensor = dt.full_tensor() self.assertEqual(dt_full_tensor, input_tensor) + @skip_if_lt_x_gpu(4) @with_comms @parametrize("dtype", [torch.float32, torch.cfloat]) def test_redistribute_shard_dim_change(self, dtype): From 2d7ea6c78a5d07f127663c4c35c7ef1904bb9005 Mon Sep 17 00:00:00 2001 From: amdfaa <107946068+amdfaa@users.noreply.github.com> Date: Fri, 21 Nov 2025 16:27:32 +0000 Subject: [PATCH 168/230] Add rocm-navi31 to the upload test stats file (#168359) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168359 Approved by: https://github.com/jeffdaily --- .github/workflows/upload-test-stats.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/upload-test-stats.yml b/.github/workflows/upload-test-stats.yml index b3d8073aad3b3..3a0567f33c8cc 100644 --- a/.github/workflows/upload-test-stats.yml +++ b/.github/workflows/upload-test-stats.yml @@ -18,6 +18,7 @@ on: - rocm-mi200 - rocm-mi300 - rocm-mi355 + - rocm-navi31 - inductor-micro-benchmark - inductor-micro-benchmark-x86 - inductor-cu124 From 1871a24a798ae6870606f9753bb947c7e8a97611 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Thu, 20 Nov 2025 11:33:18 -0800 Subject: [PATCH 169/230] Add shim for getCurrentBlasHandle (#168276) 1. Start a cuda shim_common.cpp 2. Add a getCurrentBlasHandle shim, which is needed to support vLLM cuda kernels. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168276 Approved by: https://github.com/eqy, https://github.com/albanD --- build_variables.bzl | 1 + .../csrc/test_cublas_handle.cu | 16 ++++++++++++++++ .../libtorch_agnostic_2_10/ops.py | 7 +++++++ .../libtorch_agnostic_2_10_extension/setup.py | 2 +- test/cpp_extensions/test_libtorch_agnostic.py | 9 +++++++++ torch/csrc/cuda/shim_common.cpp | 9 +++++++++ torch/csrc/stable/c/shim.h | 7 +++++++ 7 files changed, 50 insertions(+), 1 deletion(-) create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_cublas_handle.cu create mode 100644 torch/csrc/cuda/shim_common.cpp diff --git a/build_variables.bzl b/build_variables.bzl index 258e739300c1e..420616311b32c 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -731,6 +731,7 @@ libtorch_cuda_core_sources = [ "torch/csrc/cuda/comm.cpp", "torch/csrc/cuda/memory_snapshot.cpp", "torch/csrc/cuda/CUDAPluggableAllocator.cpp", + "torch/csrc/cuda/shim_common.cpp", "torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp", "torch/csrc/inductor/aoti_torch/shim_cuda.cpp", "torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp", diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_cublas_handle.cu b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_cublas_handle.cu new file mode 100644 index 0000000000000..439cb8e24ddb0 --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_cublas_handle.cu @@ -0,0 +1,16 @@ +#include +#include + +void* my_get_curr_cuda_blas_handle() { + void* ret_handle; + TORCH_ERROR_CODE_CHECK(torch_get_current_cuda_blas_handle(&ret_handle)); + return ret_handle; +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def("my_get_curr_cuda_blas_handle() -> int"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { + m.impl("my_get_curr_cuda_blas_handle", TORCH_BOX(&my_get_curr_cuda_blas_handle)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py index 8d05741869ebd..102e22e668cdf 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py @@ -239,3 +239,10 @@ def get_template_any_data_ptr(t, dtype, mutable) -> int: return torch.ops.libtorch_agnostic_2_10.get_template_any_data_ptr.default( t, dtype, mutable ) + + +def my_get_curr_cuda_blas_handle() -> int: + """ + Return the current cuBlasHandle_t pointer value. + """ + return torch.ops.libtorch_agnostic_2_10.my_get_curr_cuda_blas_handle.default() diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py index ff2aeff5e932b..7bc37ba238139 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py @@ -45,7 +45,7 @@ def get_extension(): # allow including if torch.cuda.is_available(): extra_compile_args["cxx"].append("-DLAE_USE_CUDA") - extra_compile_args["nvcc"] = ["-O2"] + extra_compile_args["nvcc"] = ["-O2", "-DUSE_CUDA"] extension = CUDAExtension sources.extend(CSRC_DIR.glob("**/*.cu")) diff --git a/test/cpp_extensions/test_libtorch_agnostic.py b/test/cpp_extensions/test_libtorch_agnostic.py index f24731ee5666a..55681a45e4445 100644 --- a/test/cpp_extensions/test_libtorch_agnostic.py +++ b/test/cpp_extensions/test_libtorch_agnostic.py @@ -825,6 +825,15 @@ def test_get_template_any_data_ptr(self, device): t, rdtype, mutable ) + @skipIfTorchVersionLessThan(2, 10) + @onlyCUDA + def test_my_get_curr_cuda_blas_handle(self, device): + import libtorch_agnostic_2_10 as libtorch_agnostic + + res = libtorch_agnostic.ops.my_get_curr_cuda_blas_handle() + expected = torch.cuda.current_blas_handle() + self.assertEqual(res, expected) + instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": diff --git a/torch/csrc/cuda/shim_common.cpp b/torch/csrc/cuda/shim_common.cpp new file mode 100644 index 0000000000000..cb5f28dba0152 --- /dev/null +++ b/torch/csrc/cuda/shim_common.cpp @@ -0,0 +1,9 @@ +#include +#include +#include + +AOTITorchError torch_get_current_cuda_blas_handle(void** ret_handle) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + *(cublasHandle_t*)(ret_handle) = at::cuda::getCurrentCUDABlasHandle(); + }); +} diff --git a/torch/csrc/stable/c/shim.h b/torch/csrc/stable/c/shim.h index 99b3b435cf550..83bdfd59096fe 100644 --- a/torch/csrc/stable/c/shim.h +++ b/torch/csrc/stable/c/shim.h @@ -103,6 +103,13 @@ AOTI_TORCH_EXPORT AOTITorchError torch_get_const_data_ptr( const void** ret_data_ptr // returns borrowed reference ); +#ifdef USE_CUDA + +AOTI_TORCH_EXPORT AOTITorchError +torch_get_current_cuda_blas_handle(void** ret_handle); + +#endif // USE_CUDA + #endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 #ifdef __cplusplus From 107ab1c58606b01f9785e35f9ce3afd39aead824 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Fri, 21 Nov 2025 16:55:44 +0000 Subject: [PATCH 170/230] control_plane: add handler for WaitCounters (#167871) Summary: This adds a DebugServer handler for WaitCounters such that we can access all wait counters live via HTTP. To do so we register a wait counter backend that tracks all counter values in a shared synchronized map. When creating a counter this will acquire the lock to add it to the global map but during runtime it only uses atomic operation. Test Plan: ``` //caffe2/test/distributed/elastic:test_control_plane ``` Differential Revision: D87095718 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167871 Approved by: https://github.com/fduwjj --- build_variables.bzl | 1 + .../distributed/elastic/test_control_plane.py | 71 ++++++++- test/distributed/test_debug.py | 1 + .../c10d/control_plane/Handlers.cpp | 18 +++ .../c10d/control_plane/WaitCounterHandler.cpp | 138 ++++++++++++++++++ .../c10d/control_plane/WaitCounterHandler.hpp | 15 ++ torch/distributed/debug/_frontend.py | 8 + 7 files changed, 251 insertions(+), 1 deletion(-) create mode 100644 torch/csrc/distributed/c10d/control_plane/WaitCounterHandler.cpp create mode 100644 torch/csrc/distributed/c10d/control_plane/WaitCounterHandler.hpp diff --git a/build_variables.bzl b/build_variables.bzl index 420616311b32c..ba856c5a97ba4 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -521,6 +521,7 @@ libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/comm.cpp", "torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp", "torch/csrc/distributed/c10d/control_plane/Handlers.cpp", + "torch/csrc/distributed/c10d/control_plane/WaitCounterHandler.cpp", "torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp", "torch/csrc/distributed/c10d/cuda/StreamBlock.cpp", "torch/csrc/distributed/c10d/debug.cpp", diff --git a/test/distributed/elastic/test_control_plane.py b/test/distributed/elastic/test_control_plane.py index 9b31cf3b1755b..d6168a71f752b 100644 --- a/test/distributed/elastic/test_control_plane.py +++ b/test/distributed/elastic/test_control_plane.py @@ -6,6 +6,7 @@ import pickle import socket import tempfile +import unittest from contextlib import contextmanager from urllib3.connection import HTTPConnection @@ -15,7 +16,13 @@ TORCH_WORKER_SERVER_SOCKET, worker_main, ) -from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase +from torch.monitor import _WaitCounter +from torch.testing._internal.common_utils import ( + IS_FBCODE, + requires_cuda, + run_tests, + TestCase, +) class UnixHTTPConnection(HTTPConnection): @@ -216,6 +223,68 @@ def test_get_handler_names(self) -> None: names = _get_handler_names() self.assertIn("ping", names) + @unittest.skipIf(IS_FBCODE, "disabled in FBCODE") + def test_wait_counter_values(self) -> None: + """ + Test that WaitCounter values are properly tracked and returned by the handler. + + Note: This test may trigger an ASAN heap-use-after-free error during process + shutdown due to static destruction order issues with boost regex in the logging + framework. The test assertions pass successfully before this shutdown error occurs. + """ + with local_worker_server() as pool: + # Create and use a WaitCounter with a specific name + counter_name = "test_counter" + counter = _WaitCounter(counter_name) + + # Use the counter multiple times to generate metrics + # Note: Using minimal/no sleep to avoid timing issues + for i in range(3): + with counter.guard(): + pass # Minimal work + + # Query the wait counter values + resp = pool.request("POST", "/handler/wait_counter_values") + self.assertEqual(resp.status, 200) + + # Parse the JSON response + data = json.loads(resp.data) + # Should be a dictionary + self.assertIsInstance(data, dict) + + # Verify our test counter appears in the response + self.assertIn( + counter_name, + data, + f"Counter '{counter_name}' not found in response. Available counters: {list(data.keys())}", + ) + + # Verify the counter has expected metrics + counter_data = data[counter_name] + self.assertIn("active_count", counter_data) + self.assertIn("total_calls", counter_data) + self.assertIn("total_time_us", counter_data) + self.assertIn("max_time_us", counter_data) + + # Verify the counter was called 3 times + self.assertEqual( + counter_data["total_calls"], + 3, + f"Expected 3 calls, got {counter_data['total_calls']}", + ) + + # Verify active_count is 0 (no active waiters) + self.assertEqual( + counter_data["active_count"], + 0, + f"Expected 0 active, got {counter_data['active_count']}", + ) + + # total_time_us and max_time_us may be 0 or very small for fast operations + # Just verify they exist and are non-negative + self.assertGreaterEqual(counter_data["total_time_us"], 0) + self.assertGreaterEqual(counter_data["max_time_us"], 0) + if __name__ == "__main__": run_tests() diff --git a/test/distributed/test_debug.py b/test/distributed/test_debug.py index ff6a203bcf160..e1612d7639a13 100644 --- a/test/distributed/test_debug.py +++ b/test/distributed/test_debug.py @@ -40,6 +40,7 @@ def fetch(path: str) -> str: self.assertIn("View 0", fetch("/profile?duration=0.01")) self.assertIn("test_basics", fetch("/stacks")) self.assertIn("pg_status", fetch("/fr_trace")) + self.assertIn("Rank 0", fetch("/wait_counters")) if torch.cuda.is_available(): self.assertIn("pg_status", fetch("/fr_trace_nccl")) diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.cpp b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp index fe8f831a23bb1..5e5c3195046cb 100644 --- a/torch/csrc/distributed/c10d/control_plane/Handlers.cpp +++ b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp @@ -11,6 +11,8 @@ #include #include +#include + namespace c10d::control_plane { namespace { @@ -73,6 +75,22 @@ RegisterHandler frTracehandler( res.setStatus(200); }); +RegisterHandler waitCounterHandler{ + "wait_counter_values", + [](const Request&, Response& res) { + // Get all wait counter values from our tracking backend + res.setContent(getWaitCounterValuesJson(), "application/json"); + res.setStatus(200); + }}; + +#if !defined(FBCODE_CAFFE2) +// Initialize the wait counter backend +[[maybe_unused]] static bool init_backend = []() { + ensureWaitCounterBackendRegistered(); + return true; +}(); +#endif + } // namespace void registerHandler(const std::string& name, HandlerFunc f) { diff --git a/torch/csrc/distributed/c10d/control_plane/WaitCounterHandler.cpp b/torch/csrc/distributed/c10d/control_plane/WaitCounterHandler.cpp new file mode 100644 index 0000000000000..194901cea6837 --- /dev/null +++ b/torch/csrc/distributed/c10d/control_plane/WaitCounterHandler.cpp @@ -0,0 +1,138 @@ +#include + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +namespace c10d::control_plane { + +namespace { + +// Data structure to hold counter metrics +struct CounterData { + std::atomic active_count{0}; + std::atomic total_calls{0}; + std::atomic total_time_us{0}; + std::atomic max_time_us{0}; +}; + +// Holder struct for the counter data map +struct CounterDataMapHolder { + c10::Synchronized< + std::unordered_map>> + map; +}; + +// Leaky singleton to avoid static destruction order issues +CounterDataMapHolder* getCounterDataMapHolder() { + static CounterDataMapHolder* holder = new CounterDataMapHolder(); + return holder; +} + +// Backend implementation that tracks counter metrics +class TrackingBackend : public c10::monitor::detail::WaitCounterBackendIf { + public: + explicit TrackingBackend(std::string key) : key_(std::move(key)) { + // Get or create counter data for this key + getCounterDataMapHolder()->map.withLock([&](auto& map) { + auto it = map.find(key_); + if (it == map.end()) { + data_ = std::make_shared(); + map[key_] = data_; + } else { + data_ = it->second; + } + }); + } + + intptr_t start(std::chrono::steady_clock::time_point now) noexcept override { + data_->active_count.fetch_add(1, std::memory_order_relaxed); + data_->total_calls.fetch_add(1, std::memory_order_relaxed); + // Return the start time as the context + return static_cast( + std::chrono::duration_cast( + now.time_since_epoch()) + .count()); + } + + void stop(std::chrono::steady_clock::time_point now, intptr_t ctx) noexcept + override { + // Calculate duration from the stored start time + auto start_ns = std::chrono::nanoseconds(ctx); + auto start_time = std::chrono::steady_clock::time_point(start_ns); + auto duration_us = + std::chrono::duration_cast(now - start_time) + .count(); + + data_->active_count.fetch_sub(1, std::memory_order_relaxed); + data_->total_time_us.fetch_add(duration_us, std::memory_order_relaxed); + + // Update max_time_us using compare-and-swap + int64_t current_max = data_->max_time_us.load(std::memory_order_relaxed); + while (duration_us > current_max) { + if (data_->max_time_us.compare_exchange_weak( + current_max, duration_us, std::memory_order_relaxed)) { + break; + } + } + } + + private: + std::string key_; + std::shared_ptr data_; +}; + +// Factory for creating tracking backends +class TrackingBackendFactory + : public c10::monitor::detail::WaitCounterBackendFactoryIf { + public: + std::unique_ptr create( + std::string_view key) noexcept override { + return std::make_unique(std::string(key)); + } +}; + +} // namespace + +// Ensures the wait counter backend is registered +// NOTE: This function is in the c10d::control_plane namespace, NOT anonymous +void ensureWaitCounterBackendRegistered() { + static c10::once_flag once; + c10::call_once(once, []() { + c10::monitor::detail::registerWaitCounterBackend( + std::make_unique()); + }); +} + +// Returns all wait counter values as a JSON string +// NOTE: This function is in the c10d::control_plane namespace, NOT anonymous +std::string getWaitCounterValuesJson() { + nlohmann::json j = nlohmann::json::object(); + + getCounterDataMapHolder()->map.withLock([&](const auto& map) { + for (const auto& [name, data] : map) { + nlohmann::json counter_obj = nlohmann::json::object(); + counter_obj["active_count"] = + data->active_count.load(std::memory_order_relaxed); + counter_obj["total_calls"] = + data->total_calls.load(std::memory_order_relaxed); + counter_obj["total_time_us"] = + data->total_time_us.load(std::memory_order_relaxed); + counter_obj["max_time_us"] = + data->max_time_us.load(std::memory_order_relaxed); + j[name] = std::move(counter_obj); + } + }); + + return j.dump(); +} + +} // namespace c10d::control_plane diff --git a/torch/csrc/distributed/c10d/control_plane/WaitCounterHandler.hpp b/torch/csrc/distributed/c10d/control_plane/WaitCounterHandler.hpp new file mode 100644 index 0000000000000..417e4d21edbd0 --- /dev/null +++ b/torch/csrc/distributed/c10d/control_plane/WaitCounterHandler.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include + +namespace c10d { +namespace control_plane { + +// Returns all wait counter values as a JSON string +std::string getWaitCounterValuesJson(); + +// Ensures the wait counter backend is registered +void ensureWaitCounterBackendRegistered(); + +} // namespace control_plane +} // namespace c10d diff --git a/torch/distributed/debug/_frontend.py b/torch/distributed/debug/_frontend.py index 622c41ca8bd64..10dae4c2802cd 100644 --- a/torch/distributed/debug/_frontend.py +++ b/torch/distributed/debug/_frontend.py @@ -96,6 +96,7 @@ def format_json(blob: str): FlightRecorder FlightRecorder NCCL torch profiler + Wait Counters
@@ -257,6 +258,7 @@ def __init__(self, port: int): "/fr_trace": self._handle_fr_trace, "/fr_trace_nccl": self._handle_fr_trace_nccl, "/profile": self._handle_profiler, + "/wait_counters": self._handle_wait_counters, } # Create HTTP server @@ -346,6 +348,12 @@ def _handle_profiler(self, req: HTTPRequestHandler) -> bytes: return self._render_template("profile.html", addrs=addrs, resps=resps) + def _handle_wait_counters(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("wait_counter_values") + return self._render_template( + "json_resp.html", title="Wait Counters", addrs=addrs, resps=resps + ) + def main(port: int) -> None: server = FrontendServer(port=port) From 2f9040434579fa068e8a013d6d9d8bec75dec3ad Mon Sep 17 00:00:00 2001 From: Isalia20 Date: Fri, 21 Nov 2025 17:53:36 +0000 Subject: [PATCH 171/230] [MPS] fix broadcasting issues for mul on sparse tensors (#168112) fix broadcasting issues for mul on sparse tensors and enable the test responsible for testing it Pull Request resolved: https://github.com/pytorch/pytorch/pull/168112 Approved by: https://github.com/malfet --- .../native/sparse/mps/SparseMPSTensorMath.mm | 113 +++++++++++++----- .../sparse/mps/kernels/SparseTensorMath.metal | 2 +- test/test_sparse.py | 1 - 3 files changed, 81 insertions(+), 35 deletions(-) diff --git a/aten/src/ATen/native/sparse/mps/SparseMPSTensorMath.mm b/aten/src/ATen/native/sparse/mps/SparseMPSTensorMath.mm index 3da1cb5da53c8..3b8fd096f495c 100644 --- a/aten/src/ATen/native/sparse/mps/SparseMPSTensorMath.mm +++ b/aten/src/ATen/native/sparse/mps/SparseMPSTensorMath.mm @@ -488,32 +488,58 @@ Tensor addmm_sparse_dense_mps( TORCH_CHECK(t_.sparse_dim() == src_.sparse_dim(), "mul(sparse, sparse): must have same sparse_dim, got ", t_.sparse_dim(), " vs ", src_.sparse_dim()); - TORCH_CHECK(t_.sizes().equals(src_.sizes()), - "mul(sparse, sparse): sizes must match exactly (no broadcasting)."); - // Coalesce and early-exit on structurally empty operands + // Coalesce and structural info auto lhs = t_.coalesce(); auto rhs = src_.coalesce(); const int64_t lhs_nnz = lhs._nnz(); const int64_t rhs_nnz = rhs._nnz(); - if (!lhs_nnz || !rhs_nnz) { - r_.resize_as_(lhs); - return r_.zero_(); - } + const int64_t sd = lhs.sparse_dim(); // dtype checks and promotion auto commonDtype = at::result_type(lhs, rhs); TORCH_CHECK(canCast(commonDtype, r_.scalar_type()), "Can't convert result type ", commonDtype, " to output ", r_.scalar_type()); - const int64_t ndim_i = lhs.sparse_dim(); + // sparse sizes must match exactly, dense tails may broadcast + TORCH_CHECK(lhs.sizes().slice(0, sd).equals(rhs.sizes().slice(0, sd)), + "mul(sparse, sparse): sparse sizes must match exactly."); + + // dense tails and broadcasted dense tail + auto lhs_dense = lhs.sizes().slice(sd); + auto rhs_dense = rhs.sizes().slice(sd); + std::vector out_dense_vec = at::infer_size(lhs_dense, rhs_dense); + at::IntArrayRef out_dense(out_dense_vec); + + // full output sizes: [sparse_sizes] + [out_dense] + std::vector out_sizes; + out_sizes.reserve(sd + static_cast(out_dense.size())); + out_sizes.insert(out_sizes.end(), lhs.sizes().begin(), lhs.sizes().begin() + sd); + out_sizes.insert(out_sizes.end(), out_dense.begin(), out_dense.end()); + r_.sparse_resize_(out_sizes, sd, static_cast(out_dense.size())); + + const auto device = r_.device(); + + // if either is structurally empty, produce an empty sparse result with correct shape + if (!lhs_nnz || !rhs_nnz) { + Tensor out_indices = at::empty({sd, 0}, at::device(device).dtype(at::kLong)); + + std::vector out_val_sizes; + out_val_sizes.reserve(1 + out_dense.size()); + out_val_sizes.push_back(0); + out_val_sizes.insert(out_val_sizes.end(), out_dense.begin(), out_dense.end()); + + Tensor out_values = at::empty(out_val_sizes, at::device(device).dtype(r_.scalar_type())); + + alias_into_sparse(r_, out_indices, out_values); + r_._coalesced_(true); + return r_; + } - // ndim_i == 0, at most one structural entry - if (ndim_i == 0) { - r_.resize_as_(lhs); + if (sd == 0) { const bool has = (lhs_nnz && rhs_nnz); - auto out_indices = lhs._indices().narrow(1, 0, has ? 1 : 0); + auto out_indices = at::empty({0, has ? 1 : 0}, lhs._indices().options()); Tensor lhs_vals = lhs._values().to(commonDtype); Tensor rhs_vals = rhs._values().to(commonDtype); @@ -531,7 +557,6 @@ Tensor addmm_sparse_dense_mps( } // General path, intersect keys, then gather + multiply on GPU - const auto device = r_.device(); auto stream = getCurrentMPSStream(); auto lhs_indices = lhs._indices().contiguous(); @@ -540,8 +565,8 @@ Tensor addmm_sparse_dense_mps( auto rhs_values = rhs._values().to(commonDtype).contiguous(); // Flatten sparse indices to keys - auto lhs_keys = flatten_indices(lhs_indices, lhs.sizes().slice(0, ndim_i)); - auto rhs_keys = flatten_indices(rhs_indices, rhs.sizes().slice(0, ndim_i)); + auto lhs_keys = flatten_indices(lhs_indices, lhs.sizes().slice(0, sd)); + auto rhs_keys = flatten_indices(rhs_indices, rhs.sizes().slice(0, sd)); // Intersect sorted keys (search the shorter in the longer) const bool A_is_lhs = (lhs_nnz <= rhs_nnz); @@ -555,35 +580,49 @@ Tensor addmm_sparse_dense_mps( const auto M = static_cast(M_int64); // number of structural matches - r_.resize_as_(lhs); + auto lhs_match = outA_idx.narrow(0, 0, M_int64); + auto rhs_match = outB_idx.narrow(0, 0, M_int64); - auto out_indices = at::empty({ndim_i, static_cast(M)}, at::device(device).dtype(at::kLong)); - auto lhs_match = outA_idx.narrow(0, 0, M); - auto rhs_match = outB_idx.narrow(0, 0, M); - auto dense_sizes_vec = lhs.sizes().slice(ndim_i).vec(); int64_t cols64 = 1; - for (auto s : dense_sizes_vec) cols64 *= s; + for (auto s : out_dense) cols64 *= s; const uint32_t cols = static_cast(std::max(cols64, 1)); - auto to2d = [&](Tensor t, int64_t nnz) -> Tensor { - const int64_t t_cols = t.numel() / nnz; - if (t_cols == cols64) { - return t.view({nnz, cols64}); + // to broadcast [nnz, *in_dense] -> [nnz, *out_dense] -> [nnz, cols] + auto broadcast_to_out2d = [&](const Tensor& vals, int64_t nnz, at::IntArrayRef in_dense) -> Tensor { + const int64_t d_in = in_dense.size(); + const int64_t d_out = out_dense.size(); + + std::vector view_shape; + view_shape.reserve(1 + d_out); + view_shape.push_back(nnz); + for (int64_t i = 0; i < d_out - d_in; ++i) { + view_shape.push_back(1); } - return t.view({nnz, 1}).expand({nnz, cols64}).contiguous(); + view_shape.insert(view_shape.end(), in_dense.begin(), in_dense.end()); + + std::vector expand_shape; + expand_shape.reserve(1 + d_out); + expand_shape.push_back(nnz); + expand_shape.insert(expand_shape.end(), out_dense.begin(), out_dense.end()); + + Tensor v = vals.view(view_shape).expand(expand_shape); + return (cols64 > 0) ? v.contiguous().view({nnz, cols64}) + : v.contiguous().view({nnz, 0}); }; - // make both sides 2d [nnz, cols] buffers so the kernel can index it - auto lhs_vals2d = to2d(lhs_values, lhs_nnz); - auto rhs_vals2d = to2d(rhs_values, rhs_nnz); + // make both sides broadcasted 2d [nnz, cols] buffers so the kernel can index it + auto lhs_vals2d = broadcast_to_out2d(lhs_values, lhs_nnz, lhs_dense); + auto rhs_vals2d = broadcast_to_out2d(rhs_values, rhs_nnz, rhs_dense); std::vector out_val_sizes; - out_val_sizes.reserve(1 + dense_sizes_vec.size()); + out_val_sizes.reserve(1 + out_dense.size()); out_val_sizes.push_back(static_cast(M)); - out_val_sizes.insert(out_val_sizes.end(), dense_sizes_vec.begin(), dense_sizes_vec.end()); + out_val_sizes.insert(out_val_sizes.end(), out_dense.begin(), out_dense.end()); auto out_values = at::empty(out_val_sizes, lhs_values.options()); - if (M > 0) { + Tensor out_indices; + if (M > 0 && cols64 > 0) { + out_indices = at::empty({sd, M}, at::device(device).dtype(at::kLong)); dispatch_sync_with_rethrow(stream->queue(), ^() { @autoreleasepool { auto pso = lib.getPipelineStateForFunc( @@ -602,11 +641,19 @@ Tensor addmm_sparse_dense_mps( lhs_match, rhs_match, lhs_indices, out_indices, out_values, - std::array{static_cast(ndim_i), static_cast(lhs_nnz)}, + std::array{static_cast(sd), static_cast(lhs_nnz)}, std::array{M, cols}); [enc dispatchThreads:grid threadsPerThreadgroup:tgs]; } }); + } else if (M > 0) { + // just select the matching coordinates + Tensor src_indices_for_out = A_is_lhs ? lhs_indices : rhs_indices; + Tensor src_match_for_out = A_is_lhs ? lhs_match : rhs_match; + out_indices = src_indices_for_out.index_select(1, src_match_for_out); + } else { + // M == 0 + out_indices = at::empty({sd, 0}, at::device(device).dtype(at::kLong)); } if (r_.scalar_type() != commonDtype) { diff --git a/aten/src/ATen/native/sparse/mps/kernels/SparseTensorMath.metal b/aten/src/ATen/native/sparse/mps/kernels/SparseTensorMath.metal index dbd1a4548f9ee..96993de59e5f3 100644 --- a/aten/src/ATen/native/sparse/mps/kernels/SparseTensorMath.metal +++ b/aten/src/ATen/native/sparse/mps/kernels/SparseTensorMath.metal @@ -313,7 +313,7 @@ INSTANTIATE_DENSE_SPARSE_MUL(float2); constant uint2& dims_output [[buffer(8)]], \ uint3 gid [[thread_position_in_grid]]); -INSTANTIATE_FOR_FLOAT_TYPES(INSTANTIATE_FUSED_GATHER_MUL); +INSTANTIATE_FOR_ALL_TYPES(INSTANTIATE_FUSED_GATHER_MUL); #define INSTANTIATE_SPMM_BMM_COO_ROWS_GROUPED(DTYPE) \ diff --git a/test/test_sparse.py b/test/test_sparse.py index 21530352cef9a..dfe127092aead 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -3918,7 +3918,6 @@ def _test_mul_skips(self, device, dtype, coalesced): self.skipTest(f"Test with dtype={dtype}, device={device} runs only with coalesced inputs") @coalescedonoff - @expectedFailureMPS # NOTE: addcmul_out is not implemented for bool. @dtypes(*all_types_and_complex_and(torch.bfloat16, torch.float16)) @dtypesIfMPS(*all_mps_types()) From 28e8803e61e0cad2be0f45b6a81871c72de97a66 Mon Sep 17 00:00:00 2001 From: Isalia20 Date: Fri, 21 Nov 2025 17:54:51 +0000 Subject: [PATCH 172/230] [MPS] enable sparse mm test (#168156) Enable sparse mm test Pull Request resolved: https://github.com/pytorch/pytorch/pull/168156 Approved by: https://github.com/malfet --- test/test_sparse.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_sparse.py b/test/test_sparse.py index dfe127092aead..58398a915ff17 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -1680,7 +1680,6 @@ def fn(S, D1, D2, beta=beta, alpha=alpha): test_shape(7, 8, 9, 20, True, (1, 1)) @coalescedonoff - @expectedFailureMPS @dtypes(torch.double) @dtypesIfMPS(torch.float32) @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupported triggers assertion error") @@ -1698,7 +1697,9 @@ def test_shape(d1, d2, d3, nnz, transposed): def fn(S, D): return torch.sparse.mm(S, D) - gradcheck(fn, (S, D), masked=True) + + kwargs = {"eps": 1e-4, "atol": 2e-5} if device == "mps:0" else {} + gradcheck(fn, (S, D), masked=True, **kwargs) test_shape(7, 8, 9, 20, False) test_shape(7, 8, 9, 20, True) From e13220b0b43e8a331d55c44a1fd70fd3d8845058 Mon Sep 17 00:00:00 2001 From: eqy Date: Fri, 21 Nov 2025 18:31:54 +0000 Subject: [PATCH 173/230] [CUDA] Update minimum NVIDIA driver version requirement in Green Context test (#168188) We seem to have runners that have a system toolkit version of 12.4 + driver version 550 which is causing the test to fail? Pull Request resolved: https://github.com/pytorch/pytorch/pull/168188 Approved by: https://github.com/Skylion007 --- torch/testing/_internal/common_cuda.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index e7c673caeaf17..0fe9813d51b34 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -89,7 +89,7 @@ def evaluate_platform_supports_green_context(): driver_version = torch.utils.collect_env.get_nvidia_driver_version(torch.utils.collect_env.run) if driver_version is None: return False - return int(driver_version.split('.')[0]) >= 550 + return int(driver_version.split('.')[0]) >= 570 PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_flash_attention()) PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_efficient_attention()) From 7717bbaccdd38ce0098765e4e655f33d8002c87c Mon Sep 17 00:00:00 2001 From: Lucy Qiu Date: Fri, 21 Nov 2025 18:37:19 +0000 Subject: [PATCH 174/230] Add template for add_overflows (#168035) Summary: Check for non uint64_t add overflows. See usage in D87115901. Afterwards, update the pytorch pin in executorch and then land the security patch. Test Plan: CI Differential Revision: D87272275 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168035 Approved by: https://github.com/larryliu0820 --- c10/util/safe_numerics.h | 45 ++++++++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/c10/util/safe_numerics.h b/c10/util/safe_numerics.h index 32ffca52e4864..bfdb968ff96ab 100644 --- a/c10/util/safe_numerics.h +++ b/c10/util/safe_numerics.h @@ -3,6 +3,7 @@ #include #include +#include // GCC has __builtin_mul_overflow from before it supported __has_builtin #ifdef _MSC_VER @@ -15,31 +16,45 @@ namespace c10 { -C10_ALWAYS_INLINE bool add_overflows(uint64_t a, uint64_t b, uint64_t* out) { +template , int> = 0> +C10_ALWAYS_INLINE bool add_overflows(T a, T b, T* out) { #if C10_HAS_BUILTIN_OVERFLOW() return __builtin_add_overflow(a, b, out); #else - unsigned long long tmp; -#if defined(_M_IX86) || defined(_M_X64) - auto carry = _addcarry_u64(0, a, b, &tmp); -#else - tmp = a + b; - unsigned long long vector = (a & b) ^ ((a ^ b) & ~tmp); - auto carry = vector >> 63; -#endif - *out = tmp; - return carry; + if constexpr (std::is_signed_v) { + // For signed types, detect overflow by checking sign changes + volatile T tmp = a + b; + *out = tmp; + + // If both operands have the same sign, check if result changed sign + // unexpectedly. + if ((a > 0) == (b > 0)) { + if ((a > 0) && (tmp <= 0)) { + return true; // Positive overflow + } + if ((a < 0) && (tmp >= 0)) { + return true; // Negative overflow + } + } + return false; + } else { + // For unsigned types, overflow causes wrap-around + volatile T tmp = a + b; + *out = tmp; + return (tmp < a || tmp < b); + } #endif } -template +C10_ALWAYS_INLINE bool add_overflows(uint64_t a, uint64_t b, uint64_t* out) { + return add_overflows(a, b, out); +} + +template , int> = 0> C10_ALWAYS_INLINE bool mul_overflows(T a, T b, T* out) { #if C10_HAS_BUILTIN_OVERFLOW() return __builtin_mul_overflow(a, b, out); #else - static_assert( - std::is_integral_v, "mul_overflows only supports integral types"); - if constexpr (std::is_signed_v) { // For signed types, use the division-based check volatile T tmp = a * b; From 402968ee90990c31bd9e1459f3544df8473c203d Mon Sep 17 00:00:00 2001 From: eqy Date: Fri, 21 Nov 2025 18:51:47 +0000 Subject: [PATCH 175/230] [cuDNN][TF32][DTensor][TEST] Turn off TF32 for DTensor conv test (#168187) For #168085, using CI to verify that the failure is actually suppressed We don't seem to see this internally, likely due to being on a newer cuDNN version which uses updated heuristics and goes to a different kernel. Ideally we would like to use the `tf32_on_and_off` decorator instead, but this doesn't seem to play nicely with the DTensor test base class? Pull Request resolved: https://github.com/pytorch/pytorch/pull/168187 Approved by: https://github.com/albanD --- test/distributed/tensor/test_convolution_ops.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/distributed/tensor/test_convolution_ops.py b/test/distributed/tensor/test_convolution_ops.py index 0a06bd66df5e8..ed1cc60802e70 100644 --- a/test/distributed/tensor/test_convolution_ops.py +++ b/test/distributed/tensor/test_convolution_ops.py @@ -14,6 +14,7 @@ Shard, ) from torch.nn import functional as F +from torch.testing._internal.common_cuda import with_tf32_off from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( create_local_tensor_test_class, @@ -230,6 +231,7 @@ def test_conv3d(self): out_dt, out = self._run_single_arg_fwd(model, x, [Shard(0)]) self.assertEqual(out_dt, out) + @with_tf32_off @with_comms def test_conv2d_no_bias_compile(self): """Test Conv2d with bias=False in compile mode (Issue #167091) @@ -262,7 +264,7 @@ def conv_fn(x, w): self.assertEqual(result_compiled.shape, torch.Size([1, 8, 5, 5])) # Verify numerical correctness - torch.testing.assert_close(result_compiled.to_local(), result_eager.to_local()) + self.assertEqual(result_compiled.to_local(), result_eager.to_local()) @with_comms def test_conv2d_no_bias_backward(self): From 80b57a655e07dc26437e601477168be135c4c4db Mon Sep 17 00:00:00 2001 From: Dzmitry Huba Date: Thu, 20 Nov 2025 16:45:53 -0800 Subject: [PATCH 176/230] Add allgather_base and reduce_scatter_base collective implementations to Local Tensor (#168314) Both collectives are required to run FSDPv2 under local tensor Pull Request resolved: https://github.com/pytorch/pytorch/pull/168314 Approved by: https://github.com/dolpm, https://github.com/ezyang --- test/distributed/test_local_tensor.py | 64 +++++++++++++++++ torch/distributed/_local_tensor/__init__.py | 4 ++ torch/distributed/_local_tensor/_c10d.py | 76 +++++++++++++++++++++ 3 files changed, 144 insertions(+) diff --git a/test/distributed/test_local_tensor.py b/test/distributed/test_local_tensor.py index d4c1a7333bf34..a4773a3f8da72 100644 --- a/test/distributed/test_local_tensor.py +++ b/test/distributed/test_local_tensor.py @@ -373,6 +373,70 @@ def test_all_gather_collective(self): self.assertEqual(tensor_list[1], different_tensors[1]) self.assertEqual(tensor_list[2], different_tensors[2]) + def test_reduce_scatter_tensor_collective(self): + """Test that reduce_scatter_tensor collective operation works correctly with LocalTensor.""" + # Create different tensors for each rank + different_tensors = { + 0: torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), + 1: torch.tensor([[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]]), + 2: torch.tensor([[100.0, 200.0], [300.0, 400.0], [500.0, 600.0]]), + } + + fake_pg = torch.distributed.distributed_c10d._get_default_group() + + # Test reduce_scatter_tensor + with LocalTensorMode(self.world_size): + lt_reduce_scatter = LocalTensor(different_tensors) + lt_reduce_scatter_size = lt_reduce_scatter.size() + lt_output_tensor = torch.zeros( + lt_reduce_scatter_size[0] // fake_pg.size(), + *lt_reduce_scatter_size[1:], + dtype=lt_reduce_scatter.dtype, + device=lt_reduce_scatter.device, + ) + + dist.reduce_scatter_tensor( + lt_output_tensor, lt_reduce_scatter, group=fake_pg + ) + + expected_output = LocalTensor( + { + 0: torch.tensor([[111.0, 222.0]]), + 1: torch.tensor([[333.0, 444.0]]), + 2: torch.tensor([[555.0, 666.0]]), + } + ) + print(lt_output_tensor) + self.assertEqual(lt_output_tensor, expected_output) + + def test_all_gather_into_tensor_collective(self): + """Test that all_gather_into_tensor collective operation works correctly with LocalTensor.""" + # Create different tensors for each rank + different_tensors = { + 0: torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), + 1: torch.tensor([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]]), + 2: torch.tensor([[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]]), + } + + fake_pg = torch.distributed.distributed_c10d._get_default_group() + + # Test all_gather_into_tensor + with LocalTensorMode(self.world_size): + lt_gather = LocalTensor(different_tensors) + lt_gather_size = lt_gather.size() + lt_output_tensor = torch.zeros( + lt_gather_size[0] * fake_pg.size(), + *lt_gather_size[1:], + dtype=lt_gather.dtype, + device=lt_gather.device, + ) + + dist.all_gather_into_tensor(lt_output_tensor, lt_gather, group=fake_pg) + + expected_output = torch.cat(list(different_tensors.values())) + + self.assertEqual(lt_output_tensor, expected_output) + def test_all_to_all_single_collective(self): """Test that all_to_all_single collective operation works correctly with LocalTensor.""" from torch.distributed._functional_collectives import all_to_all_single diff --git a/torch/distributed/_local_tensor/__init__.py b/torch/distributed/_local_tensor/__init__.py index 194127b725fa0..cc4a47f299444 100644 --- a/torch/distributed/_local_tensor/__init__.py +++ b/torch/distributed/_local_tensor/__init__.py @@ -1253,6 +1253,10 @@ def __torch_dispatch__( return _c10d._local_all_gather_(*args, **kwargs) elif func is torch.ops.c10d.allgather_into_tensor_coalesced_.default: return _c10d._local_allgather_into_tensor_coalesced_(*args, **kwargs) + elif func is torch.ops.c10d._allgather_base_.default: + return _c10d._local_allgather_base_(*args, **kwargs) + elif func is torch.ops.c10d._reduce_scatter_base_.default: + return _c10d._local_reduce_scatter_base_(*args, **kwargs) elif func is torch.ops.c10d.gather_.default: return _c10d._local_gather_(*args, **kwargs) elif func is torch.ops.c10d.alltoall_.default: diff --git a/torch/distributed/_local_tensor/_c10d.py b/torch/distributed/_local_tensor/_c10d.py index 873da1ad5c626..a6a8c41103c9f 100644 --- a/torch/distributed/_local_tensor/_c10d.py +++ b/torch/distributed/_local_tensor/_c10d.py @@ -486,6 +486,82 @@ def _local_reduce_scatter_tensor_coalesced_( return work_so +def _local_allgather_base_( + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + process_group_so: ScriptObject, + async_op: bool = True, + timeout: int = -1, +) -> tuple[torch.Tensor, ScriptObject]: + # "_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup + # process_group, bool async_op=True, int timeout=-1) -> (Tensor, __torch__.torch.classes.c10d.Work)"); + from . import LocalTensor + + ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so) + + assert isinstance(output_tensor, LocalTensor), "Output tensor must be a LocalTensor" + assert isinstance(input_tensor, LocalTensor), "Input tensor must be a LocalTensor" + + for group_offset in group_offsets: + group_ranks = [group_offset + r for r in ranks] + + gathered_tensors = [] + for rank_i in group_ranks: + gathered_tensors.append(input_tensor._local_tensors[rank_i]) + + gathered_tensor = torch.cat(gathered_tensors, dim=0) + + for rank_i in group_ranks: + output_tensor._local_tensors[rank_i].copy_(gathered_tensor) + + work = FakeWork() + work_so = Work.boxed(work) + return output_tensor, work_so + + +def _local_reduce_scatter_base_( # type: ignore[no-untyped-def] + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + process_group_so: ScriptObject, + reduce_op_so: ScriptObject, + async_op: bool = True, + timeout: int = -1, +) -> tuple[torch.Tensor, ScriptObject]: + # "_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, + # __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, + # bool async_op=True, int timeout=-1) -> (Tensor, __torch__.torch.classes.c10d.Work)" + + from . import LocalTensor + + reduce_op = reduce_op_so.op() # type: ignore[attr-defined] + ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so) + + assert isinstance(output_tensor, LocalTensor), "Output tensor must be a LocalTensor" + assert isinstance(input_tensor, LocalTensor), "Input tensor must be a LocalTensor" + + for group_offset in group_offsets: + group_ranks = [group_offset + r for r in ranks] + + gathered_tensors = [] + for rank_i in group_ranks: + gathered_tensors.append(input_tensor._local_tensors[rank_i]) + + reduced_tensor = _local_reduce(reduce_op, gathered_tensors) + + scattered_tensor = torch.split( + reduced_tensor, + reduced_tensor.size(0) // len(group_ranks), + dim=0, + ) + + for i, rank_i in enumerate(group_ranks): + output_tensor._local_tensors[rank_i].copy_(scattered_tensor[i].clone()) + + work = FakeWork() + work_so = Work.boxed(work) + return output_tensor, work_so + + def _local_all_gather_( output_tensors: list[list[torch.Tensor]], input_tensors: list[torch.Tensor], From 08bfadf971742edf63d8e3eb16f151de0c9dc41b Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Thu, 20 Nov 2025 21:42:44 -0800 Subject: [PATCH 177/230] [DTensor] compute shape and offset for arbitrary _StridedShard (#168146) resolve https://github.com/pytorch/pytorch/issues/167859 for _StridedShard, compute_local_shape_and_global_offset was landed to consider fsdp2 + tp: (_StridedShard(0, split_factor=mesh.size(k)), Shard(0)). Need to extend it to arbitrary _StridedShard for example, `_StridedShard(dim=0, split_factor=batch_size), _StridedShard(dim=0, split_factor=batch_size * seq_len / device_mesh.size(0))` This PR ensure correct local shape for DTensor views with _StridedShard Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/168146 Approved by: https://github.com/wconstab --- test/distributed/tensor/test_utils.py | 471 +++++++++++------- torch/distributed/tensor/_api.py | 2 +- .../distributed/tensor/_ops/_common_rules.py | 5 +- torch/distributed/tensor/_ops/_matrix_ops.py | 2 +- torch/distributed/tensor/_sharding_prop.py | 2 +- torch/distributed/tensor/_utils.py | 227 ++++----- torch/distributed/tensor/placement_types.py | 9 +- 7 files changed, 400 insertions(+), 318 deletions(-) diff --git a/test/distributed/tensor/test_utils.py b/test/distributed/tensor/test_utils.py index 11b70c8554e52..5f3225d174cb2 100644 --- a/test/distributed/tensor/test_utils.py +++ b/test/distributed/tensor/test_utils.py @@ -16,7 +16,6 @@ from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.tensor._utils import ( _compute_local_shape_and_global_offset, - _explicit_order_placements, compute_global_tensor_info, compute_global_tensor_shape, compute_local_shape_and_global_offset, @@ -46,85 +45,6 @@ class LocalTest(TestCase): - def test_explicit_order_placements(self): - # mesh_shape: ShapeType, placements: Sequence[Placement] - test_cases = [ - { - "mesh_shape": [2, 4], - "placements": [Replicate(), Replicate()], - "ordered": [(0, Replicate()), (1, Replicate())], - }, - { - "mesh_shape": [3, 2], - "placements": [Shard(0), Replicate()], - "ordered": [(0, Shard(0)), (1, Replicate())], - }, - { - "mesh_shape": [2, 4], - "placements": [_StridedShard(0, split_factor=4), Shard(0)], - "ordered": [(1, Shard(0)), (0, Shard(0))], - }, - { - "mesh_shape": [2, 3, 4], - "placements": [Shard(0), _StridedShard(0, split_factor=4), Shard(0)], - "ordered": [(0, Shard(0)), (2, Shard(0)), (1, Shard(0))], - }, - { - "mesh_shape": [2, 3, 4], - "placements": [ - _StridedShard(0, split_factor=12), - _StridedShard(0, split_factor=4), - Shard(0), - ], - "ordered": [(2, Shard(0)), (1, Shard(0)), (0, Shard(0))], - }, - ] - for test_case in test_cases: - actual = _explicit_order_placements( - test_case["mesh_shape"], test_case["placements"] - ) - expected = test_case["ordered"] - - self.assertEqual( - actual, - expected, - f"mesh_shape={test_case['mesh_shape']} placements={test_case['placements']}, output: {actual=}, {expected=}", - ) - - error_cases = [ - { - "mesh_shape": [2, 3, 4], - "placements": [Shard(0), _StridedShard(0, split_factor=3), Shard(0)], - "exception_type": RuntimeError, - "exception_text": "Can only convert _StridedShard to ordered Shard if split_factor", - }, - { - "mesh_shape": [2, 3, 4], - "placements": [ - _StridedShard(0, split_factor=3), - Shard(0), - Shard(0), - ], - "exception_type": NotImplementedError, - "exception_text": r"Strided sharding does not allow Shard\(\) to appear after the strided part has ended", - }, - { - "mesh_shape": [2, 3], - "placements": [ - Shard(0), - ], - "exception_type": RuntimeError, - "exception_text": "Expected one placement per mesh dim", - }, - ] - for test_case in error_cases: - with self.assertRaisesRegex( - test_case["exception_type"], test_case["exception_text"] - ): - _explicit_order_placements( - test_case["mesh_shape"], test_case["placements"] - ) - def test_compute_local_shape_and_global_offset_uneven(self): # This case is not only 'uneven' bug also has an empty shard # (e.g. most DP ranks have local shape 18,4096, one has 8,4096, one has 0,4096 @@ -151,6 +71,225 @@ def test_compute_local_shape_and_global_offset_uneven(self): self.assertEqual(local_shape, (expected_shard_size, 4096)) self.assertEqual(global_offset, (expected_shard_offset, 0)) + # S, S uneven without empty + global_shape = (18, 2) + DP = 4 + TP = 2 + mesh_shape = (DP, TP) + placements = [Shard(0), Shard(0)] + for my_coordinate in itertools.product(range(DP), range(TP)): + dp_rank, tp_rank = my_coordinate + local_shape, global_offset = _compute_local_shape_and_global_offset( + global_shape, mesh_shape, list(my_coordinate), placements + ) + + dp012_shard_size = 5 + if dp_rank in (0, 1, 2): + tp0_shard_size = 3 + if tp_rank == 0: + expected_shard_offset = dp012_shard_size * dp_rank + expected_shard_size = 3 + else: + assert tp_rank == 1 + expected_shard_offset = dp012_shard_size * dp_rank + tp0_shard_size + expected_shard_size = 2 + else: + assert dp_rank == 3 + tp0_shard_size = 2 + if tp_rank == 0: + expected_shard_offset = dp012_shard_size * dp_rank + expected_shard_size = 2 + else: + assert tp_rank == 1 + expected_shard_offset = dp012_shard_size * dp_rank + tp0_shard_size + expected_shard_size = 1 + self.assertEqual(local_shape, (expected_shard_size, 2)) + self.assertEqual(global_offset, (expected_shard_offset, 0)) + + # S, S uneven with empty + global_shape = (13, 2) + DP = 4 + TP = 2 + mesh_shape = (DP, TP) + placements = [Shard(0), Shard(0)] + for my_coordinate in itertools.product(range(DP), range(TP)): + dp_rank, tp_rank = my_coordinate + local_shape, global_offset = _compute_local_shape_and_global_offset( + global_shape, mesh_shape, list(my_coordinate), placements + ) + + dp012_shard_size = 4 + if dp_rank in (0, 1, 2): + tp0_shard_size = 2 + if tp_rank == 0: + expected_shard_offset = dp012_shard_size * dp_rank + expected_shard_size = 2 + else: + assert tp_rank == 1 + expected_shard_offset = dp012_shard_size * dp_rank + tp0_shard_size + expected_shard_size = 2 + else: + assert dp_rank == 3 + tp0_shard_size = 1 + if tp_rank == 0: + expected_shard_offset = dp012_shard_size * dp_rank + expected_shard_size = 1 + else: + assert tp_rank == 1 + expected_shard_offset = global_shape[0] + expected_shard_size = 0 + self.assertEqual(local_shape, (expected_shard_size, 2)) + self.assertEqual(global_offset, (expected_shard_offset, 0)) + + # SS, Shard + global_shape = (18, 2) + DP = 4 + TP = 2 + mesh_shape = (DP, TP) + placements = [_StridedShard(0, split_factor=TP), Shard(0)] + TP_shard_size = int(global_shape[0] / TP) + for my_coordinate in itertools.product(range(DP), range(TP)): + dp_rank, tp_rank = my_coordinate + local_shape, global_offset = _compute_local_shape_and_global_offset( + global_shape, mesh_shape, list(my_coordinate), placements + ) + expected_shard_size = 3 + expected_shard_offset = ( + tp_rank * TP_shard_size + expected_shard_size * dp_rank + ) + if dp_rank == 3: + expected_shard_size = 0 + expected_shard_offset = 18 + self.assertEqual(local_shape, (expected_shard_size, 2)) + self.assertEqual(global_offset, (expected_shard_offset, 0)) + + # SS, SS + global_shape = (39, 2) + DP = 4 + TP = 2 + mesh_shape = (DP, TP) + placements = [ + _StridedShard(0, split_factor=3), + _StridedShard(0, split_factor=4), + ] + for my_coordinate in itertools.product(range(DP), range(TP)): + dp_rank, tp_rank = my_coordinate + local_shape, global_offset = _compute_local_shape_and_global_offset( + global_shape, mesh_shape, list(my_coordinate), placements + ) + if dp_rank in (0, 1, 2): + tp0_shard_size = 8 + if tp_rank == 0: + expected_shard_offset = 4 * dp_rank + expected_shard_size = tp0_shard_size + else: + assert tp_rank == 1 + expected_shard_offset = 4 * dp_rank + 2 + expected_shard_size = 4 + else: + assert dp_rank == 3 + tp0_shard_size = 3 + if tp_rank == 0: + expected_shard_offset = 4 * dp_rank + expected_shard_size = 3 + else: + assert tp_rank == 1 + expected_shard_offset = global_shape[0] + expected_shard_size = 0 + self.assertEqual(local_shape, (expected_shard_size, 2)) + self.assertEqual(global_offset, (expected_shard_offset, 0)) + + # (Shard, SS) + global_shape = (18, 2) + DP = 4 + TP = 2 + mesh_shape = (DP, TP) + placements = [Shard(0), _StridedShard(0, split_factor=2)] + for my_coordinate in itertools.product(range(DP), range(TP)): + dp_rank, tp_rank = my_coordinate + local_shape, global_offset = _compute_local_shape_and_global_offset( + global_shape, mesh_shape, list(my_coordinate), placements + ) + if dp_rank in (0, 1, 2): + tp0_shard_size = 3 + if tp_rank == 0: + expected_shard_offset = 5 * dp_rank + expected_shard_size = tp0_shard_size + else: + assert tp_rank == 1 + expected_shard_offset = 5 * dp_rank + 2 + expected_shard_size = 2 + else: + assert dp_rank == 3 + if tp_rank == 0: + expected_shard_offset = 5 * dp_rank + expected_shard_size = 2 + else: + assert tp_rank == 1 + expected_shard_offset = 5 * dp_rank + 1 + expected_shard_size = 1 + self.assertEqual(local_shape, (expected_shard_size, 2)) + self.assertEqual(global_offset, (expected_shard_offset, 0)) + + # (Shard, SS, Shard) + global_shape = (39, 2) + mesh0, mesh1, mesh2 = 4, 2, 3 + mesh_shape = (mesh0, mesh1, mesh2) + placements = [Shard(0), _StridedShard(0, split_factor=2), Shard(0)] + for my_coordinate in itertools.product( + range(mesh0), range(mesh1), range(mesh2) + ): + mesh0_rank, mesh1_rank, mesh2_rank = my_coordinate + local_shape, global_offset = _compute_local_shape_and_global_offset( + global_shape, mesh_shape, list(my_coordinate), placements + ) + if mesh0_rank in (0, 1, 2): + if mesh1_rank == 0: + if mesh2_rank == 0: + expected_shard_offset = 10 * mesh0_rank + expected_shard_size = 2 + elif mesh2_rank == 1: + expected_shard_offset = 10 * mesh0_rank + 2 + expected_shard_size = 2 + else: + expected_shard_offset = 10 * mesh0_rank + 6 + expected_shard_size = 2 + else: + assert mesh1_rank == 1 + if mesh2_rank == 0: + expected_shard_offset = 10 * mesh0_rank + 3 + expected_shard_size = 2 + elif mesh2_rank == 1: + expected_shard_offset = 10 * mesh0_rank + 8 + expected_shard_size = 2 + else: + assert mesh2_rank == 2 + expected_shard_size = 0 + expected_shard_offset = global_shape[0] + else: + assert mesh0_rank == 3 + if mesh1_rank == 0: + if mesh2_rank in (0, 1): + expected_shard_offset = 10 * mesh0_rank + 2 * mesh2_rank + expected_shard_size = 2 + else: + assert mesh2_rank == 2 + expected_shard_offset = 10 * mesh0_rank + 6 + expected_shard_size = 1 + else: + assert mesh1_rank == 1 + if mesh2_rank == 0: + expected_shard_offset = 10 * mesh0_rank + 3 + expected_shard_size = 2 + elif mesh2_rank == 1: + expected_shard_offset = 10 * mesh0_rank + 7 + expected_shard_size = 2 + else: + expected_shard_offset = global_shape[0] + expected_shard_size = 0 + self.assertEqual(local_shape, (expected_shard_size, 2)) + self.assertEqual(global_offset, (expected_shard_offset, 0)) + class UtilTest(DTensorTestBase): @property @@ -292,6 +431,78 @@ def test_compute_local_shape_and_global_offset_2D(self): global_tensor[dim0_start:dim0_end, dim1_start:dim1_end], ) + @with_comms + def test_compute_local_shape_and_global_offset_3D(self): + global_tensor_shape = torch.Size([2 * self.world_size, 2 * self.world_size]) + mesh_size_0 = 2 + mesh_size_1 = 2 + mesh_size_2 = self.world_size // (mesh_size_0 * mesh_size_1) + global_mesh = init_device_mesh( + self.device_type, + (mesh_size_0, mesh_size_1, mesh_size_2), + mesh_dim_names=("mesh-0", "mesh-1", "mesh-2"), + ) + placements = [ + _StridedShard(0, split_factor=mesh_size_1), + Shard(0), + Shard(0), + ] + local_shape, global_offset = compute_local_shape_and_global_offset( + global_tensor_shape, global_mesh, placements + ) + mesh0_rank, mesh1_rank, mesh2_rank = global_mesh.get_coordinate() + self.assertEqual(local_shape, [2, 2 * self.world_size]) + self.assertEqual( + global_offset, (4 * mesh0_rank + 8 * mesh1_rank + 2 * mesh2_rank, 0) + ) + + @with_comms + def test_compute_local_shape_and_global_offset_4D(self): + global_tensor_shape = torch.Size([2 * self.world_size, 2 * self.world_size]) + mesh_size_0 = 1 + mesh_size_1 = 2 + mesh_size_2 = 2 + mesh_size_3 = self.world_size // (mesh_size_0 * mesh_size_1 * mesh_size_2) + global_mesh = init_device_mesh( + self.device_type, + (mesh_size_0, mesh_size_1, mesh_size_2, mesh_size_3), + mesh_dim_names=("mesh-0", "mesh-1", "mesh-2", "mesh-3"), + ) + placements = [ + _StridedShard(0, split_factor=mesh_size_1), + _StridedShard(1, split_factor=mesh_size_3), + Shard(0), + Shard(1), + ] + local_shape, global_offset = compute_local_shape_and_global_offset( + global_tensor_shape, global_mesh, placements + ) + mesh0_rank, mesh1_rank, mesh2_rank, mesh3_rank = global_mesh.get_coordinate() + self.assertEqual( + local_shape, (2 * mesh_size_1 * mesh_size_3, 2 * mesh_size_0 * mesh_size_2) + ) + self.assertEqual( + global_offset, + (8 * mesh2_rank + 4 * mesh0_rank, 8 * mesh3_rank + 4 * mesh1_rank), + ) + placements = [ + _StridedShard(0, split_factor=mesh_size_1), + _StridedShard(1, split_factor=mesh_size_3), + Shard(0), + Shard(0), + ] + local_shape, global_offset = compute_local_shape_and_global_offset( + global_tensor_shape, global_mesh, placements + ) + mesh0_rank, mesh1_rank, mesh2_rank, mesh3_rank = global_mesh.get_coordinate() + self.assertEqual( + local_shape, (2 * mesh_size_1, 2 * mesh_size_2 * mesh_size_3 * mesh_size_0) + ) + self.assertEqual( + global_offset, + (8 * mesh2_rank + 0 * mesh0_rank + 4 * mesh3_rank, 4 * mesh1_rank), + ) + @with_comms def test_fsdp_tp_meta_compute(self): # FSDP + TP sharding @@ -362,106 +573,6 @@ def test_hsdp_tp_meta_compute(self): self.assertEqual(local_shape, expected_local_shape) self.assertEqual(global_offset, expected_global_offset) - # TODO: remove this test once we support general meta compute on strided sharding - @with_comms - def test_strided_sharding_assumption_in_meta_compute(self): - # current ``compute_local_shape_and_global_offset`` does not allow Shard(i) - # placement to appear after the strided sharding part has ended. This test - # check that ``compute_local_shape_and_global_offset`` does not allow placements - # that violate the assumption and does not forbid the allowed ones. - - # Test 0: 2-D mesh - mesh_size_0 = 2 - mesh_size_1 = self.world_size // mesh_size_0 - global_mesh = init_device_mesh( - self.device_type, - (mesh_size_0, mesh_size_1), - mesh_dim_names=("mesh-0", "mesh-1"), - ) - global_tensor_shape = torch.Size([2 * self.world_size, 2 * self.world_size]) - - for shard_dim in [0, 1]: - placements = [ - _StridedShard(shard_dim, split_factor=mesh_size_1), - Shard(shard_dim), - ] - _, _ = compute_local_shape_and_global_offset( - global_tensor_shape, global_mesh, placements - ) - - # Test 1: 3-D mesh - mesh_size_0 = 2 - mesh_size_1 = 2 - mesh_size_2 = self.world_size // (mesh_size_0 * mesh_size_1) - global_mesh = init_device_mesh( - self.device_type, - (mesh_size_0, mesh_size_1, mesh_size_2), - mesh_dim_names=("mesh-0", "mesh-1", "mesh-2"), - ) - - # legal placements: Shard() appear after the strided part but it's on another - # tensor dimension. - placements = [ - _StridedShard(0, split_factor=mesh_size_1), - Shard(0), - Shard(1), - ] - _, _ = compute_local_shape_and_global_offset( - global_tensor_shape, global_mesh, placements - ) - - # illegal placements: Shard() appear after the strided part and it's on the - # same tensor dimension. - placements = [ - _StridedShard(0, split_factor=mesh_size_1), - Shard(0), - Shard(0), - ] - with self.assertRaisesRegex(NotImplementedError, "the strided part has ended"): - _, _ = compute_local_shape_and_global_offset( - global_tensor_shape, global_mesh, placements - ) - - # Test 2: 4-D mesh - mesh_size_0 = 1 - mesh_size_1 = 2 - mesh_size_2 = 2 - mesh_size_3 = self.world_size // (mesh_size_0 * mesh_size_1 * mesh_size_2) - global_mesh = init_device_mesh( - self.device_type, - (mesh_size_0, mesh_size_1, mesh_size_2, mesh_size_3), - mesh_dim_names=("mesh-0", "mesh-1", "mesh-2", "mesh-3"), - ) - # legal placements: Shard() appear after the strided part but it's on another - # tensor dimension. - placements = [ - _StridedShard(0, split_factor=mesh_size_1), - _StridedShard(1, split_factor=mesh_size_3), - Shard(0), - Shard(1), - ] - local_shape, _ = compute_local_shape_and_global_offset( - global_tensor_shape, global_mesh, placements - ) - expected_local_shape = ( - 2 * mesh_size_1 * mesh_size_3, - 2 * mesh_size_0 * mesh_size_2, - ) - self.assertEqual(local_shape, expected_local_shape) - - # illegal placements: Shard() appear after the strided part and it's on the - # same tensor dimension. - placements = [ - _StridedShard(0, split_factor=mesh_size_1), - _StridedShard(1, split_factor=mesh_size_3), - Shard(0), - Shard(0), - ] - with self.assertRaisesRegex(NotImplementedError, "the strided part has ended"): - _, _ = compute_local_shape_and_global_offset( - global_tensor_shape, global_mesh, placements - ) - class UtilSingleDeviceTest(TestCase): def test_compute_global_tensor_info_unsupported_placement(self): diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py index fb072d8dce629..dabf9f6f194ce 100644 --- a/torch/distributed/tensor/_api.py +++ b/torch/distributed/tensor/_api.py @@ -1071,7 +1071,7 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def] # get local tensor shape local_shape, _ = compute_local_shape_and_global_offset( - size, device_mesh, placements + size, device_mesh, placements, skip_offset=True ) # initialize the local tensor diff --git a/torch/distributed/tensor/_ops/_common_rules.py b/torch/distributed/tensor/_ops/_common_rules.py index 1e7ff648f7fbd..2d4a311b4bedd 100644 --- a/torch/distributed/tensor/_ops/_common_rules.py +++ b/torch/distributed/tensor/_ops/_common_rules.py @@ -168,7 +168,10 @@ def merge_sharding(dim: str, a: int, b: int) -> int: assert input_spec.tensor_meta is not None global_shape = input_spec.tensor_meta.shape local_shape, _ = compute_local_shape_and_global_offset( - global_shape, input_spec.mesh, input_spec.placements + global_shape, + input_spec.mesh, + input_spec.placements, + skip_offset=True, ) cost += prod(local_shape) * input_spec.mesh.size(mesh_dim) # pyrefly: ignore [bad-argument-type] diff --git a/torch/distributed/tensor/_ops/_matrix_ops.py b/torch/distributed/tensor/_ops/_matrix_ops.py index 49152a1bee13a..5ccf3c37c7855 100644 --- a/torch/distributed/tensor/_ops/_matrix_ops.py +++ b/torch/distributed/tensor/_ops/_matrix_ops.py @@ -1090,7 +1090,7 @@ def local_meta(spec: OpSpec, placements: tuple[Placement, ...]) -> TensorMeta: meta: TensorMeta = spec.output_specs.tensor_meta local_stride = compute_local_stride(meta.stride, mesh, placements) local_shape, _ = compute_local_shape_and_global_offset( - meta.shape, mesh, placements + meta.shape, mesh, placements, skip_offset=True ) return TensorMeta(torch.Size(local_shape), local_stride, meta.dtype) diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index f3dc04ef10f97..2db44f387e4eb 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -660,7 +660,7 @@ def _adjust_shape_and_stride_args( # adjust shape to be the same as that of the _local_tensor # of the DTensor input arg at index 0, which is inferred expected_input_schema[shape_idx], _ = compute_local_shape_and_global_offset( - out_tensor_meta.shape, spec.mesh, spec.placements + out_tensor_meta.shape, spec.mesh, spec.placements, skip_offset=True ) # adjust the stride arg for aten.new_empty_strided.default diff --git a/torch/distributed/tensor/_utils.py b/torch/distributed/tensor/_utils.py index 74ad2aaa80434..d7ee355500528 100644 --- a/torch/distributed/tensor/_utils.py +++ b/torch/distributed/tensor/_utils.py @@ -1,5 +1,4 @@ import threading -from collections import defaultdict from collections.abc import Sequence from typing import cast, Optional @@ -7,6 +6,7 @@ import torch.distributed._functional_collectives as funcol import torch.distributed.tensor._api as dtensor from torch._prims_common import ShapeType +from torch.distributed._local_tensor import maybe_run_for_local_tensor from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor._collective_utils import redistribute_cost from torch.distributed.tensor._dtensor_spec import DTensorSpec @@ -17,7 +17,6 @@ Replicate, Shard, ) -from torch.utils._typing_utils import not_none class ExplicitRedistributionContext: @@ -56,61 +55,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): ExplicitRedistributionContext._local._active = self._prev -def _explicit_order_placements( - mesh_shape: ShapeType, placements: Sequence[Placement] -) -> Sequence[tuple[int, Placement]]: - """ - Replace Strided Shards with regular shards in an adjusted order. - - Returns a list of (mesh_dim, placement) tuples where the list order is the sharding order. - - ex. - [Shard(0), _StridedShard(0, split_factor=2), Shard(0)] -> - [(0, Shard(0)), (2, Shard(0)), (1, Shard(0))] - - """ - if not len(placements) == len(mesh_shape): - raise RuntimeError( - "Expected one placement per mesh dim, " - f"but found {len(placements)} placements and {len(mesh_shape)} mesh dims." - ) - ordered = [] - deferred_strided_placements = defaultdict(list) - strided_part_ended_for_dim = set() - for mesh_dim, p in enumerate(placements): - if isinstance(p, _StridedShard): - # validate the stride is the correct multiple of the meshdim and the earlier shard - deferred_strided_placements[p.dim].append((mesh_dim, p)) - - else: - ordered.append((mesh_dim, p)) - if isinstance(p, Shard): - if p.dim in strided_part_ended_for_dim: - raise NotImplementedError( - f"Strided sharding does not allow Shard() to appear after " - f"the strided part has ended. {p} at mesh dim {mesh_dim} in " - f"{placements} violates this assumption." - ) - - if p.dim in deferred_strided_placements: - strided_part_ended_for_dim.add(p.dim) - strided_placements = deferred_strided_placements.pop(p.dim) - aggregate_size = mesh_shape[mesh_dim] - while len(strided_placements) > 0: - strided_mesh_dim, strided = strided_placements.pop() - if not strided.split_factor == aggregate_size: - raise RuntimeError( - f"Can only convert _StridedShard to ordered Shard if split_factor({strided.split_factor})" - f" == aggregate mesh size ({aggregate_size})" - ) - aggregate_size *= mesh_shape[strided_mesh_dim] - ordered.append((strided_mesh_dim, Shard(p.dim))) - - return ordered - - def compute_local_shape_and_global_offset( - global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] + global_shape: ShapeType, + mesh: DeviceMesh, + placements: Sequence[Placement], + skip_offset: bool = False, ) -> tuple[tuple[int, ...], tuple[int, ...]]: """ Compute the local tensor shape and the global offsets into the original tensor @@ -143,24 +92,55 @@ def compute_local_shape_and_global_offset( global_shape (ShapeType): The global shape of the DTensor. mesh (:class:`DeviceMesh`): The device mesh this DTensor is distributed on. placements (Sequence[:class:`Placement`]]): The placements of the DTensor. + skip_offset (bool): If True, skip computing the global offsets and return an empty + tuple for global_offset. This can improve performance when only the local shape + is needed. Defaults to False. Return: local_shape: the shape of the DTensor's _local_tensor on the current rank. global_offset: a tuple of offsets for each dimension of the global tensor shape, - identifying how this shard fits into the global tensor in each dimension. + identifying how this shard fits into the global tensor in each dimension. If + skip_offset is True, this will be an empty tuple. """ return _compute_local_shape_and_global_offset( - global_shape, mesh.shape, mesh.get_coordinate(), placements + global_shape, mesh.shape, mesh.get_coordinate(), placements, skip_offset ) +@maybe_run_for_local_tensor +def _compute_offsets( + placement, + shard_offsets: int, + shard_size: int, + zero_global_offset: int, + previous_offsets, +) -> torch.Tensor: + if shard_size == 0: + return torch.arange(zero_global_offset, zero_global_offset + 1) + if isinstance(placement, Shard) and not isinstance(placement, _StridedShard): + index = torch.arange(shard_offsets, shard_offsets + shard_size) + else: + assert isinstance(shard_offsets, list) + index = torch.tensor(shard_offsets) + if previous_offsets is None: + return index + else: + return previous_offsets[index] + + +@maybe_run_for_local_tensor +def _get_first_offset(offsets: torch.Tensor) -> int: + return int(offsets[0]) + + # accept 'plain data types' to enable simpler unit testing without creating device mesh def _compute_local_shape_and_global_offset( global_shape: ShapeType, mesh_shape: ShapeType, my_coordinate: Optional[list[int]], placements: Sequence[Placement], + skip_offset: bool = False, ) -> tuple[tuple[int, ...], tuple[int, ...]]: """ Suppose you have a full tensor with size global_shape, and you have sharded @@ -176,85 +156,72 @@ def _compute_local_shape_and_global_offset( This function is fairly simple if your tensor is evenly sharded; the complication is around uneven splits. There is also some complication for handling StridedShard, which changes the order you should apply sharding. + + Args: + global_shape (ShapeType): The global shape of the tensor. + mesh_shape (ShapeType): The shape of the device mesh. + my_coordinate (Optional[list[int]]): The coordinate of the current rank in the device mesh. + placements (Sequence[Placement]): The placements of the DTensor. + skip_offset (bool): If True, skip computing the global offsets and return an empty + tuple for global_offset. This can improve performance when only the local shape + is needed. Defaults to False. + + Returns: + tuple: A tuple containing: + - local_shape (tuple[int, ...]): The shape of the local shard on the current rank. + - global_offset (tuple[int, ...]): The offsets for each dimension identifying where + this shard begins in the global tensor. If skip_offset is True, this will be an + empty tuple. """ + empty_offset = () if my_coordinate is None: # if rank not in the mesh, return empty offset - return ((0,), ()) - - # StridedShard implies a non-standard order to apply shards; get the - # correct order to start applying splits - ordered_placements = _explicit_order_placements(mesh_shape, placements) + return ((0,), empty_offset) local_shape = list(global_shape) - # We'll compute the data for where the shard begins on a per-dim basis. - # However, a single dim can be sharded multiple times, so we will end up - # doing a Sum(size*stride) like computation to determine the location of our - # shard for each of the shardings on that dim. - global_offset = [0] * len(global_shape) - - for mesh_dim, placement in ordered_placements: + # Perform shard from left to right. For example, + # global tensor: [0, 1, 2, 3, 4, 5, 6, 7] + # placements: S(0), SS(0, split_factor=2) + # mesh_shape: (2, 2) + # After S(0), shard_dim_to_global_offsets are + # {0: [0, 1, 2, 3]} on my_coordinate [0, 0] [0, 1] + # {0: [4, 5, 6, 7]} on my_coordinate [1, 0] [1, 1] + # After SS(0, split_factor=2), shard_dim_to_global_offsets are + # {0: [0, 2]} on my_coordinate [0, 0] + # {0: [1, 3]} on my_coordinate [0, 1] + # {0: [4, 6]} on my_coordinate [1, 0] + # {0: [5, 7]} on my_coordinate [1, 1] + shard_dim_to_global_offsets = {} + for mesh_dim, placement in enumerate(placements): mesh_dim_size = mesh_shape[mesh_dim] - if isinstance(placement, Shard): - shard_dim = placement.dim - assert shard_dim < len(local_shape), ( - f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" - ) - shard_size, shard_offset = placement._local_shard_size_and_offset( - local_shape[shard_dim], - mesh_dim_size, - my_coordinate[mesh_dim], - ) - - local_shape[shard_dim] = shard_size - - shard_global_offset = global_offset[shard_dim] + not_none(shard_offset) - - zero_global_offset = global_shape[shard_dim] - if isinstance(shard_global_offset, torch.SymInt) and not isinstance( - zero_global_offset, torch.SymInt - ): - zero_global_offset = torch.SymInt(zero_global_offset) - - global_offset[shard_dim] = torch.sym_ite( - shard_size == 0, - # Special case to fill in a standardized non-garbage value for - # the global_offset of zero-sized shards. This value is out - # of bounds of the tensor, so it won't conflict with any real - # offsets. DCP may rely on this value to de-duplicate shards. - # Note that you can end up with zero-size shards that are - # still otherwise in bounds for the tensor (TODO: give an - # example). - zero_global_offset, - # As we successively shard the same dimension, we keep - # advancing our pointer beyond our original offset until we - # get to the final chunk start. - shard_global_offset, - ) - - # NOTE: the offset compute relies on the local shard index and it has no - # problem when strided sharding is not present. To correctly compute, we assume - # that the ``_StridedShard.split_factor`` field encodes how many partitions - # each local tensor will be further split into when sharding on higher mesh - # dimensions. However, this number is only correct if the DTensor is not - # sharded after the strided sharding completes. For example, - # [Shard(0), _StridedShard(0, split_factor=2), Shard(0)] is the placements - # where the DTensor's dim-0 is first sharded on device mesh dim-0, then on - # device mesh dim-2, and last on mesh dim-1. We define the - # "_StridedShard(0, split_factor=2), Shard(0)" part as the strided sharding - # part because strided sharding happens on mesh dim-1 and it was caused by - # the fact that sharding on dim-2 occurred ahead. In this case, there's no - # further sharding after this strided sharding part and ``split_factor`` - # correctly encodes the number. Another example is - # [_StridedShard(0, split_factor=2), Shard(0), Shard(0)] where the DTensor's - # dim-0 is first sharded on mesh dim-1, then on mesh dim-0, and last on mesh - # dim-2. This violates our assumption that no further sharding shall occur - # after the strided sharding part and ``split_factor`` won't correctly - # encode the number of further split. So far, the only case where _StridedShard - # placement would appear is FSDP2 + TP on 2D mesh and the above case could only - # happen on mesh of 3 or more dimensions. - # TODO: change this function to correctly address this. - # TODO: this logic can be applied to contiguous sharding as well + if not isinstance(placement, (Shard, _StridedShard)): + continue + shard_dim = placement.dim + zero_global_offset = global_shape[shard_dim] + assert shard_dim < len(local_shape), ( + f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" + ) + shard_size, shard_offsets = placement._local_shard_size_and_offset( + local_shape[shard_dim], + mesh_dim_size, + my_coordinate[mesh_dim], + ) + local_shape[shard_dim] = shard_size + if skip_offset: + continue + shard_dim_to_global_offsets[shard_dim] = _compute_offsets( + placement, + shard_offsets, + shard_size, + zero_global_offset, + shard_dim_to_global_offsets.get(shard_dim), + ) + if skip_offset: + return tuple(local_shape), empty_offset + global_offset = [0] * len(global_shape) + for shard_dim, global_offsets in shard_dim_to_global_offsets.items(): + global_offset[shard_dim] = _get_first_offset(global_offsets) return tuple(local_shape), tuple(global_offset) diff --git a/torch/distributed/tensor/placement_types.py b/torch/distributed/tensor/placement_types.py index 726abc5971376..65da0a7b1823b 100644 --- a/torch/distributed/tensor/placement_types.py +++ b/torch/distributed/tensor/placement_types.py @@ -684,12 +684,13 @@ def _to_replicate_tensor( def _local_shard_size(sharded_indices: list[torch.Tensor], rank: int) -> int: return len(sharded_indices[rank]) - def _local_shard_size_and_offset( + # delete pyre-ignore once separating _StridedShard from Shard + def _local_shard_size_and_offset( # pyre-ignore[bad-override] self, curr_local_size: int, num_chunks: int, rank: int, - ) -> tuple[int, Optional[int]]: + ) -> tuple[int, list[int]]: # indices_tensor is 1D torch.arange(logical_dim_size) unsqueezed # so that we can reuse self._split_tensor which splits on self.dim shape = [1] * self.dim + [curr_local_size] @@ -707,9 +708,9 @@ def _local_shard_size_and_offset( sharded_indices = [shard.view(-1) for shard in sharded_indices] local_shard_size = _StridedShard._local_shard_size(sharded_indices, rank) + offsets = sharded_indices[rank].tolist() - # offsets from _StridedShard is never used - return local_shard_size, None + return local_shard_size, offsets class Replicate(torch._C._distributed.Replicate): From 8f8082d7f077db8fc5d55a601fd8a1f3a4b236f7 Mon Sep 17 00:00:00 2001 From: Joona Havukainen Date: Fri, 21 Nov 2025 19:49:15 +0000 Subject: [PATCH 178/230] Fix memory leak test for SDPA op call (#168040) Fixes #158330 Currently the test fails in the nightly because it keeps committing more work without waiting for previous work to finish and clean-up. So it is not testing the correct thing. Each iteration adds 2MB to the driver allocated memory as it holds on to the needed inputs, outputs and graph definition while the computation is in-flight. Either waiting for about ~0.2s after each op call or calling device to synchronize allows the computation to wrap-up and the underlying memory to get released, displaying the expected result that the memory before and after the op call is the same. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168040 Approved by: https://github.com/kulinseth --- test/test_mps.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_mps.py b/test/test_mps.py index a84ac7d355169..51f2637e4d55e 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -9667,10 +9667,12 @@ def get_mps_memory_usage(): memory_footprints = [] for _ in range(100): output = F.scaled_dot_product_attention(query, key, value) + # syncronize to wait for the GPU computation to return + torch.mps.synchronize() current_mem, driver_mem = get_mps_memory_usage() memory_footprints.append((current_mem, driver_mem)) - # 5 MB different maximum allowed value(could be decreased even more) - torch.testing.assert_close(memory_footprints[-1], memory_footprints[0], atol=5, rtol=1) + # 1 kB different maximum allowed value + torch.testing.assert_close(memory_footprints[-1], memory_footprints[0], atol=1e-3, rtol=1e-3) def generate_qkv(self, batch: int, NH: int, q_len: int, s_len: int, head_dim: int, layout: str, dtype: torch.dtype): if layout == "contiguous": From 739acb8a7ee27f23b99c9b12345766ac89e24acc Mon Sep 17 00:00:00 2001 From: William Wen Date: Mon, 17 Nov 2025 17:56:06 -0800 Subject: [PATCH 179/230] [dynamo, nested graph breaks] fix FOR_ITER iterator push and zip strict; enable nested graph break tests on test_functions.py (#167694) The iterator push in FOR_ITER must be done before attempting to call next_variable because if next_variable graph breaks with nested graph breaks on, the iterator must be on the stack to resume properly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167694 Approved by: https://github.com/StrongerXi --- test/dynamo/test_functions.py | 12 ++++++++---- torch/_dynamo/symbolic_convert.py | 7 ++++--- torch/_dynamo/variables/builtin.py | 4 ++-- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index bac435cebfdfc..840d4b32ab389 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -153,7 +153,7 @@ def inline_script_if_tracing_fn_with_default_args(x, y, c=1.2): return torch.cos(x * y) + c -class FunctionTests(torch._dynamo.test_case.TestCase): +class FunctionTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks): @make_test def test_inline_jit_annotations(x): x = inline_script_if_tracing(x) @@ -4221,7 +4221,7 @@ def forward(self): return self.m() -class DefaultsTests(torch._dynamo.test_case.TestCase): +class DefaultsTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks): def test_func_default_tensor_args(self): """ Tests that we indeed reference (and mutate) "the one" default tensor arg @@ -4749,7 +4749,7 @@ def fn(x, ys, zs): x = x.clone() for y, z in zip(ys, zs, strict=True): x += y * z - return x + return x, zip(ys, zs) opt_fn = torch.compile(fn, backend="eager") nopython_fn = torch.compile(fn, backend="eager", fullgraph=True) @@ -4758,7 +4758,11 @@ def fn(x, ys, zs): ys = [1.0, 2.0, 3.0] zs = [2.0, 5.0, 8.0] - self.assertEqual(opt_fn(x, ys, zs), fn(x, ys, zs)) + ref = fn(x, ys, zs) + res = opt_fn(x, ys, zs) + self.assertEqual(ref[0], res[0]) + self.assertEqual(list(ref[1]), list(res[1])) + self.assertIsInstance(res[1], zip) # If nopython, should raise UserError with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"): diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 18f053a2ca675..dab2393a86259 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -2050,22 +2050,23 @@ def WITH_CLEANUP_FINISH(self, inst: Instruction) -> None: def FOR_ITER(self, inst: Instruction) -> None: it = self.pop().realize() + self.push(it) try: val = it.next_variable(self) - self.push(it) self.push(val) except (StopIteration, exc.ObservedUserStopIteration) as e: if isinstance(e, exc.ObservedUserStopIteration): exc.handle_observed_exception(self) - # leave iterator upon exhaustion in 3.12 if sys.version_info >= (3, 12): # CPython 3.12 actually jumps to the instruction after the END_FOR # and performs the action of END_FOR as part of FOR_ITER. We jump # to the END_FOR and run it, so we need to make sure 2 values are # on the stack for it to pop. - self.push(it) self.push(ConstantVariable.create(None)) + else: + # pop the iterator in Python < 3.12 + self.pop() self.jump(inst) def _create_exception_type(self, val: VariableTracker) -> VariableTracker: diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 746db0f3dfd62..ae6678628634a 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -2280,11 +2280,11 @@ def call_zip( "1 kwargs (`strict`)", f"{len(kwargs)} kwargs", ) - strict = kwargs.pop("strict", False) + strict = kwargs.pop("strict", ConstantVariable.create(False)) iter_args = [BuiltinVariable(iter).call_function(tx, [arg], {}) for arg in args] return variables.ZipVariable( iter_args, - strict=strict, # type: ignore[arg-type] + strict=strict.as_python_constant(), mutation_type=ValueMutationNew(), ) From 044143a1be4f75d1bac01db369d8f75e817e7420 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Thu, 20 Nov 2025 17:23:20 +0000 Subject: [PATCH 180/230] Fix `hash(Size([SymInt, ...]))` on Python 3.14+ (#168256) Fixes PyTorch issue #168254 Python 3.14 introduced an optimization to tuple hashing by adding hash caching in the tuple structure. In older versions, Python would recompute the hash on every call. Now, the computed hash is stored in the tuple's `ob_hash` field and reused on subsequent calls. `torch.Size(...)` has this field set to `0` in some scenarios and causes the cache to behave incorrectly. To fix this, we reset the cache right before delegating the call to CPython `tuple_hash`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168256 Approved by: https://github.com/williamwen42 --- test/test_dynamic_shapes.py | 7 +++++++ torch/csrc/Size.cpp | 28 +++++++++++++++++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 41ce5af6a28be..5c721395bf9cd 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -511,6 +511,13 @@ def test_meta_symint(self): r = torch.empty(a0, device="meta") self.assertIsInstance(r.shape[0], SymInt) + def test_hash_size(self): + # See issue #168254 + shape_env = ShapeEnv() + a0 = create_symint(shape_env, 2) + r = torch.empty(a0, device="meta") + self.assertRaises(TypeError, lambda: hash(r.shape)) + def test_guard_int(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 2) diff --git a/torch/csrc/Size.cpp b/torch/csrc/Size.cpp index 62bc48fa9b983..ea39424cf8ea7 100644 --- a/torch/csrc/Size.cpp +++ b/torch/csrc/Size.cpp @@ -219,6 +219,23 @@ static PySequenceMethods THPSize_as_sequence = { nullptr /* sq_contains */ }; +#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 14 +static Py_hash_t THPSize_hash(PyObject* self) { + /* + Python 3.14 introduce a caching mechanism for tuple hashing which is stored + in the `ob_hash` field. The caching mechanism relies on a sentinel value (-1) + to indicate the hash has not yet been computed. + For some unknown reason, this field is initialized with 0 when Size is + created, which causes the caching logic to behave incorrectly. + */ + PyTupleObject* v = _PyTuple_CAST(self); + // reset ob_hash and force hash to be recomputed + Py_hash_t sentinel = -1; + v->ob_hash = sentinel; + return PyTuple_Type.tp_hash(self); +} +#endif + static PyMappingMethods THPSize_as_mapping = { nullptr, /* mp_length */ wrap_tuple_fn, @@ -284,7 +301,11 @@ PyTypeObject THPSizeType = { &THPSize_as_number, /* tp_as_number */ &THPSize_as_sequence, /* tp_as_sequence */ &THPSize_as_mapping, /* tp_as_mapping */ - nullptr, /* tp_hash */ +#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 14 + &THPSize_hash, /* tp_hash */ +#else + nullptr, /* tp_hash */ +#endif nullptr, /* tp_call */ nullptr, /* tp_str */ nullptr, /* tp_getattro */ @@ -294,7 +315,12 @@ PyTypeObject THPSizeType = { nullptr, /* tp_doc */ nullptr, /* tp_traverse */ nullptr, /* tp_clear */ +#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 14 + // if tp_hash is defined, one must also defines tp_richcompare + PyTuple_Type.tp_richcompare, /* tp_richcompare */ +#else nullptr, /* tp_richcompare */ +#endif 0, /* tp_weaklistoffset */ nullptr, /* tp_iter */ nullptr, /* tp_iternext */ From 82e9ae95498b673dced0b51c871808e35895f4e2 Mon Sep 17 00:00:00 2001 From: Andrey Talman Date: Fri, 21 Nov 2025 20:16:38 +0000 Subject: [PATCH 181/230] Forward fix numpy binary check after #168270 (#168374) Looks like https://pypi.org/project/numpy/1.23.5/#files is available for Python 3.11 hence use this version instead of 1.21.2 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168374 Approved by: https://github.com/seemethere --- .circleci/scripts/binary_linux_test.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.circleci/scripts/binary_linux_test.sh b/.circleci/scripts/binary_linux_test.sh index 58d0af29e133b..3771ecc108f87 100755 --- a/.circleci/scripts/binary_linux_test.sh +++ b/.circleci/scripts/binary_linux_test.sh @@ -53,9 +53,9 @@ if [[ "$PACKAGE_TYPE" != libtorch ]]; then # numpy tests: # We test 1 version no numpy. 1 version with numpy 1.x and rest with numpy 2.x if [[ "\$python_nodot" = *311* ]]; then - retry pip install -q protobuf typing-extensions + retry pip install -q numpy==1.23.5 protobuf typing-extensions elif [[ "\$python_nodot" = *312* ]]; then - retry pip install -q numpy==1.21.2 protobuf typing-extensions + retry pip install -q protobuf typing-extensions else retry pip install -q numpy protobuf typing-extensions fi From b8e682385a9ba62b45ee6924b3d4054b60d4030f Mon Sep 17 00:00:00 2001 From: Chris Leonard Date: Fri, 21 Nov 2025 20:30:21 +0000 Subject: [PATCH 182/230] Fix arg parser one pos arg (#163081) Fixes #130609 tensor.rehape() was running with extra arguments whenever the first argument was a tuple. This is because methods with one sequence arguments can be passed in as a Sequence (tuple, list, etc.) or variable argument (*args). But when multiple arguments where passed in and the first argument was a sequence, it would run and just ignore the other arguments. For example, if someone wrote ``` x = torch.ones((4, 3)) y = x.reshape((2, 6), torch.float32) ``` the process would run without any errors but the 'torch.float32' wouldn't do anything. This could be problematic if the developer thinks they are changing the the tensor dtype (or anything else they think they are doing) but really nothing is happening other than reshaping (silent bugs). This PR fixes this issue for any method that has one integer sequence argument, not just reshape() (including tile() and view()) **Warning:** this could break downstream code if someone has already implemented something like this. However, they could also have a silent bug in their code that goes unnoticed because no errors are thrown so it is still best to fix this issue. @malfet Pull Request resolved: https://github.com/pytorch/pytorch/pull/163081 Approved by: https://github.com/isuruf, https://github.com/malfet, https://github.com/albanD --- test/test_torch.py | 3 +++ test/test_view_ops.py | 10 ++++++++++ torch/csrc/utils/python_arg_parser.cpp | 4 +++- 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/test/test_torch.py b/test/test_torch.py index 9b9cc2cfc58f9..66b2002a36d1e 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -7464,6 +7464,9 @@ def test_parsing_intlist(self): "missing 1 required positional arguments", lambda: torch.tensor().new_zeros((5, 5), 0)) + # ensure ones() throws an error when extra positional (non-keyword) arguments are given. + self.assertRaises(TypeError, lambda: torch.ones((3, 3), torch.float32)) + def test_from_buffer(self): a = bytearray([1, 2, 3, 4]) self.assertEqual(torch.ByteStorage.from_buffer(a).tolist(), [1, 2, 3, 4]) diff --git a/test/test_view_ops.py b/test/test_view_ops.py index 980439b7a6967..2a127241c49bd 100644 --- a/test/test_view_ops.py +++ b/test/test_view_ops.py @@ -1183,6 +1183,8 @@ def test_reshape(self, device): self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr()) self.assertEqual(torch.reshape(x, (9,)), x.reshape(9)) self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1)) + # ensure reshape() throws an error if extra positional arguments are given. + self.assertRaises(TypeError, lambda: x.reshape((9,), torch.float32)) y = torch.randn(4, 4, 4, device=device)[:, 0, :] # .data_ptr() on meta tensors is always 0 so they are equal regardless of the reshape @@ -1726,6 +1728,9 @@ def can_broadcast(s0, s1): r"must match the existing size \(\d\)", ): torch.broadcast_to(t, s1) + # ensure broadcast_to() throws an error when extra positional arguments are given. + t = torch.tensor([1, 2, 3]) + self.assertRaises(TypeError, lambda: t.broadcast_to((3, 3), torch.float32)) def test_view(self, device): tensor = torch.rand(15, device=device) @@ -1812,6 +1817,11 @@ def test_view(self, device): self.assertEqual(tensor.view(6, 2, 1), contig_tensor.view(6, 2, 1)) self.assertEqual(tensor.view(1, 6, 2, 1), contig_tensor.view(1, 6, 2, 1)) + # ensure view() throws an error if extra positional arguments are given. + self.assertRaises( + TypeError, lambda: tensor.view((tensor.numel(),), torch.float32) + ) + @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) def test_reshape_view_semantics(self, device, dtype): tensor = make_tensor((15, 4), dtype=dtype, device=device) diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index e89f7887320a0..5fa0986cc814d 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -1687,7 +1687,9 @@ bool FunctionSignature::parse( if (max_pos_args == 1 && (params[0].type_ == ParameterType::INT_LIST || params[0].type_ == ParameterType::SYM_INT_LIST)) { - allow_varargs_intlist = true; + int64_t failed_idx = -1; + allow_varargs_intlist = is_int_or_symint_list( + args, params[0].size, &failed_idx, &overloaded_args); } if (static_cast(nargs) > max_pos_args && !allow_varargs_intlist) { From 9141f03c32ac954298a70dc1a9b7eeacb46971a9 Mon Sep 17 00:00:00 2001 From: jmaczan Date: Fri, 21 Nov 2025 20:37:18 +0000 Subject: [PATCH 183/230] 20x less memory use and 37.25% speedup in min_cut_rematerialization_partition when using the new dp knapsack solver, compared to existing default one (dp) (#160914) The goal: **Reduce memory usage** in [min_cut_rematerialization_partition](https://github.com/pytorch/pytorch/blob/1f1900369435933a013df9a7d5e07c75c1cebb5d/torch/_functorch/partitioners.py#L2601) with new knapsack implementation Rationale: The @Chillee's comment in original dp_knapsack suggests improving the code with [Hirschberg algorithm and sliding window](https://codeforces.com/blog/entry/47247?#comment-316200). In this PR, I create a new knapsack implementation which is based on the original implementation and uses Hirschberg + sliding window Existing `dp_knapsack` implementation instantiates full dp table of shape (n, W), where n is number of memory elements and W is quantized memory length. The new `dp_knapsack_sliding_hirschberg` uses only (2, W) memory - it needs only 2 dp profiles (1-dim tensors) to compute the same output as original `dp_knapsack` This optimization is possible, because at each step we use only current and previous row ("dp profile"). We split the indices of items (each item is a pair of memory and runtime) into two, compute dp profile for each half, then compute the index of split at which the sum of runtimes from both halfs is the highest, then we use this split index to decide how much budget we give to left half and right half and we recurse on left half and right half with new memory budgets Based on benchmarks, consider if we should keep two dp knapsack implementations or one, which one should be default, do we want to make it easier to use the new one etc. In general, think about the next steps Thanks in advance for all comments Pull Request resolved: https://github.com/pytorch/pytorch/pull/160914 Approved by: https://github.com/ezyang Co-authored-by: Edward Yang --- test/functorch/test_ac_knapsack.py | 81 ++++++++++ .../_activation_checkpointing/knapsack.py | 146 ++++++++++++++++++ torch/_functorch/config.py | 5 +- torch/_functorch/partitioners.py | 3 + 4 files changed, 233 insertions(+), 2 deletions(-) diff --git a/test/functorch/test_ac_knapsack.py b/test/functorch/test_ac_knapsack.py index 751a4c4d21859..2d2899e9ca297 100644 --- a/test/functorch/test_ac_knapsack.py +++ b/test/functorch/test_ac_knapsack.py @@ -2,6 +2,10 @@ from torch._functorch._activation_checkpointing.graph_info_provider import ( GraphInfoProvider, ) +from torch._functorch._activation_checkpointing.knapsack import ( + dp_knapsack, + dp_knapsack_sliding_hirschberg, +) from torch._functorch._activation_checkpointing.knapsack_evaluator import ( KnapsackEvaluator, ) @@ -326,5 +330,82 @@ def test_get_backward_memory_from_topologically_sorted_graph(self): self.assertEqual(result_item[1], expected_result_item[1]) +class TestActivationCheckpointingKnapsack(TestCase): + def setUp(self): + # (memory, runtime, max_memory, expected_runtime, expected_saved, expected_recomputable) + self.test_cases = [ + ([2, 3, 2, 4, 1], [1, 2, 1, 3, 2], 5, 5.0, [3, 4], [2, 1, 0]), + ([1, 1, 1], [1, 2, 3], 3, 6.0, [0, 1, 2], []), + ([10, 20, 30], [1, 2, 3], 5, 0.0, [], [2, 1, 0]), + ([1, 2, 3], [10, 20, 30], 1, 10.0, [0], [2, 1]), + ([1, 1, 1], [2, 2, 2], 2, 4.0, [0, 1], [2]), + ([0, 2, 3], [5, 2, 3], 5, 10.0, [0, 1, 2], []), + ([1, 2, 3], [0, 2, 3], 3, 3.0, [2], [0, 1]), + ([100, 200, 300], [1000, 2000, 3000], 500, 5000.0, [1, 2], [0]), + ([0.5, 1.5, 2.0], [1.0, 2.0, 3.0], 2.0, 3.0, [1, 0], [2]), + ([], [], 10, 0.0, [], []), + ([1, 2, 3], [1, 2, 3], 0, 0.0, [], [2, 1, 0]), + ([0, 0, 0], [1, 2, 3], 0, 6.0, [0, 1, 2], []), + ([1, 2, 3], [0, 0, 0], 6, 0.0, [], [2, 1, 0]), + ] + + def _run_knapsack_and_check( + self, + func, + memory, + runtime, + max_memory, + expected_runtime, + expected_saved, + expected_recomputable, + ): + result_runtime, result_saved, result_recomputable = func( + memory, runtime, max_memory + ) + self.assertEqual(result_runtime, expected_runtime) + self.assertEqual(sorted(result_saved), sorted(expected_saved)) + self.assertEqual(sorted(result_recomputable), sorted(expected_recomputable)) + + def test_dp_knapsack(self): + for i, ( + memory, + runtime, + max_memory, + expected_runtime, + expected_saved, + expected_recomputable, + ) in enumerate(self.test_cases): + with self.subTest(f"dp_knapsack_case_{i}"): + self._run_knapsack_and_check( + dp_knapsack, + memory, + runtime, + max_memory, + expected_runtime, + expected_saved, + expected_recomputable, + ) + + def test_dp_knapsack_sliding_hirschberg(self): + for i, ( + memory, + runtime, + max_memory, + expected_runtime, + expected_saved, + expected_recomputable, + ) in enumerate(self.test_cases): + with self.subTest(f"dp_knapsack_sliding_hirschberg_case_{i}"): + self._run_knapsack_and_check( + dp_knapsack_sliding_hirschberg, + memory, + runtime, + max_memory, + expected_runtime, + expected_saved, + expected_recomputable, + ) + + if __name__ == "__main__": run_tests() diff --git a/torch/_functorch/_activation_checkpointing/knapsack.py b/torch/_functorch/_activation_checkpointing/knapsack.py index 0a3eaa5a9344c..b2f0a124c64c1 100644 --- a/torch/_functorch/_activation_checkpointing/knapsack.py +++ b/torch/_functorch/_activation_checkpointing/knapsack.py @@ -119,3 +119,149 @@ def dp_knapsack( max_runtime = dp[n][quantized_max_memory].item() return max_runtime, saved_items, recomputable_items + + +def dp_knapsack_sliding_hirschberg( + memory: list[float], runtime: list[float], max_memory: float +) -> tuple[float, list[int], list[int]]: + # Scaling factor to convert floating point weights to integers + S = 10000 + + # q_ prefix stands for quantized + q_memory = [int(round(m * S)) for m in memory] + runtimes = [float(v) for v in runtime] + + q_max_memory = int(round(max_memory * S)) + + q_memory_length = len(q_memory) + if q_memory_length == 0: + return 0.0, [], [] + + item_indices = list(range(q_memory_length)) + dp_profile_size = q_max_memory + 1 + + # Current DP profile (row) + dp_profile = torch.zeros(dp_profile_size, dtype=torch.float32, device="cpu") + # Store a candidate for next dp_profile - current dp row + item + candidate_profile = torch.empty(dp_profile_size, dtype=torch.float32, device="cpu") + left_profile = torch.empty(dp_profile_size, dtype=torch.float32, device="cpu") + right_profile = torch.empty(dp_profile_size, dtype=torch.float32, device="cpu") + + saved_items: list[int] = [] + recomputable_items: list[int] = [] + + # Explicit stack to optimize memory and avoid recursion + # Stack stores segments as (start index, end index, capacity for segment) + stack: list[tuple[int, int, int]] = [(0, q_memory_length, q_max_memory)] + + # LIFO + while stack: + start, end, capacity = stack.pop() + length = end - start + if length == 0: + continue + + # Leaf + if length == 1: + index = item_indices[start] + memory_item = q_memory[index] + runtime_item = runtimes[index] + if memory_item <= capacity and runtime_item > 0.0: + saved_items.append(index) + else: + recomputable_items.append(index) + continue + + # Split the segment into two halves + middle = start + (length // 2) + left_start, left_end = middle, end + right_start, right_end = start, middle + + # Assign items to both halves + left_items = item_indices[left_start:left_end] + right_items = item_indices[right_start:right_end] + + # Working only on items allowed by segment's capacity + capacity = capacity + 1 + dp_view = dp_profile[:capacity] + candidate_view = candidate_profile[:capacity] + left_dp_local = left_profile[:capacity] + right_dp_local = right_profile[:capacity] + + # Left part + dp_view.zero_() + for index in left_items: + memory_item = q_memory[index] + runtime_item = runtimes[index] + + if memory_item == 0: + # Weight is 0, so add it to all capacities; a "free lunch", essentially + dp_view.add_(runtime_item) + continue + + # If item is too heavy, we skip it + if memory_item >= capacity: + continue + + # Add the current item so we can then pick the highest value + dp_view_candidate = candidate_view[: capacity - memory_item] + torch.add(dp_view[:-memory_item], runtime_item, out=dp_view_candidate) + # Take the highest - either previous (without current) or with current + torch.maximum( + dp_view[memory_item:], dp_view_candidate, out=dp_view[memory_item:] + ) + + # Store the left profile + left_dp_local.copy_(dp_view) + + # Right part + dp_view.zero_() + for index in right_items: + memory_item = q_memory[index] + runtime_item = runtimes[index] + + if memory_item == 0: + dp_view.add_(runtime_item) + continue + + if memory_item >= capacity: + continue + + dp_view_candidate = candidate_view[: capacity - memory_item] + torch.add(dp_view[:-memory_item], runtime_item, out=dp_view_candidate) + torch.maximum( + dp_view[memory_item:], dp_view_candidate, out=dp_view[memory_item:] + ) + + # Store the reversed right profile + right_dp_local.copy_(dp_view.flip(-1)) + + # In-place compute item-wise sum of left and right to pick the split point where the sum is highest + left_dp_local.add_(right_dp_local) + + # Pick the index of highest value of a pair, which we then use as a split point + best_split = int(torch.argmax(left_dp_local).item()) + + left_capacity = best_split + right_capacity = capacity - best_split + + # Clamp (might be removed if we're 100% sure that there is no edge case that will mess up the indices math) + if left_capacity < 0: + left_capacity = 0 + if right_capacity < 0: + right_capacity = 0 + if left_capacity > q_max_memory: + left_capacity = q_max_memory + if right_capacity > q_max_memory: + right_capacity = q_max_memory + + # Push right then left, so left is processed next + stack.append((right_start, right_end, right_capacity)) + stack.append((left_start, left_end, left_capacity)) + + saved_items = sorted(saved_items) + recomputable_items = sorted(recomputable_items) + + max_runtime = sum(runtime[i] for i in saved_items) + recomputable_items.reverse() + return max_runtime, saved_items, recomputable_items diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index 790cf71a83a23..42d6f308f831a 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -162,8 +162,9 @@ def remote_autograd_cache_default() -> Optional[bool]: activation_memory_budget_runtime_estimator = "flops" # This controls the solver used for the 0-1 knapsack. By default we use a -# quantized DP solution ("dp"). The other approaches are a "greedy" and a "ilp" -# (which has a scipy dependency). +# quantized DP solution ("dp"). The other approaches are a "greedy", a "ilp" +# (which has a scipy dependency) and "dp_knapsack_sliding_hirschberg", which +# used memory-efficient quantized DP solution activation_memory_budget_solver = "dp" # This dumps out a SVG visualization of the expected runtime vs. activation diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index e7c665d8df9d1..3b79a50ff9e21 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -47,6 +47,7 @@ from ._activation_checkpointing.graph_info_provider import GraphInfoProvider from ._activation_checkpointing.knapsack import ( dp_knapsack, + dp_knapsack_sliding_hirschberg, greedy_knapsack, ilp_knapsack, ) @@ -2365,6 +2366,8 @@ def _optimize_runtime_with_given_memory( return ilp_knapsack(memory, runtimes, max_memory) elif SOLVER == "dp": return dp_knapsack(memory, runtimes, max_memory) + elif SOLVER == "dp_knapsack_sliding_hirschberg": + return dp_knapsack_sliding_hirschberg(memory, runtimes, max_memory) elif SOLVER == "dynamic_memory_budget_dp": log.warning( "dynamic_memory_budget_dp is an experimental solver. " From a2d11eb06972cbfa3fb0d23595e89d7585618e57 Mon Sep 17 00:00:00 2001 From: drisspg Date: Fri, 21 Nov 2025 02:32:48 +0000 Subject: [PATCH 184/230] [FlexFlash] Add wiring for backwards (#168319) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168319 Approved by: https://github.com/v0i0 --- test/inductor/test_flex_flash.py | 123 +++++++++++ torch/_inductor/kernel/flex/flex_attention.py | 13 +- .../kernel/flex/flex_flash_attention.py | 204 ++++++++++++++++-- .../flash_attention_backward.py.jinja | 28 +++ 4 files changed, 353 insertions(+), 15 deletions(-) create mode 100644 torch/_inductor/kernel/flex/templates/flash_attention_backward.py.jinja diff --git a/test/inductor/test_flex_flash.py b/test/inductor/test_flex_flash.py index 0c877ff33b5e2..50f12291a0e83 100644 --- a/test/inductor/test_flex_flash.py +++ b/test/inductor/test_flex_flash.py @@ -193,6 +193,36 @@ def flash_vs_triton(q, k, v, score_mod=None, block_mask=None, rtol=2): f"Flash error {flash_error:.2e} exceeds {rtol}x Triton error {triton_error:.2e} + {fwd_atol:.2e}" ) + needs_backward = any( + isinstance(t, torch.Tensor) and t.requires_grad for t in (q, k, v) + ) + if needs_backward: + grad = torch.randn_like(out_flash) + inputs = (q, k, v) + grads_ref = torch.autograd.grad(out_ref_fp32, inputs, grad) + grads_triton = torch.autograd.grad(out_triton, inputs, grad) + grads_flash = torch.autograd.grad(out_flash, inputs, grad) + + dq_atol = 2 * (grads_ref[0] + 0.3 - 0.3 - grads_ref[0]).abs().max().item() + dk_atol = 2 * (grads_ref[1] + 0.3 - 0.3 - grads_ref[1]).abs().max().item() + dv_atol = 2 * (grads_ref[2] + 0.3 - 0.3 - grads_ref[2]).abs().max().item() + + atol_pack = (dq_atol, dk_atol, dv_atol) + for grad_flash, grad_triton, grad_ref, atol in zip( + grads_flash, grads_triton, grads_ref, atol_pack + ): + assert torch.isfinite(grad_flash).all() + assert torch.isfinite(grad_triton).all() + assert torch.isfinite(grad_ref).all() + + triton_error = (grad_triton - grad_ref).abs().max().item() + flash_error = ( + (grad_flash - grad_ref.to(grad_flash.dtype)).abs().max().item() + ) + assert flash_error <= rtol * triton_error + atol, ( + f"Flash error {flash_error:.2e} exceeds {rtol}x Triton error {triton_error:.2e} + {atol:.2e}" + ) + return out_flash, out_triton, out_ref_fp32 @@ -330,6 +360,99 @@ def score_mod_with_grad(score, b, h, q_idx, kv_idx): kernel_options={"BACKEND": "FLASH"}, ) + @dtypes(torch.float16, torch.bfloat16) + def test_flash_attention_backward_rejects_mask_mod(self, device, dtype): + q, k, v = create_test_tensors(dtype=dtype, device=device) + + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + block_mask = _create_block_mask_for_device( + causal_mask, 2, 4, 512, 512, device=device + ) + q.requires_grad_(True) + compiled_fn = torch.compile(flex_attention) + with self.assertRaisesRegex( + RuntimeError, + r"NYI: Flex Flash Attention doesn't support block_sparsity yet", + ): + compiled_fn( + q, k, v, block_mask=block_mask, kernel_options={"BACKEND": "FLASH"} + ).sum().backward() + + @dtypes(torch.float16, torch.bfloat16) + def test_flash_attention_backward_rejects_score_mod_capture(self, device, dtype): + q, k, v = create_test_tensors(dtype=dtype, device=device) + + bias = torch.randn(4, device=device, dtype=dtype) + + def score_mod_with_capture(score, b, h, q_idx, kv_idx): + return score + bias[h] + + q.requires_grad_(True) + compiled_fn = torch.compile(flex_attention) + with self.assertRaisesRegex( + RuntimeError, + r"NYI: Flex Flash Attention doesn't support score_mods in bwds yet", + ): + compiled_fn( + q, + k, + v, + score_mod=score_mod_with_capture, + kernel_options={"BACKEND": "FLASH"}, + ).sum().backward() + + @dtypes(torch.float16, torch.bfloat16) + def test_flash_attention_backward_rejects_score_mod(self, device, dtype): + q, k, v = create_test_tensors(dtype=dtype, device=device) + + def score_mod_twice(score, b, h, q_idx, kv_idx): + return score * 2 + + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + compiled_fn = torch.compile(flex_attention) + with self.assertRaisesRegex( + RuntimeError, + r"NYI: Flex Flash Attention doesn't support score_mods in bwds yet", + ): + compiled_fn( + q, + k, + v, + score_mod=score_mod_twice, + kernel_options={"BACKEND": "FLASH"}, + ).sum().backward() + + @dtypes(torch.float16, torch.bfloat16) + def test_flash_attention_backward_kernel_called(self, device, dtype): + q, k, v = create_test_tensors(dim=128, dtype=dtype, device=device) + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + + flash_vs_triton(q, k, v) + + compiled_fn = torch.compile(flex_attention) + + def run_for_profile(): + q_run, k_run, v_run = ( + t.detach().clone().requires_grad_(True) for t in (q, k, v) + ) + compiled_fn( + q_run, k_run, v_run, kernel_options={"BACKEND": "FLASH"} + ).sum().backward() + + with cuda_kernel_profiler("flash_attncuteflash_bwd") as prof_result: + run_for_profile() + + self.assertTrue( + prof_result["found"], + f"Flash attention backward kernel not found. Kernels: {prof_result['kernel_names']}", + ) + @dtypes(torch.float16, torch.bfloat16) def test_flash_attention_with_block_mask(self, device, dtype): """Test flash attention with block mask and mask_mod.""" diff --git a/torch/_inductor/kernel/flex/flex_attention.py b/torch/_inductor/kernel/flex/flex_attention.py index c555f66dbf538..d36b8d56cc711 100644 --- a/torch/_inductor/kernel/flex/flex_attention.py +++ b/torch/_inductor/kernel/flex/flex_attention.py @@ -39,6 +39,8 @@ from .flex_decoding import _use_flex_decoding, create_flex_decoding_kernel from .flex_flash_attention import ( _use_flex_flash_attention, + _use_flex_flash_attention_backward, + create_flex_flash_attention_backward_kernel, create_flex_flash_attention_kernel, ) @@ -660,7 +662,7 @@ def flex_attention_backward(*args, **kwargs): f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" ) - kernel_options, _ = _sanitize_kernel_options_for_triton(kernel_options) + kernel_options, backend = _sanitize_kernel_options_for_triton(kernel_options) # Mark symbols in custom kernel options as static shapes and add guards. kernel_options = { k: V.graph.sizevars.guard_int(v) if isinstance(v, sympy.Symbol) else v @@ -723,6 +725,15 @@ def flex_attention_backward(*args, **kwargs): ) freeze_irnodes(mask_graph_buffer) + if _use_flex_flash_attention_backward( + fw_graph, + mask_graph, + backend=backend, + ): + return create_flex_flash_attention_backward_kernel( + query, key, value, out, logsumexp, grad_out, scale, kernel_options + ) + # Construct layout with stride order matching K key_size = [Bq, Hkv, seq_len_kv, qk_head_dim] key_strides = infer_dense_strides(key_size, key.get_stride()) diff --git a/torch/_inductor/kernel/flex/flex_flash_attention.py b/torch/_inductor/kernel/flex/flex_flash_attention.py index 78a79f9664b68..05d1290f0ab49 100644 --- a/torch/_inductor/kernel/flex/flex_flash_attention.py +++ b/torch/_inductor/kernel/flex/flex_flash_attention.py @@ -42,6 +42,10 @@ def ensure_flash_available() -> bool: flash_attention_cutedsl_template = CuteDSLTemplate( name="flash_attention_cutedsl", source=load_flex_template("flash_attention") ) +flash_attention_backward_cutedsl_template = CuteDSLTemplate( + name="flash_attention_backward_cutedsl", + source=load_flex_template("flash_attention_backward"), +) def _fixed_indexer_cute( @@ -101,6 +105,28 @@ def cutedsl_make_indexer(self): FixedLayout.make_indexer = original_make_indexer # type: ignore[assignment] +def wrap_choice_render_with_cutedsl_indexer(choice: Any) -> None: + """ + Wrap a template choice's kernel render to apply CuteDSL indexer patching. + + See Note [CuteDSL indexer patch]: + This wrapper allows the template to construct its closures normally, then + scopes the indexer patch to the actual render call that emits the kernel. + This ensures CuteDSL templates see colexicographic indexing while preserving + the template's setup logic. + """ + original_make_kernel_render = choice.make_kernel_render + + def make_kernel_render_with_patch(*args, **kwargs): + render_kernel, render = original_make_kernel_render(*args, **kwargs) + # Let the template construct its closures, then scope the indexer patch + # to the actual render call that emits the kernel + render_with_patch = patch_fixed_layout_indexer_for_cutedsl()(render) + return render_kernel, render_with_patch + + choice.make_kernel_render = make_kernel_render_with_patch + + def input_buffers_require_grads(graph_module, num_score_mod_placeholders: int): """Check if any of the input buffers (beyond the score mod placeholders) require gradients.""" inputs = [] @@ -117,6 +143,18 @@ def requires_grad(n): return any(requires_grad(n) for n in inputs[num_score_mod_placeholders:]) +def is_trivial_score_graph(graph_module: GraphModule) -> bool: + """Backwards currently doesn't support score_mods, match against identity""" + graph = graph_module.graph + nodes = list(graph.nodes) + placeholders = [n for n in nodes if n.op == "placeholder"] + output = [n for n in nodes if n.op == "output"] + assert len(output) == 1, "Got graph w/ multiple outputs" + output_val = output[0].args[0] + # The identity graph just sends the score straight through + return output_val == placeholders[0] + + def is_trivial_mask_graph(graph_module: GraphModule) -> bool: """Mask graph is trivial when it only gates via the default full op.""" graph = graph_module.graph @@ -287,29 +325,167 @@ def create_flex_flash_attention_kernel( NEEDS_BLOCK_MASK=needs_block_mask, ) - def wrap_choice_render(choice): - # See Note [CuteDSL indexer patch] - original_make_kernel_render = choice.make_kernel_render + for choice in choices: + wrap_choice_render_with_cutedsl_indexer(choice) + + if error or not choices: + # Fallback to original implementation + raise RuntimeError(f"CuteDSL template failed: {error}") + + # No autotune for now + template_output = choices[0].output_node() + + return (template_output, lse) + + +def _can_use_flex_flash_attention_backward( + fw_subgraph: Subgraph, + mask_graph: Subgraph, +) -> tuple[bool, str]: + if not ensure_flash_available(): + return False, "CUTE flash attention is not available" + + if not is_trivial_score_graph(fw_subgraph.graph_module): + return ( + False, + "NYI: Flex Flash Attention doesn't support score_mods in bwds yet.", + ) + + if not is_trivial_mask_graph(mask_graph.graph_module): + return False, "NYI: Flex Flash Attention doesn't support block_sparsity yet." + + return True, "" + + +def _use_flex_flash_attention_backward( + fw_subgraph: Subgraph, + mask_graph: Subgraph, + backend: Literal["AUTO", "TRITON", "FLASH", "TRITON_DECODE"], +) -> bool: + """Determine if we should use flex flash attention for the given inputs. + + Args: + subgraph: The score modification subgraph + mask_graph: The mask modification subgraph + kernel_options: Kernel configuration options + num_score_mod_placeholders: Number of placeholders in score_mod + backend: Implementation selector (AUTO, TRITON, FLASH, TRITON_DECODE) + + Returns: + True if flash attention should be used, False otherwise + """ + # Flash is experimental and must be explicitly requested + if backend != "FLASH": + return False + + can_use, reason = _can_use_flex_flash_attention_backward( + fw_subgraph, + mask_graph, + ) + + if not can_use: + raise RuntimeError( + f"BACKEND='FLASH' but flash attention cannot be used: {reason}" + ) + + return True + + +def create_flex_flash_attention_backward_kernel( + query: TensorBox, + key: TensorBox, + value: TensorBox, + out: TensorBox, + logsumexp: TensorBox, + grad_out: TensorBox, + scale: float, + kernel_options: dict[str, Any], + # TODO: will be needed + # grad_logsumexp, + # fw_graph: SubgraphResults, + # joint_graph: SubgraphResults, + # mask_graph: SubgraphResults, + # score_mod_other_buffers: list[TensorBox], + # mask_mod_other_buffers: list[TensorBox], + # kv_num_blocks: TensorBox | None, + # kv_indices: TensorBox | None, + # full_kv_num_blocks: TensorBox | None, + # full_kv_indices: TensorBox | None, +) -> tuple[TensorBox | ShapeAsConstantBuffer, TensorBox, TensorBox, tuple]: + """Create a CuteDSL flash attention backward kernel for the default mod path.""" + if not ensure_flash_available(): + raise RuntimeError("CUTE flash attention not available") + + batch_size, num_heads, seq_len_q, head_dim = query.get_size() + v_head_dim = value.get_size()[-1] + device = query.get_device() + dtype = query.get_dtype() + assert device is not None + + grad_query_strides = infer_dense_strides( + [batch_size, num_heads, seq_len_q, head_dim], query.get_stride() + ) + grad_query = empty_strided( + size=[batch_size, num_heads, seq_len_q, head_dim], + stride=grad_query_strides, + dtype=dtype, + device=device, + ) + + grad_key_strides = infer_dense_strides( + [batch_size, num_heads, value.get_size()[2], head_dim], key.get_stride() + ) + grad_key = empty_strided( + size=[batch_size, num_heads, value.get_size()[2], head_dim], + stride=grad_key_strides, + dtype=dtype, + device=device, + ) - def make_kernel_render_with_patch(*args, **kwargs): - render_kernel, render = original_make_kernel_render(*args, **kwargs) + grad_value_strides = infer_dense_strides( + [batch_size, num_heads, value.get_size()[2], v_head_dim], value.get_stride() + ) + grad_value = empty_strided( + size=[batch_size, num_heads, value.get_size()[2], v_head_dim], + stride=grad_value_strides, + dtype=dtype, + device=device, + ) - # Let the template construct its closures, then scope the indexer patch - # to the actual render call that emits the kernel - render_with_patch = patch_fixed_layout_indexer_for_cutedsl()(render) + output_layout = FixedLayout( + device=device, + dtype=dtype, + size=[batch_size, num_heads, seq_len_q, head_dim], + stride=[sympy.sympify(s) for s in grad_query.get_stride()], + ) - return render_kernel, render_with_patch + choices: list[Any] = [] - choice.make_kernel_render = make_kernel_render_with_patch + input_nodes = [ + query, + key, + value, + out, + grad_out, + logsumexp, + grad_key, + grad_value, + ] + + error = flash_attention_backward_cutedsl_template.maybe_append_choice( + choices, + input_nodes=input_nodes, + layout=output_layout, + mutated_inputs=[grad_key, grad_value], + SM_SCALE=scale, + ) for choice in choices: - wrap_choice_render(choice) + wrap_choice_render_with_cutedsl_indexer(choice) if error or not choices: - # Fallback to original implementation raise RuntimeError(f"CuteDSL template failed: {error}") - # No autotune for now template_output = choices[0].output_node() - return (template_output, lse) + return (template_output, grad_key, grad_value, tuple()) diff --git a/torch/_inductor/kernel/flex/templates/flash_attention_backward.py.jinja b/torch/_inductor/kernel/flex/templates/flash_attention_backward.py.jinja new file mode 100644 index 0000000000000..2831ba6af5b60 --- /dev/null +++ b/torch/_inductor/kernel/flex/templates/flash_attention_backward.py.jinja @@ -0,0 +1,28 @@ +{{def_kernel("Q", "K", "V", "OUT", "D_OUT", "LSE", "DK", "DV")}} + from flash_attn.cute.interface import _flash_attn_bwd + + q_transposed = Q.transpose(1, 2) + k_transposed = K.transpose(1, 2) + v_transposed = V.transpose(1, 2) + out_transposed = OUT.transpose(1, 2) + d_out_transposed = D_OUT.transpose(1, 2) + + dq_transposed, dk_transposed, dv_transposed = _flash_attn_bwd( + q_transposed, + k_transposed, + v_transposed, + out_transposed, + d_out_transposed, + LSE, + softmax_scale={{SM_SCALE}}, + ) + + dq = dq_transposed.transpose(1, 2) + dk = dk_transposed.transpose(1, 2) + dv = dv_transposed.transpose(1, 2) + + dq_out = {{get_output()}} + {# TODO: add out support to flash #} + dq_out.copy_(dq) + DK.copy_(dk) + DV.copy_(dv) From 81bfd503d935a633646b793077bf60c2ef7f9b73 Mon Sep 17 00:00:00 2001 From: Jiannan Wang Date: Fri, 21 Nov 2025 21:12:15 +0000 Subject: [PATCH 185/230] Add warning for clearing profiler events at the end of each cycle (#168066) Fixes #148314. This PR introduces a warning to clarify the behavior of the Profiler regarding event management. Specifically, it informs users that: - The Profiler clears events at the end of each cycle. - Only events from the current cycle will be reported by default. - To retain events across cycles, users should set `acc_events=True`. The warning is triggered only once when `self.profiler` is not `None` and `self.acc_events` is `False`, using the `warn_once` function. This change aims to improve transparency and help users avoid confusion when analyzing profiling results. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168066 Approved by: https://github.com/sraikund16 --- torch/profiler/profiler.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index 056a5fcc21fdd..c52bd0f9ce2bb 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -38,6 +38,14 @@ ] PROFILER_STEP_NAME = "ProfilerStep" +_WARNINGS_SHOWN = set() + + +def _warn_once(msg, category=UserWarning, stacklevel=2): + if msg not in _WARNINGS_SHOWN: + _WARNINGS_SHOWN.add(msg) + warn(msg, category=category, stacklevel=stacklevel) + class _NumpyEncoder(json.JSONEncoder): """ @@ -205,6 +213,12 @@ def prepare_trace(self) -> None: acc_events=self.acc_events, custom_trace_id_callback=self.custom_trace_id_callback, ) + if (self.profiler is not None) and (not self.acc_events): + _warn_once( + "Warning: Profiler clears events at the end of each cycle." + "Only events from the current cycle will be reported." + "To keep events across cycles, set acc_events=True." + ) self.profiler._prepare_trace() def start_trace(self) -> None: From b1cd563cf8d63924325f9470406fdc15d35c98e0 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Fri, 21 Nov 2025 21:38:21 +0000 Subject: [PATCH 186/230] Revert #154859 (#168297) We suspect it's causing intermittent segfaults Pull Request resolved: https://github.com/pytorch/pytorch/pull/168297 Approved by: https://github.com/malfet --- test/profiler/test_execution_trace.py | 7 +++ .../standalone/execution_trace_observer.cpp | 59 +------------------ 2 files changed, 10 insertions(+), 56 deletions(-) diff --git a/test/profiler/test_execution_trace.py b/test/profiler/test_execution_trace.py index dbd5d89ad6a61..26c0ab42905de 100644 --- a/test/profiler/test_execution_trace.py +++ b/test/profiler/test_execution_trace.py @@ -2,6 +2,7 @@ import json import os +import sys import tempfile import unittest from typing import Any @@ -364,6 +365,9 @@ def test_execution_trace_env_disabled(self, device): self.assertTrue(p.execution_trace_observer is None) @unittest.skipIf(IS_WINDOWS, "torch.compile does not support WINDOWS") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) @unittest.skipIf( (not has_triton()) or (not TEST_CUDA and not TEST_XPU), "need triton and device(CUDA or XPU) availability to run", @@ -419,6 +423,9 @@ def fn(a, b, c): assert found_call_compiled_fx_graph @unittest.skipIf(IS_WINDOWS, "torch.compile does not support WINDOWS") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) @unittest.skipIf( (not has_triton()) or (not TEST_CUDA and not TEST_XPU), "need triton and device(CUDA or XPU) availability to run", diff --git a/torch/csrc/profiler/standalone/execution_trace_observer.cpp b/torch/csrc/profiler/standalone/execution_trace_observer.cpp index 29b2b94af4472..b46e1d19bcd0e 100644 --- a/torch/csrc/profiler/standalone/execution_trace_observer.cpp +++ b/torch/csrc/profiler/standalone/execution_trace_observer.cpp @@ -112,59 +112,8 @@ struct TORCH_API ExecutionTraceObserver { // NOLINT std::map> opStack; // Uses the underlying TensorImpl object pointer as the key and map to its // unique id. - std::map objectId; - - using weak_storage_ptr = c10::weak_intrusive_ptr; - std::unordered_map data_ptr_to_storage_id; - std::unordered_map - data_ptr_to_weak_storage_ptr; - - ID get_tensor_storage_ID(const c10::Storage& t_storage) { - const std::lock_guard lock(gMutex); - - const void* raw_data_ptr = nullptr; - bool should_track_liveness = false; - // FakeTensor/FunctionalTensor may clear the Storage handle entirely or use - // a nullptr data pointer. Treat both cases as a shared cache key but avoid - // touching the weak-ref table so they can reuse the same ID without - // tripping the liveness check. - if (t_storage.unsafeGetStorageImpl()) { - raw_data_ptr = t_storage.data(); - should_track_liveness = raw_data_ptr != nullptr; - } - - auto id_iter = data_ptr_to_storage_id.find(raw_data_ptr); - if (!should_track_liveness) { - if (id_iter != data_ptr_to_storage_id.end()) { - return id_iter->second; - } - ID id = storage_id_++; - data_ptr_to_storage_id.emplace(raw_data_ptr, id); - return id; - } - - auto weak_iter = data_ptr_to_weak_storage_ptr.find(raw_data_ptr); - if (weak_iter == data_ptr_to_weak_storage_ptr.end()) { - ID id = storage_id_++; - data_ptr_to_storage_id.insert_or_assign(raw_data_ptr, id); - data_ptr_to_weak_storage_ptr.emplace( - raw_data_ptr, t_storage.getWeakStorageImpl()); - return id; - } - - if (weak_iter->second.expired()) { - ID id = storage_id_++; - data_ptr_to_storage_id.insert_or_assign(raw_data_ptr, id); - data_ptr_to_weak_storage_ptr.insert_or_assign( - raw_data_ptr, t_storage.getWeakStorageImpl()); - return id; - } - - id_iter = data_ptr_to_storage_id.find(raw_data_ptr); - TORCH_INTERNAL_ASSERT(id_iter != data_ptr_to_storage_id.end()); - return id_iter->second; - } + std::map objectId{}; // Observer run state. enum class RunState { uninitialized, disabled, enabled }; @@ -227,8 +176,6 @@ struct TORCH_API ExecutionTraceObserver { // NOLINT // 1 -> root ID // 2 ... -> regular node ID std::atomic id_{2}; - - std::atomic storage_id_{1}; }; // Using a singleton manager here to allow init and delete the observer object. @@ -499,8 +446,8 @@ convertIValue( // symbolic sizes/strides implies t->storage_offset() will fail if (tensor_impl->has_storage() && !tensor_impl->has_symbolic_sizes_strides()) { - const c10::Storage& t_storage = tensor_impl->storage(); - storage_id = ob.get_tensor_storage_ID(t_storage); + auto& t_storage = tensor_impl->storage(); + storage_id = getObjectID(ob, t_storage.data()); offset = tensor_impl->storage_offset(); numel = tensor_impl->numel(); itemsize = tensor_impl->itemsize(); From c8b265fba14432e3371756318314b96b96943ad3 Mon Sep 17 00:00:00 2001 From: Parshant Sharma Date: Fri, 21 Nov 2025 22:31:49 +0000 Subject: [PATCH 187/230] [dynamo, nested graph breaks] Fix-nested-graph-break-suppression (#167743) Fixes #167393 ### Summary: - Modified graph break deduplication to use the full chain of frame locations when nested_graph_breaks=True. ### Impact: - module: dynamo Pull Request resolved: https://github.com/pytorch/pytorch/pull/167743 Approved by: https://github.com/williamwen42 --- test/dynamo/test_error_messages.py | 145 +++++++++++++++++++++++++++++ torch/_dynamo/symbolic_convert.py | 44 ++++++++- 2 files changed, 186 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_error_messages.py b/test/dynamo/test_error_messages.py index c706e5f7af025..49f787bd25cd6 100644 --- a/test/dynamo/test_error_messages.py +++ b/test/dynamo/test_error_messages.py @@ -1811,6 +1811,151 @@ def f3(x): """, ) + @make_logging_test(graph_breaks=True) + def test_try_block_with_graph_break_suppression(self, records): + global inner, middle_with_try, outer + + def inner(x): + result = x + 1 + torch._dynamo.graph_break() + return result + 1 + + def middle_with_try(x): + try: + return inner(x) + except Exception: + pass + return x + + def outer(x): + return middle_with_try(x) + + with torch._dynamo.config.patch(nested_graph_breaks=True, verbose=False): + torch.compile(outer, backend="eager")(torch.ones(3)) + + full_messages = [ + r for r in records if "Graph break in user code" in r.getMessage() + ] + suppressed_messages = [ + r + for r in records + if "user stack suppressed due to duplicate" in r.getMessage() + ] + + self.assertEqual( + len(full_messages), + 1, + f"Expected 1 full graph break message, got {len(full_messages)}", + ) + self.assertEqual( + len(suppressed_messages), + 1, + f"Expected at least 1 suppressed message, got {len(suppressed_messages)}", + ) + + self.assertExpectedInline( + munge_exc(full_messages[0].getMessage(), suppress_suffix=True, skip=0), + """\ +Graph break in user code at test_error_messages.py:N +Graph Break Reason: Call to `torch._dynamo.graph_break()` + Explanation: User-inserted graph break. Message: None + Hint: Remove the `torch._dynamo.graph_break()` call. + + Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` + + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html +User code traceback: + File "test_error_messages.py", line N, in test_try_block_with_graph_break_suppression + torch.compile(outer, backend="eager")(torch.ones(3)) + File "test_error_messages.py", line N, in outer + return middle_with_try(x) + File "test_error_messages.py", line N, in middle_with_try + return inner(x) + File "test_error_messages.py", line N, in inner + torch._dynamo.graph_break() +""", + ) + + self.assertExpectedInline( + munge_exc( + suppressed_messages[0].getMessage(), suppress_suffix=True, skip=0 + ), + """\ +Graph break (user stack suppressed due to duplicate graph break) in user code at test_error_messages.py:N +Graph Break Reason: Call to `torch._dynamo.graph_break()` + Explanation: User-inserted graph break. Message: None + Hint: Remove the `torch._dynamo.graph_break()` call. + + Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` + + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html""", + ) + + @make_logging_test(graph_breaks=True) + def test_nested_graph_break_different_call_sites_not_suppressed(self, records): + global inner, outer + + def inner(x): + x = x + 1 + torch._dynamo.graph_break() + return x + 2 + + @torch.compile(backend="eager") + def outer(x): + x = inner(x + 4) + 8 + return inner(x) + 16 + + with torch._dynamo.config.patch(nested_graph_breaks=True, verbose=False): + outer(torch.ones(3)) + + self.assertEqual( + len(records), + 2, + f"Expected 2 graph break messages (one per call site), got {len(records)}", + ) + + self.assertExpectedInline( + munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0), + """\ +Graph break in user code at test_error_messages.py:N +Graph Break Reason: Call to `torch._dynamo.graph_break()` + Explanation: User-inserted graph break. Message: None + Hint: Remove the `torch._dynamo.graph_break()` call. + + Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` + + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html +User code traceback: + File "test_error_messages.py", line N, in test_nested_graph_break_different_call_sites_not_suppressed + outer(torch.ones(3)) + File "test_error_messages.py", line N, in outer + x = inner(x + 4) + 8 + File "test_error_messages.py", line N, in inner + torch._dynamo.graph_break() +""", + ) + + self.assertExpectedInline( + munge_exc(records[1].getMessage(), suppress_suffix=True, skip=0), + """\ +Graph break in user code at test_error_messages.py:N +Graph Break Reason: Call to `torch._dynamo.graph_break()` + Explanation: User-inserted graph break. Message: None + Hint: Remove the `torch._dynamo.graph_break()` call. + + Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` + + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html +User code traceback: + File "test_error_messages.py", line N, in test_nested_graph_break_different_call_sites_not_suppressed + outer(torch.ones(3)) + File "test_error_messages.py", line N, in outer + return inner(x) + 16 + File "test_error_messages.py", line N, in inner + torch._dynamo.graph_break() +""", + ) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index dab2393a86259..f401b9d6178b9 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -4167,6 +4167,33 @@ def speculate(self) -> SpeculationEntry: self.instructions[self.instruction_pointer - 1], ) + def _make_frame_loc( + self, filename: str, lineno: Optional[int], fallback_lineno: int + ) -> tuple[str, int]: + if lineno is None or lineno < 0: + return (filename, fallback_lineno) + return (filename, lineno) + + def _get_frame_loc_chain( + self, frame_loc: tuple[str, int] + ) -> tuple[tuple[str, int], ...]: + frame_loc_chain_list: list[tuple[str, int]] = [] + + if config.nested_graph_breaks: + current_tx: Optional[InstructionTranslatorBase] = self.parent + while current_tx is not None: + parent_frame_loc = self._make_frame_loc( + current_tx.f_code.co_filename, + current_tx.lineno, + current_tx.f_code.co_firstlineno, + ) + frame_loc_chain_list.append(parent_frame_loc) + current_tx = current_tx.parent + + frame_loc_chain_list.reverse() + frame_loc_chain_list.append(frame_loc) + return tuple(frame_loc_chain_list) + def log_graph_break( self, code_options: dict[str, Any], @@ -4177,14 +4204,25 @@ def log_graph_break( user_stack = torch._guards.TracingContext.extract_stack() try: - frame_loc = (user_stack[-1].filename, user_stack[-1].lineno) + if config.nested_graph_breaks and self.parent is not None: + frame_loc = self._make_frame_loc( + self.f_code.co_filename, + self.lineno, + self.f_code.co_firstlineno, + ) + else: + frame_loc = self._make_frame_loc( + user_stack[-1].filename, + user_stack[-1].lineno, + 0, + ) except IndexError: # first instruction frame_loc = ( code_options["co_filename"], code_options["co_firstlineno"], ) - + frame_loc_chain = self._get_frame_loc_chain(frame_loc) stack_above_dynamo_formatted = "" if config.verbose: stack_above_dynamo = get_stack_above_dynamo() @@ -4229,7 +4267,7 @@ def log_graph_break( if ( graph_break_log.isEnabledFor(logging.DEBUG) and not explain - and graph_break_dup_warning_checker.add(frame_loc) + and graph_break_dup_warning_checker.add(frame_loc_chain) # type: ignore[arg-type] ): # This log line MUST contain the string "Graph break in user code", # This log line is exercised from From 38b5a5ec9cf9b1cc021fdee6f11782e1e46fcc83 Mon Sep 17 00:00:00 2001 From: Kushagra Rastogi Date: Fri, 21 Nov 2025 22:32:56 +0000 Subject: [PATCH 188/230] Narrow the return type annotation in 'VariableTracker::call_obj_hasattr' (#167983) Fixes #167982 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167983 Approved by: https://github.com/guilhermeleobas --- test/dynamo/test_misc.py | 30 +++++++++++++++++++++++++ torch/_dynamo/variables/base.py | 3 ++- torch/_dynamo/variables/constant.py | 18 +++++++++++---- torch/_dynamo/variables/dicts.py | 6 ++--- torch/_dynamo/variables/functions.py | 12 +++++----- torch/_dynamo/variables/iter.py | 2 +- torch/_dynamo/variables/lists.py | 16 ++++++------- torch/_dynamo/variables/nn_module.py | 4 +++- torch/_dynamo/variables/user_defined.py | 4 +++- 9 files changed, 70 insertions(+), 25 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 6348ba5638e05..781e95e0c7c95 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -9650,6 +9650,36 @@ def fn(): self.assertEqual(fn_out, compiled_out) self.assertFalse(fn_out) + def test_constant_hasattr_returns_bool(self): + """Test that hasattr on constant values properly returns boolean ConstantVariable.""" + + # Test various constant types + def fn(): + # String constant + s = "hello" + result1 = hasattr(s, "upper") # True + result2 = hasattr(s, "nonexistent") # False + + # Integer constant + i = 42 + result3 = hasattr(i, "bit_length") # True + result4 = hasattr(i, "fake_method") # False + + # Float constant + f = 3.14 + result5 = hasattr(f, "is_integer") # True + result6 = hasattr(f, "missing_attr") # False + + # Use all results to ensure they're compiled + return (result1, result2, result3, result4, result5, result6) + + compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn) + + fn_out = fn() + compiled_out = compiled_fn() + self.assertEqual(fn_out, compiled_out) + self.assertEqual(fn_out, (True, False, True, False, True, False)) + def test_torch_objects_as_keys(self): remap = {torch.float16: torch.float32} diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index 2d11a27bafac0..78f64d882055c 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -32,6 +32,7 @@ if TYPE_CHECKING: from ..codegen import PyCodegen from ..symbolic_convert import InstructionTranslator + from .constant import ConstantVariable class SourceType(Enum): @@ -440,7 +441,7 @@ def force_apply_to_var_sequence( for v in self.unpack_var_sequence(tx): fn(v) - def call_obj_hasattr(self, tx: Any, name: str) -> "VariableTracker": + def call_obj_hasattr(self, tx: Any, name: str) -> "ConstantVariable": unimplemented( gb_type="Unsupported hasattr call", context=f"call_obj_hasattr {self} {name}", diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 1e886c6ee7ad7..a8b6a38cb1e9d 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -8,7 +8,8 @@ import enum import operator -from typing import Any, Literal, Optional, TYPE_CHECKING, Union +from typing import Any, Literal, Optional, overload, TYPE_CHECKING, Union +from typing_extensions import override import torch from torch._dynamo.source import AttrSource, GetItemSource @@ -38,6 +39,14 @@ class ConstantVariable(VariableTracker): nested collections. """ + @overload + @staticmethod + def create(value: bool) -> "ConstantVariable": ... + + @overload + @staticmethod + def create(value: Any, **kwargs: Any) -> VariableTracker: ... + @staticmethod def create(value: Any, **kwargs: Any) -> VariableTracker: """ @@ -53,10 +62,10 @@ def create(value: Any, **kwargs: Any) -> VariableTracker: # Routing for supported collection literals. if isinstance(value, set): items = [ConstantVariable.create(x) for x in value] - return variables.SetVariable(items, **kwargs) + return variables.SetVariable(items, **kwargs) # type: ignore[arg-type] elif isinstance(value, frozenset): items = [ConstantVariable.create(x) for x in value] - return variables.FrozensetVariable(items, **kwargs) + return variables.FrozensetVariable(items, **kwargs) # type: ignore[arg-type] elif isinstance(value, slice): slice_args = (value.start, value.stop, value.step) slice_args_vars = tuple(ConstantVariable.create(arg) for arg in slice_args) @@ -266,9 +275,10 @@ def call_method( ) return super().call_method(tx, name, args, kwargs) + @override def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> "ConstantVariable": result = hasattr(self.value, name) return variables.ConstantVariable.create(result) diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 636875d85e54a..9b02465d5766e 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -806,7 +806,7 @@ def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTrack def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: # dict not allow setting arbitrary attributes. OrderedDict and # defaultdict allow arbitrary setattr, but not deletion of default attrs if any( @@ -905,7 +905,7 @@ def call_method( def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: if self.python_type() is types.MappingProxyType: return ConstantVariable.create(name in types.MappingProxyType.__dict__) return super().call_obj_hasattr(tx, name) @@ -1432,7 +1432,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: assert self.kv is not None if name in self.python_type().__dict__: return ConstantVariable.create(True) diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 1eaf58ee95dea..cd345759956be 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -383,7 +383,7 @@ def call_function( def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: result = False try: @@ -547,7 +547,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: result = hasattr(self.fn, name) return variables.ConstantVariable.create(result) @@ -784,7 +784,7 @@ def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: if name in self.python_type().__dict__: return ConstantVariable.create(True) return ConstantVariable.create(False) @@ -1436,7 +1436,7 @@ def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any: def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: if name == "__code__": return variables.ConstantVariable.create(hasattr(self, "code")) if name == "__defaults__": @@ -1753,7 +1753,7 @@ def call_function( def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: return variables.ConstantVariable.create(hasattr(self.value, name)) def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: @@ -2113,7 +2113,7 @@ def call_function( def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: # functools.partial uses slots, so attributes are constant return variables.ConstantVariable.create( hasattr(functools.partial(identity), name) diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 162ec02a9a9b7..c111dca9f2d68 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -262,7 +262,7 @@ def has_force_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> "VariableTracker": + ) -> "ConstantVariable": if name == "__iter__" or name == "__next__": return variables.ConstantVariable.create(True) return super().call_obj_hasattr(tx, name) diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 2ac355bd53417..cafbea5afde1e 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -480,7 +480,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: if self.python_type() is range: return variables.ConstantVariable.create(name in range.__dict__) return super().call_obj_hasattr(tx, name) @@ -932,7 +932,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: if self.python_type() is not list: return super().call_obj_hasattr(tx, name) return variables.ConstantVariable.create(hasattr([], name)) @@ -1089,7 +1089,7 @@ def call_method( def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: if self.python_type() is collections.deque: return variables.ConstantVariable.create(name in collections.deque.__dict__) return super().call_obj_hasattr(tx, name) @@ -1130,7 +1130,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: if self.python_type() is not tuple: return super().call_obj_hasattr(tx, name) return variables.ConstantVariable.create(hasattr((), name)) @@ -1292,7 +1292,7 @@ def get_item_dyn( def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: return variables.ConstantVariable.create(hasattr(torch.Size, name)) @@ -1540,7 +1540,7 @@ def check_and_create_method() -> Optional[VariableTracker]: def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: return variables.ConstantVariable.create( name in self.dynamic_attributes or hasattr(self.tuple_cls, name) ) @@ -1653,7 +1653,7 @@ def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: return variables.ConstantVariable.create(hasattr(iter([]), name)) def python_type(self) -> type: @@ -1726,7 +1726,7 @@ def call_method( def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: if self.python_type() is range_iterator: ri = iter(range(0)) return ConstantVariable(hasattr(ri, name)) diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index e754699d862ad..4b5198ffe8533 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -72,6 +72,8 @@ if TYPE_CHECKING: from torch._dynamo.symbolic_convert import InstructionTranslator + from .constant import ConstantVariable + def initialize_lazy_module(tx: "InstructionTranslator", mod, args, kwargs): """ @@ -230,7 +232,7 @@ def unpack_var_sequence(self, tx): def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> "VariableTracker": + ) -> "ConstantVariable": mod = tx.output.get_submodule(self.module_key) result = hasattr(mod, name) install_guard( diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index ec378a5512a01..fb676295535df 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -114,6 +114,8 @@ from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator + from .constant import ConstantVariable + def is_standard_setattr(val): return val in (object.__setattr__, BaseException.__setattr__) @@ -913,7 +915,7 @@ def is_standard_new(self): def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> "VariableTracker": + ) -> "ConstantVariable": if self.source: source = AttrSource(self.source, name) install_guard(source.make_guard(GuardBuilder.HASATTR)) From d419a2fe0dcb719d0068af4ac0fb4dec796f1e77 Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Fri, 21 Nov 2025 12:09:12 -0800 Subject: [PATCH 189/230] [inductor] find benchmark scripts for r2r determinism unit test (#168041) Previously the test will be skipped due to the benchmark script not found. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168041 Approved by: https://github.com/v0i0 --- test/inductor/test_deterministic.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/test/inductor/test_deterministic.py b/test/inductor/test_deterministic.py index 7e79100f4c053..d7e4313f5fe3b 100644 --- a/test/inductor/test_deterministic.py +++ b/test/inductor/test_deterministic.py @@ -1,6 +1,7 @@ # Owner(s): ["module: inductor"] import contextlib import os +import pathlib import subprocess import sys import tempfile @@ -23,6 +24,9 @@ ) +REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent + + @instantiate_parametrized_tests class DeterministicTest(TestCase): def setUp(self) -> None: @@ -121,9 +125,6 @@ def test_run2run_determinism(self, model_name, training_or_inference, precision) the current working directory. """ - if not os.path.exists("benchmarks/dynamo/huggingface.py"): - self.skipTest("Skip due to benchmarks/dynamo/huggingface.py not found.") - def _setup_env(env): env["TORCHINDUCTOR_FORCE_DISABLE_CACHES"] = "1" # disable autotune cache env["TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE"] = "0" @@ -137,7 +138,7 @@ def _setup_env(env): with tempfile.TemporaryDirectory() as tmpdir: saved_pkl = os.path.join(tmpdir, "saved.pkl") cmd = ( - f"{sys.executable} benchmarks/dynamo/huggingface.py --backend inductor" + f"{sys.executable} {REPO_ROOT}/benchmarks/dynamo/huggingface.py --backend inductor" + f" --{precision} --accuracy --only {model_name} --{training_or_inference}" + f" --disable-cudagraphs --save-model-outputs-to={saved_pkl}" ) @@ -153,7 +154,7 @@ def _setup_env(env): # self.assertTrue("pass" in out.stdout.decode()) cmd = ( - f"{sys.executable} benchmarks/dynamo/huggingface.py --backend inductor" + f"{sys.executable} {REPO_ROOT}/benchmarks/dynamo/huggingface.py --backend inductor" + f" --{precision} --accuracy --only {model_name} --{training_or_inference}" + f" --disable-cudagraphs --compare-model-outputs-with={saved_pkl}" ) From d4493c550a67de3501dfd17724e0387d2e5c91a6 Mon Sep 17 00:00:00 2001 From: tianrengao Date: Fri, 21 Nov 2025 23:07:37 +0000 Subject: [PATCH 190/230] Add dynamic config generation for custom op autotuning (#167193) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Add dynamic config generation for custom op autotuning Enable shape-aware autotuning by adding config_generator parameter to register_custom_op_autotuning, allowing configs to be generated dynamically based on input properties: shapes, dtype, etc.., instead of using static pre-defined configs. User provided `config_generator` should take in a dict of fake tensors and generate CustomOpConfigs. Different input shapes often benefit from different config values (e.g., different k_splits for decompose_k based on K dimension). Static configs require manually enumerating all possibilities, which doesn't scale and isn't optimal across varying shapes. ### Proposed API Change ```python def register_custom_op_autotuning( custom_op, configs: Optional[list[CustomOpConfig]] = None, # Static configs config_generator: Optional[ Callable[[dict[str, torch.Tensor]], list[CustomOpConfig]] ] = None,, # NEW for dynamic configs ... ) ``` `config_generator`: Provided by user. Takes a dict mapping parameter names to shapes, returns list of CustomOpConfig The config_generator is mutually exclusive with static configs. ## Example usage: User define a config generator to pass into register_custom_op_autotuning function. ```python def generate_k_split_configs(fake_tensors: dict[str, torch.Tensor]) -> list[CustomOpConfig]: """Generate k_split configs based on input matrix dimensions.""" m, k = fake_tensors["a"].shape[-2:] _, n = fake_tensors["b"].shape[-2:] # Compute valid k_splits for the given k dimension k_splits = get_k_splits(m, n, k) return [CustomOpConfig(k_splits=k) for k in k_splits] ``` #### Determining K choices based on input shapes: ```python # Small K dimension (256, 4096) × (4096, 1024) # → Generates 4 configs: k_splits = [2, 4, 8, 16] # Large K dimension (256, 65536) × (65536, 1024) # → Generates 6 configs: k_splits = [2, 4, 8, 16, 32, 64] ``` #### Pass into custom op autotuning ```python register_custom_op_autotuning( matmul_op, config_generator=generate_k_split_configs, # Dynamic generation input_gen_fns={...} ) ``` ----- ### Implementation Added _generate_dynamic_configs() function during `autotuning_lowering` to extract tensor parameters shapes(infer from IR nodes), and call user's config generator. Old static configs are preserved but mutually exclusive to dynamic configs. ### Tests Updated `test_decompose_k_custom_op_autotune_dynamic_config_for_input_shape` to use dynamic config generation with get_k_splits. Now different tensor shapes have different k candidates. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167193 Approved by: https://github.com/eellison --- test/inductor/test_custom_op_autotune.py | 121 ++++++++++++++--------- torch/_inductor/kernel/custom_op.py | 114 ++++++++++++++++++--- 2 files changed, 170 insertions(+), 65 deletions(-) diff --git a/test/inductor/test_custom_op_autotune.py b/test/inductor/test_custom_op_autotune.py index c148c69468902..3c50b4d881f8f 100644 --- a/test/inductor/test_custom_op_autotune.py +++ b/test/inductor/test_custom_op_autotune.py @@ -217,26 +217,38 @@ def _(input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8): ) def _create_decompose_k_inputs(self, m=256, k=65536, n=1024): - """Create test inputs for decompose_k matrix multiplication - divisible by all k_splits values.""" + """Create test inputs for decompose_k matrix multiplication. + Tensor a: Input matrix of shape (m, k) + Tensor b: Weight matrix of shape (k, n) + Tensor bias: Bias vector of shape (n,) + """ # Ensure k is divisible by all k_splits values: [2, 32, 64, 128, 256] k = ((k + 255) // 256) * 256 # Round up to nearest multiple of 256 a = torch.randn(m, k, device=self.device, dtype=self.dtype, requires_grad=False) b = torch.randn(k, n, device=self.device, dtype=self.dtype, requires_grad=False) - return a, b + bias = ( + torch.randn(n, device=self.device, dtype=self.dtype, requires_grad=False) + * 0.1 + ) + return a, b, bias @skipIfXpu - def test_decompose_k_custom_op_autotune(self): - """Test decompose_k autotuning with epilogue fusion (matmul + bias + relu + scale). - - Validates that the custom op encapsulates the entire fused operation with parametric - tuning for k_splits values controlling how the K dimension is decomposed. + def test_decompose_k_custom_op_autotune_dynamic_config_for_input_shape(self): + """Test decompose_k autotuning with with epilogue fusion(matmul+bias+relu+scale) and + dynamic config generation based on matmul input shapes. + + Validates that the custom op encapsulates the entire fused operation (matmul + bias + + relu + scale) with parametric tuning for k_splits values controlling how the K + dimension is decomposed. The config generator receives correct parameter names and + shapes, dynamically generates different k_split configs using get_k_splits for + different input shapes, and produces correct results matching the reference implementation. """ - test_op_name = f"test_lib::matmul_relu_epilogue_{id(self)}" + test_op_name = f"test_lib::matmul_relu_epilogue_dynamic_{id(self)}" def decompose_k_implementation( a: torch.Tensor, b: torch.Tensor, k_splits: int = 4 ) -> torch.Tensor: - """Matrix multiply with k-way decomposition - Python implementation.""" + """Matrix multiply with k-way decomposition.""" m = a.shape[0] n = b.shape[1] k = a.shape[1] @@ -254,7 +266,7 @@ def decompose_k_implementation( return torch.sum(result, dim=0) # [m, n] @torch.library.custom_op(test_op_name, mutates_args=()) - def matmul_relu_epilogue_op( + def matmul_relu_epilogue_dynamic_op( a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor, k_splits: int = 4 ) -> torch.Tensor: """Matmul with decompose_k + bias + relu + scale (complete epilogue fusion).""" @@ -264,23 +276,28 @@ def matmul_relu_epilogue_op( scaled = activated * 2.0 return scaled - @matmul_relu_epilogue_op.register_fake + @matmul_relu_epilogue_dynamic_op.register_fake def _(a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor, k_splits: int = 4): return torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=a.dtype) - # Register autotuning with different k_splits values + # Define dynamic config generator using get_k_splits + def generate_k_split_configs( + fake_tensors: dict[str, torch.Tensor], + ) -> list[CustomOpConfig]: + """Generate k_split configs based on input matrix dimensions.""" + from torch._inductor.utils import get_k_splits + + m, k = fake_tensors["a"].shape[-2:] + _, n = fake_tensors["b"].shape[-2:] + + k_splits_list = get_k_splits(m, n, k) + + return [CustomOpConfig(k_splits=k) for k in k_splits_list] + register_custom_op_autotuning( - matmul_relu_epilogue_op, - configs=[ - CustomOpConfig(k_splits=2), - CustomOpConfig(k_splits=4), - CustomOpConfig(k_splits=8), - CustomOpConfig(k_splits=16), - CustomOpConfig(k_splits=32), - CustomOpConfig(k_splits=64), - CustomOpConfig(k_splits=128), - ], - name="matmul_relu_epilogue_autotuned", + matmul_relu_epilogue_dynamic_op, + config_generator=generate_k_split_configs, + name="matmul_relu_epilogue_dynamic_autotuned", input_gen_fns={ "a": lambda fake_tensor: torch.randn_like( fake_tensor, device=self.device @@ -297,38 +314,44 @@ def _(a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor, k_splits: int = 4): }, ) - # Create test inputs - a, b = self._create_decompose_k_inputs() - bias = torch.randn(b.shape[1], device=self.device, dtype=self.dtype) * 0.1 + # Test multiple shapes to verify dynamic config generation + test_shapes = [ + (256, 16384, 1024), + (256, 65536, 1024), + ] - # Compile the model using the custom op - @torch.compile - def test_model(a, b, bias): - return matmul_relu_epilogue_op(a, b, bias) + for m, k, n in test_shapes: + # Use helper function to create test inputs + a, b, bias = self._create_decompose_k_inputs(m, k, n) - torch._dynamo.reset() + @torch.compile + def test_model(a, b, bias): + return matmul_relu_epilogue_dynamic_op(a, b, bias) - with config.patch( - max_autotune=True, - benchmark_fusion=True, - ): - compiled_result = test_model(a, b, bias) + torch._dynamo.reset() - def reference_model(a, b, bias): - matmul_result = a @ b - biased = matmul_result + bias - activated = torch.relu(biased) - scaled = activated * 2.0 - return scaled + with config.patch( + max_autotune=True, + benchmark_fusion=True, + ): + compiled_result = test_model(a, b, bias) - expected = reference_model(a, b, bias) + def reference_model(a, b, bias): + matmul_result = a @ b + biased = matmul_result + bias + activated = torch.relu(biased) + scaled = activated * 2.0 + return scaled - torch.testing.assert_close( - compiled_result, - expected, - rtol=2e-1, - atol=5e-1, - ) + expected = reference_model(a, b, bias) + + torch.testing.assert_close( + compiled_result, + expected, + rtol=2e-1, + atol=5e-1, + msg=f"Failed for shape ({m}, {k}, {n})", + ) @skipIfXpu def test_multi_parameter_tuning(self): diff --git a/torch/_inductor/kernel/custom_op.py b/torch/_inductor/kernel/custom_op.py index 23878f757cc5e..12cc68dcb9844 100644 --- a/torch/_inductor/kernel/custom_op.py +++ b/torch/_inductor/kernel/custom_op.py @@ -350,9 +350,51 @@ def autotune_custom_op( return selected_result +def _generate_dynamic_configs( + tensor_inputs: list[Buffer], + config_generator: Callable[[dict[str, torch.Tensor]], list[CustomOpConfig]], + default_impl: Callable[..., Any], + operation_name: str, +) -> list[CustomOpConfig]: + """Generate configs dynamically based on input tensors at lowering time.""" + import inspect + + sig = inspect.signature(default_impl) + param_names = list(sig.parameters.keys()) + + with V.fake_mode: + fake_tensors = [] + for inp in tensor_inputs: + raw_shape = inp.get_size() + concrete_shape = V.graph.sizevars.size_hints( + raw_shape, fallback=config.unbacked_symint_fallback + ) + fake_tensor = torch.empty( + concrete_shape, dtype=inp.get_dtype(), device=inp.get_device() + ) + fake_tensors.append(fake_tensor) + + fake_tensors_dict = dict(zip(param_names, fake_tensors)) + + configs = config_generator(fake_tensors_dict) + + if not isinstance(configs, (list, tuple)): + raise TypeError( + f"config_generator must return a list or tuple of CustomOpConfig, " + f"got {type(configs)}" + ) + if not configs: + raise ValueError(f"config_generator returned empty list for {operation_name}. ") + + return list(configs) + + def register_custom_op_autotuning( custom_op: torch._library.custom_ops.CustomOpDef, - configs: Union[list[CustomOpConfig], list[Callable[..., Any]]], + configs: Optional[Union[list[CustomOpConfig], list[Callable[..., Any]]]] = None, + config_generator: Optional[ + Callable[[dict[str, torch.Tensor]], list[CustomOpConfig]] + ] = None, name: Optional[str] = None, input_gen_fns: Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]] = None, ) -> None: @@ -361,11 +403,15 @@ def register_custom_op_autotuning( Args: custom_op: Custom operation (decorated function from @torch.library.custom_op) - configs: List of CustomOpConfig objects + configs: List of CustomOpConfig objects for static inputs. Mutually exclusive with config_generator. + config_generator: Dynamic config generator function that takes a dict mapping + parameter names to fake tensors, and returns list[CustomOpConfig] + based on input tensor properties. Mutually exclusive with configs. name: Operation name (default: "{op_name}_autotuned") input_gen_fns: Custom input generators for benchmarking Examples: + # Static configs @torch.library.custom_op("mylib::attention", mutates_args=()) def my_attention(query, key, value, head_dim=32): ... @@ -383,6 +429,20 @@ def my_attention(query, key, value, head_dim=32): "value": lambda fake: torch.randn_like(fake, device='cuda'), }, ) + + # Dynamic config generation based on input tensor properties + def generate_k_split_configs(fake_tensors: dict[str, torch.Tensor]) -> list[CustomOpConfig]: + # Access tensor shapes, dtypes, devices, etc. + m, k = fake_tensors["mat1"].shape + _, n = fake_tensors["mat2"].shape + k_splits = ... # compute possible k splits based on tensor properties + return [CustomOpConfig(k_splits=k) for k in k_splits] + + register_custom_op_autotuning( + matmul_decomposeK_op, + config_generator=generate_k_split_configs, + input_gen_fns={...}, + ) """ from torch._library.custom_ops import CustomOpDef @@ -392,23 +452,36 @@ def my_attention(query, key, value, head_dim=32): f"got {type(custom_op)}." ) + # Validate configs and config_generator are mutually exclusive + if configs is not None and config_generator is not None: + raise ValueError( + "Cannot specify both 'configs' and 'config_generator'. " + "Use 'config_generator' for shape-dependent configs." + ) + + if configs is None and config_generator is None: + raise ValueError("Must specify either 'configs' or 'config_generator'") + op_overload = custom_op._opoverload default_impl = custom_op._init_fn - if not isinstance(configs, (list, tuple)): - raise TypeError(f"configs must be a list or tuple, got {type(configs)}") + # Process and validate static configs at registration time + static_configs = None + if configs is not None: + if not isinstance(configs, (list, tuple)): + raise TypeError(f"configs must be a list or tuple, got {type(configs)}") - processed_configs = [] - for cfg in configs: - if isinstance(cfg, CustomOpConfig): - processed_configs.append(cfg) - else: - raise TypeError( - f"Each config must be a CustomOpConfig object, got {type(cfg)}" - ) + static_configs = [] + for cfg in configs: + if isinstance(cfg, CustomOpConfig): + static_configs.append(cfg) + else: + raise TypeError( + f"Each config must be a CustomOpConfig object, got {type(cfg)}" + ) - if not processed_configs: - raise ValueError("At least one config must be provided") + if not static_configs: + raise ValueError("At least one config must be provided") if name is None: name = f"{op_overload._name}_autotuned" @@ -419,11 +492,20 @@ def autotuning_lowering(*args: Any, **kwargs: Any) -> Any: # Extract tensor inputs and non-tensor parameters (runtime kwargs) tensor_inputs, runtime_kwargs = _extract_tensor_inputs(args, kwargs) - # Prepare decompositions and kwargs by merging config params with runtime kwargs + # Get configs: either generate dynamically or use static configs + if config_generator is not None: + configs_to_use = _generate_dynamic_configs( + tensor_inputs, config_generator, default_impl, name + ) + else: + assert static_configs is not None + configs_to_use = static_configs + + # Prepare decompositions and kwargs for autotuning decompositions = [] non_tensor_args = [] - for cfg in processed_configs: + for cfg in configs_to_use: decomp = cfg.get_decomposition(default_impl=default_impl) decompositions.append(decomp) From 6c8c03c96183ed565d6d9766cbd994a6c4c6196d Mon Sep 17 00:00:00 2001 From: Yavuz Yetim Date: Fri, 21 Nov 2025 23:18:37 +0000 Subject: [PATCH 191/230] Fix aot_compile typing. (#168320) Summary: Any doesn't capture tuple of different number of elements. Test Plan: Unit tests Differential Revision: D87598839 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168320 Approved by: https://github.com/cyyever, https://github.com/yangw-dev --- torch/_export/__init__.py | 2 +- torch/_inductor/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index d653db0c23a74..7a0a87dab19dc 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -79,7 +79,7 @@ def aot_compile_warning(): def aot_compile( f: Callable, - args: tuple[Any], + args: tuple[Any, ...], kwargs: Optional[dict[str, Any]] = None, *, dynamic_shapes: Optional[dict[str, Any]] = None, diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index 810649e7b7b25..8e6fde9280c4a 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -272,7 +272,7 @@ def aoti_load_package( def aot_compile( gm: torch.fx.GraphModule, - args: tuple[Any], + args: tuple[Any, ...], kwargs: Optional[dict[str, Any]] = None, *, options: Optional[dict[str, Any]] = None, From 6d22819c04218f285feae48351c8071aa0ba4688 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Sat, 22 Nov 2025 00:36:03 +0000 Subject: [PATCH 192/230] [Inductor] Properly enlarge XBLOCK/set num_warps=1 for B200 inner persistent reductions (#168335) Summary: Improve inner persistent reductions with large M for B200, slight improvement on H100 Test Plan: Tested on internal shapes, r <= 512, large M B200 LayerNorm fwd with new heuristics: 2855 GB/s -> 5429 GB/s RMSNorm average went from 5039 GB/s -> 5677 GB/s H100: LayerNorm fwd 1600 GB/s -> 2000 GB/s RMSNorm fwd parity, ~2k GB/s Differential Revision: D87603033 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168335 Approved by: https://github.com/v0i0, https://github.com/shunting314 --- torch/_inductor/runtime/triton_heuristics.py | 21 ++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 8a84d2432ceca..bff86bf3ced5e 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2443,6 +2443,7 @@ def triton_config_reduction( waves_per_eu=None, dynamic_scale_rblock=True, reduction_hint=None, + min_num_warps=None, ) -> Config: """ Construct a reduction triton config with some adjustment heuristics @@ -2478,7 +2479,12 @@ def total_numel() -> int: num_warps = total_numel() // 128 max_num_warps = 16 if r <= 8192 else 32 - num_warps = _num_warps( + if min_num_warps is not None: + _num_warps_func = functools.partial(_num_warps, min_num_warps=min_num_warps) + else: + _num_warps_func = _num_warps + + num_warps = _num_warps_func( num_warps, max_num_warps=max_num_warps, register_intensive=register_intensive ) @@ -3293,9 +3299,6 @@ def _persistent_reduction_configs( ): xnumel = size_hints["x"] rnumel = get_total_reduction_numel(size_hints) - loads_and_stores = inductor_meta.get("num_load", 0) + inductor_meta.get( - "num_store", 0 - ) MAX_PERSISTENT_BLOCK_NUMEL = 4096 @@ -3366,12 +3369,11 @@ def _persistent_reduction_configs( # TODO(jansel): we should be able to improve these heuristics elif not max_autotune_enabled: # Do not filter configs when tuning if reduction_hint == ReductionHint.INNER and rnumel >= 256: - if rnumel > 1024: + if rnumel > 1024 or xnumel // 8 < 128 or inductor_meta.get("RSPLIT_SIZE"): configs = configs[:1] else: - x_block = 8 - if xnumel // x_block < 128 or loads_and_stores >= 5: - x_block = 1 + num_warps, min_num_warps = 1, 1 + x_block = min(1024 // rnumel, 8) configs = [ triton_config_reduction( @@ -3379,6 +3381,9 @@ def _persistent_reduction_configs( x_block, rnumel, register_intensive=True, + num_warps=num_warps, + min_num_warps=min_num_warps, + reduction_hint=reduction_hint, ) ] From 976abd8710ccbd99c8f3e31cd50dabd15a716631 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Sat, 22 Nov 2025 00:42:07 +0000 Subject: [PATCH 193/230] [Inductor] Mix Order Reduction Heuristics (#168361) Differential Revision: D87608892 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168361 Approved by: https://github.com/v0i0, https://github.com/shunting314 --- torch/_inductor/codegen/simd.py | 5 ++--- torch/_inductor/runtime/runtime_utils.py | 6 ++++++ torch/_inductor/runtime/triton_heuristics.py | 20 +++++++++++--------- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 6bfe27cdc6f99..1b58503690e98 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -47,7 +47,7 @@ from ..optimize_indexing import indexing_dtype_strength_reduction from ..runtime.coordinate_descent_tuner import CoordescTuner from ..runtime.hints import DeviceProperties -from ..runtime.runtime_utils import green_text, next_power_of_2, yellow_text +from ..runtime.runtime_utils import green_text, last_power_of_2, yellow_text from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse from ..utils import ( cache_property_on_self, @@ -1625,7 +1625,7 @@ def _pick_split_size(): # split_size is decided based on hint numel_hint = V.graph.sizevars.size_hint(numel) - split_size = max(next_power_of_2(numel_hint // estimated_num_splits), 16) + split_size = max(last_power_of_2(numel_hint // estimated_num_splits), 16) split_size = min(split_size, 128) return split_size @@ -1678,7 +1678,6 @@ def _bench(candidate_split_size): split_size, 8, ) - # print(f"Autotuning pick split size {split_size}") kernel, ws_name, src_code = self._generate_kernel_code_for_mix_order_reduction( kernel_features, diff --git a/torch/_inductor/runtime/runtime_utils.py b/torch/_inductor/runtime/runtime_utils.py index 169e105d10b03..b4e66378e85ae 100644 --- a/torch/_inductor/runtime/runtime_utils.py +++ b/torch/_inductor/runtime/runtime_utils.py @@ -47,6 +47,12 @@ def next_power_of_2(n: int) -> int: return n +def last_power_of_2(n: int) -> int: + """Return the largest power of 2 less than or equal to n""" + next_pow2 = next_power_of_2(n) + return next_pow2 // 2 if next_pow2 > n else next_pow2 + + def get_num_bytes(*args: torch.Tensor, num_in_out_args: int = 0) -> int: """ Return the total number of bytes the arguments of tensor type takes. diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index bff86bf3ced5e..9a1783811d75c 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -3433,21 +3433,23 @@ def persistent_reduction( if inductor_meta.get("RSPLIT_SIZE"): new_configs = [] + rsplit_size = inductor_meta.get("RSPLIT_SIZE") + rnumel_hint = size_hints["r0_"] + min_x_block = 1 + if rnumel_hint <= 512: + min_x_block = 4 + x_block = min(max(rsplit_size // 32, min_x_block), 16) for c in configs: - c.kwargs["RSPLIT_SIZE"] = inductor_meta.get("RSPLIT_SIZE") - - c.kwargs["NUM_STAGES"] = 1 - + c.kwargs["RSPLIT_SIZE"] = rsplit_size # small XBLOCK to use less registers/smem - c.kwargs["XBLOCK"] = ( - torch._inductor.config.triton.mix_order_reduction_initial_xblock - ) + c.kwargs["XBLOCK"] = x_block - rnumel_hint = size_hints["r0_"] + num_iters = rsplit_size // x_block + c.kwargs["NUM_STAGES"] = min(max(num_iters // 4, 1), 3) if rnumel_hint <= 1024: c.num_warps //= 2 - c.num_warps = max(c.num_warps, 2) + c.num_warps = max(c.num_warps, 1) new_configs.append(c) # less warps so potentially each sm can run more thread blocks From 5e4ca87294d28228c4f4630ff7c29a963f051ba9 Mon Sep 17 00:00:00 2001 From: Yarong Mu Date: Sat, 22 Nov 2025 01:03:12 +0000 Subject: [PATCH 194/230] feat(pallas): add Pallas TPU backend (#167774) This commit introduces a new Pallas backend for TPU. This backend allows running PyTorch code on TPU using the Pallas kernel language, without a dependency on `torch_xla`. The backend is enabled by setting the `PALLAS_TARGET_TPU=1` environment variable. When this variable is set, the Pallas backend will generate code that moves data from CPU tensors to the TPU, executes the Pallas kernel on the TPU, and moves the results back to the CPU. The implementation includes: - A new `is_tpu` flag in the Pallas codegen to trigger the TPU-specific logic. - Data movement using `jax.dlpack.from_dlpack` and `jax.device_get`. - A new `has_tpu_pallas` utility function to check for the availability of a Pallas-on-TPU environment. - Validation to ensure that a TPU is available when `PALLAS_TARGET_TPU` is set. - A new test suite for the Pallas TPU backend, including a test for the validation logic. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167774 Approved by: https://github.com/oulgen, https://github.com/voznesenskym, https://github.com/jansel --- scripts/install_torchinductor_tpu_deps.sh | 9 ++++ test/inductor/test_pallas.py | 26 ++++++++++- test/inductor/test_torchinductor.py | 3 +- torch/_inductor/codegen/pallas.py | 57 +++++++++++++++++++---- torch/_inductor/config.py | 7 +++ torch/testing/_internal/inductor_utils.py | 4 +- torch/utils/_pallas.py | 30 ++++++------ 7 files changed, 106 insertions(+), 30 deletions(-) create mode 100755 scripts/install_torchinductor_tpu_deps.sh diff --git a/scripts/install_torchinductor_tpu_deps.sh b/scripts/install_torchinductor_tpu_deps.sh new file mode 100755 index 0000000000000..1fe374a8b7bfb --- /dev/null +++ b/scripts/install_torchinductor_tpu_deps.sh @@ -0,0 +1,9 @@ +#!/bin/bash +# +# Install dependencies for TorchInductor on TPU. + +# Install dependencies from requirements.txt first +pip install -r requirements.txt + +# Install JAX nightly builds and other TPU dependencies +pip install --pre -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html jax==0.8.0.dev20251013 jaxlib==0.8.0.dev20251013 libtpu==0.0.25.dev20251012+nightly tpu-info==0.6.0 setuptools==78.1.0 # @lint-ignore diff --git a/test/inductor/test_pallas.py b/test/inductor/test_pallas.py index c4aabd0375090..369013e1670b6 100644 --- a/test/inductor/test_pallas.py +++ b/test/inductor/test_pallas.py @@ -3,6 +3,7 @@ import re import sys import unittest +from unittest import mock import torch import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools @@ -12,6 +13,7 @@ from torch._inductor.utils import run_and_get_code from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS from torch.testing._internal.inductor_utils import HAS_PALLAS +from torch.utils._pallas import has_cuda_pallas, has_jax_tpu_backend from torch.utils._triton import has_triton @@ -746,7 +748,7 @@ def fn(x): self.assertEqual(result, expected) -@unittest.skipUnless(HAS_PALLAS, "requires jax and pallas") +@unittest.skipUnless(has_cuda_pallas(), "requires jax and pallas") class PallasTestsCUDA(PallasTestsMixin, TestCase): DEVICE = "cuda" @@ -756,6 +758,28 @@ class PallasTestsCPU(PallasTestsMixin, TestCase): DEVICE = "cpu" +@unittest.skipUnless(has_jax_tpu_backend(), "requires JAX TPU backend") +@config.patch({"_debug_cpu_to_tpu_pallas": True}) +class PallasTestsTPU(PallasTestsMixin, TestCase): + DEVICE = "cpu" + + @mock.patch("torch._inductor.codegen.pallas.has_tpu_pallas", return_value=False) + def test_tpu_not_available_raises_error(self, mock_has_tpu_pallas): + def fn(a, b): + return a + b + + with self.assertRaisesRegex( + RuntimeError, + ( + "PALLAS_TARGET_TPU is set, but no TPU device was found. " + "Please make sure that you have a TPU available and that JAX is configured correctly." + ), + ): + torch.compile(fn, backend="inductor", options={"cpu_backend": "pallas"})( + torch.randn(16), torch.randn(16) + ) + + if test_torchinductor.HAS_CPU and HAS_PALLAS: make_pallas(test_torchinductor.SweepInputsCpuTest) # make_pallas(test_torchinductor.CpuTests) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index a120c5b394f01..b1cea5eac77d7 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -127,13 +127,14 @@ ) from torch._inductor.utils import has_torchvision_roi_align from torch.testing._internal.common_utils import slowTest -from torch.testing._internal.inductor_utils import ( +from torch.testing._internal.inductor_utils import ( # noqa: F401 clone_preserve_strides_offset, GPU_TYPE, HAS_CPU, HAS_GPU, HAS_MPS, HAS_MULTIGPU, + HAS_TPU, IS_BIG_GPU, requires_gpu, RUN_CPU, diff --git a/torch/_inductor/codegen/pallas.py b/torch/_inductor/codegen/pallas.py index 512cf89795b0d..23bf0e1bbe31a 100644 --- a/torch/_inductor/codegen/pallas.py +++ b/torch/_inductor/codegen/pallas.py @@ -7,6 +7,7 @@ import torch # noqa: TC001 from torch.utils._ordered_set import OrderedSet +from torch.utils._pallas import has_tpu_pallas from .. import config from ..runtime.runtime_utils import torch_dtype_to_jax @@ -886,6 +887,17 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[ove kernel_name = name or "" interpret_is_cpu = V.graph.get_current_device_or_throw().type == "cpu" + is_tpu = torch._inductor.config._debug_cpu_to_tpu_pallas + if is_tpu: + if not torch._inductor.config.pallas_take_first_jax_device_only: + raise RuntimeError( + "Pallas backend currently only supports using the first JAX device." + ) + if not has_tpu_pallas(): + raise RuntimeError( + "PALLAS_TARGET_TPU is set, but no TPU device was found. " + "Please make sure that you have a TPU available and that JAX is configured correctly." + ) interpret_literal = "True" if interpret_is_cpu else "False" # For GPU (Triton backend), import pltriton for masked loads/stores @@ -1065,19 +1077,38 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[ove if alias_params: code.writeline("# Convert Torch -> JAX for donated outputs") for alias_name in alias_params: - code.writeline( - f"{alias_name}_jax = jax.dlpack.from_dlpack({alias_name})" - ) + # TODO: The `jax.device_put` path is a temporary workaround for a Mosaic compiler bug + # that occurs with DLPack. Once TorchTPU provides a direct method for placing a + # `torch.Tensor` on a TPU device, this should be reverted to use the + # `jax.dlpack.from_dlpack` path. + if is_tpu: + code.writeline( + f"{alias_name}_jax = jax.device_put({alias_name}.cpu().numpy(), device=jax.devices('tpu')[0])" + ) + else: + code.writeline( + f"{alias_name}_jax = jax.dlpack.from_dlpack({alias_name})" + ) code.writeline("# Convert Torch -> JAX for in-place tensors") for ptr in pointer_tail: if ptr.startswith("in_out_ptr"): - code.writeline(f"{ptr}_jax = jax.dlpack.from_dlpack({ptr})") + if is_tpu: + code.writeline( + f"{ptr}_jax = jax.device_put({ptr}.cpu().numpy(), device=jax.devices('tpu')[0])" + ) + else: + code.writeline(f"{ptr}_jax = jax.dlpack.from_dlpack({ptr})") code.writeline("# Convert Torch -> JAX for inputs") for ptr in pointer_tail: if ptr.startswith("in_ptr"): - code.writeline( - f"{ptr}_jax = jax.dlpack.from_dlpack({ptr}.contiguous())" - ) + if is_tpu: + code.writeline( + f"{ptr}_jax = jax.device_put({ptr}.cpu().numpy(), device=jax.devices('tpu')[0])" + ) + else: + code.writeline( + f"{ptr}_jax = jax.dlpack.from_dlpack({ptr}.contiguous())" + ) code.writeline("# Prepare output metadata from PyTorch tensor") code.writeline( @@ -1116,9 +1147,15 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[ove ) for idx in copy_output_indices: name = output_params[idx] - code.writeline( - f"{name}.copy_(torch.from_dlpack(result_values[{idx}]))" - ) + if is_tpu: + code.writeline( + f"res_cpu = jax.device_get(result_values[{idx}])" + ) + code.writeline(f"{name}.copy_(torch.from_dlpack(res_cpu))") + else: + code.writeline( + f"{name}.copy_(torch.from_dlpack(result_values[{idx}]))" + ) return code.getvalue() diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index f3592b93469cd..645927686232b 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1211,6 +1211,13 @@ def decide_compile_threads() -> int: enable_autograd_for_aot: bool = False +_debug_cpu_to_tpu_pallas: bool = Config( + env_name_force="PALLAS_TARGET_TPU", default=False +) +pallas_take_first_jax_device_only: bool = Config( + env_name_force="PALLAS_TAKE_FIRST_JAX_DEVICE_ONLY", default=True +) + def get_worker_log_path() -> Optional[str]: log_loc = None diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index 6bd34c812d641..cda1908a3a340 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -34,7 +34,7 @@ ) from torch.fx.experimental.proxy_tensor import make_fx from torch.utils._helion import has_helion -from torch.utils._pallas import has_pallas +from torch.utils._pallas import has_pallas, has_tpu_pallas from torch.utils._triton import has_triton from torch.utils._config_module import ConfigModule from torch.testing._internal.common_device_type import ( @@ -105,6 +105,8 @@ def test_cpu(): getattr(x, "device_type", "") == "cpu" for x in _desired_test_bases ) +HAS_TPU = has_tpu_pallas() + def _check_has_dynamic_shape( self: TestCase, diff --git a/torch/utils/_pallas.py b/torch/utils/_pallas.py index 2d93e7f32c58e..63ef22be49cf4 100644 --- a/torch/utils/_pallas.py +++ b/torch/utils/_pallas.py @@ -72,6 +72,18 @@ def has_jax_tpu_backend() -> bool: return False +@functools.cache +def has_tpu_pallas() -> bool: + """Checks for a full Pallas-on-TPU environment.""" + return has_pallas_package() and has_jax_tpu_backend() + + +@functools.cache +def has_cuda_pallas() -> bool: + """Checks for a full Pallas-on-CUDA environment.""" + return has_pallas_package() and torch.cuda.is_available() and has_jax_cuda_backend() + + @functools.cache def has_pallas() -> bool: """ @@ -82,20 +94,4 @@ def has_pallas() -> bool: - Pallas (jax.experimental.pallas) available - A compatible backend (CUDA or TPU) is available in both PyTorch and JAX. """ - if not has_pallas_package(): - return False - - # Check for is CUDA is available or if JAX has GPU/CUDA backend - has_cuda = torch.cuda.is_available() and has_jax_cuda_backend() - - # Check for TPU backend - has_tpu_torch = False - try: - import torch_xla.core.xla_model as xm - - has_tpu_torch = xm.xla_device_count() > 0 - except ImportError: - pass - has_tpu = has_tpu_torch and has_jax_tpu_backend() - - return has_cuda or has_tpu + return has_cuda_pallas() or has_tpu_pallas() From 57d4e492381f75d6f31592287b834a46c5d2bd04 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Fri, 21 Nov 2025 12:28:38 -0800 Subject: [PATCH 195/230] [inductor] Fix a user-defined Triton kernel output + .cpu() correctness issue (#168281) Summary: Fix https://github.com/pytorch/pytorch/issues/168181. If a buffer is marked as mutated (in this particular issue buffer as a user-defined Triton kernel output), ir.DeviceCopy should not optimize it to a target device constant. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168281 Approved by: https://github.com/eellison --- test/inductor/test_triton_kernels.py | 14 ++++++++++++++ torch/_inductor/ir.py | 12 ++++++++++++ torch/_inductor/lowering.py | 11 +---------- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index eee4dba7f2772..1f205eecec1bf 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -1216,6 +1216,20 @@ def f(x, y): compiled_out = torch.compile(f)(x, y) self.assertEqual(compiled_out, eager_out) + @requires_gpu + def test_triton_kernel_to_cpu(self): + def f(x, y): + out = torch.zeros_like(x) + add_kernel[(1,)](x, y, out, 16, 16) + out_cpu = out.cpu() + 1 + return out_cpu + + x = torch.randn(4, 4, device=GPU_TYPE) + y = torch.randn(4, 4, device=GPU_TYPE) + eager_out = f(x, y) + compiled_out = torch.compile(f)(x, y) + self.assertEqual(compiled_out, eager_out) + @requires_gpu def test_triton_kernel_out_of_order(self): @triton.jit diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index d13182e717494..0bd97ec240dc5 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -530,6 +530,16 @@ def get_symbolic_inputs(inputs: Sequence[IRNode]) -> list[Expr]: return list(sym_vars) +def try_get_name(x): + if isinstance(x, TensorBox): + x = x.data + if isinstance(x, BaseView): + x = x.unwrap_view() + if isinstance(x, StorageBox): + x = x.data + return x.get_name() if isinstance(x, Buffer) else None + + class IRNode: """Base class for all intermediate representation (IR) nodes in TorchInductor. @@ -7429,6 +7439,8 @@ class DeviceCopy(ExternKernelOut): def create(cls, x: IRNode, device: torch.device, non_blocking: bool) -> IRNode: if ( not x.is_extern() + # Can not apply this optimization if x has been mutated + and try_get_name(x) not in V.graph.mutated_buffers and all(r in V.graph.constants for r in x.get_read_names()) and not config.aot_inductor.use_runtime_constant_folding ): diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index d374be59c9446..090265d208c92 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -3918,15 +3918,6 @@ def _unsafe_index_put_(self, indices, values, accumulate=False): def index_put_impl_(self, indices, values, accumulate, check, may_realize=False): if may_realize: - def try_get_name(x): - if isinstance(x, ir.TensorBox): - x = x.data - if isinstance(x, ir.BaseView): - x = x.unwrap_view() - if isinstance(x, ir.StorageBox): - x = x.data - return x.get_name() if isinstance(x, ir.Buffer) else None - def indice_slice_from_randperm(indice): # Refer to: https://github.com/pytorch/pytorch/pull/139366#discussion_r1825424660 # For this specific pattern, indices is unique as coming from torch.randperm. @@ -3941,7 +3932,7 @@ def indice_slice_from_randperm(indice): ) return False - if try_get_name(self) in values.get_read_names() and not all( + if ir.try_get_name(self) in values.get_read_names() and not all( indice_slice_from_randperm(indice) for indice in indices ): # Fix issue: https://github.com/pytorch/pytorch/issues/138908 From 7ec5c1684e6891fa09d6bb96206e6452ea35a71f Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Fri, 21 Nov 2025 12:28:38 -0800 Subject: [PATCH 196/230] [inductor] Reduce cold compilation time caused by duplicated user-defined Triton kernels (#168292) Summary: Similar to https://github.com/pytorch/pytorch/pull/167132, but the previous PR didn't consider user-defined Triton kernels. When cudagraphs-partition is enabled in Inductor, different partitions can use the same user-defined Triton kernels. Each user-defined Trition kernel should only be defined and compiled once. Local measure shoes this PR can reduce Qwen/Qwen3-VL-235B-A22B-Instruct's cold compilation time from 243.65s to 114.69s. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168292 Approved by: https://github.com/eellison ghstack dependencies: #168281 --- test/inductor/test_cudagraph_trees.py | 25 +++++++++++++++++++++++++ torch/_inductor/codegen/wrapper.py | 2 ++ 2 files changed, 27 insertions(+) diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index 934f969543b2a..0203604b1ba03 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -4148,6 +4148,31 @@ def foo(x): "def triton_poi_fused_add_", 1, exactly=True ).run(code[0]) + @config.patch("graph_partition", True) + def test_graph_partition_user_defined_triton_kernel_reuse(self): + from torch.testing._internal.triton_utils import add_kernel + + def foo(x, y): + # partition 1 + output1 = torch.empty_like(x) + add_kernel[(4,)](x, y, output1, n_elements=128, BLOCK_SIZE=16) + output1_cpu = output1.cpu() + 1 + # partition 2 should reuse the user-defined kernel + x2 = output1_cpu.to("cuda") + output2 = torch.empty_like(x) + add_kernel[(4,)](x2, y, output2, n_elements=128, BLOCK_SIZE=16) + return output1, output2 + + compiled_foo = torch.compile(foo) + x = torch.randn(128, device="cuda") + y = torch.randn(128, device="cuda") + eager_out = foo(x, y) + compiled_out, code = run_and_get_code(compiled_foo, x, y) + self.assertEqual(eager_out, compiled_out) + FileCheck().check_count( + "async_compile.triton('add_kernel',", 1, exactly=True + ).run(code[0]) + def test_meta_tensor(self): def foobar(x, y): return x * 2, y * 3 diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 0498fca739bfc..7e4aa07987224 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -3765,6 +3765,8 @@ def __init__( self.kernel_autotune_calls = root.kernel_autotune_calls # Only store kernel src to name mapping in the main graph self.src_to_kernel = root.src_to_kernel + # Same here, only define user-defined Triton kernels in the main graph + self.user_defined_kernel_cache = root.user_defined_kernel_cache def set_launcher_fn_name(self) -> None: # This sets up the name of the function containing the launcher code of From 69bcac8d4fc10f4713c404902ce622f521ec0c06 Mon Sep 17 00:00:00 2001 From: Shawn Xu Date: Sat, 22 Nov 2025 01:05:08 +0000 Subject: [PATCH 197/230] [triton] Enable Triton kernel serialization for AOTI by adding dict and int list argument types (#167866) Differential Revision: D86997070 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167866 Approved by: https://github.com/minjang --- test/export/test_export.py | 2 - torch/_dynamo/variables/higher_order_ops.py | 2 +- torch/_export/serde/export_schema.thrift | 4 +- torch/_export/serde/schema.py | 4 +- torch/_export/serde/schema.yaml | 8 +- torch/_export/serde/serialize.py | 42 +++++- torch/_higher_order_ops/hints_wrap.py | 2 +- torch/_inductor/codegen/cpp_wrapper_cpu.py | 7 +- torch/_inductor/codegen/debug_utils.py | 6 +- .../utils/generated_serialization_types.h | 123 ++++++++++++------ torch/nativert/graph/Graph.cpp | 9 ++ torch/nativert/graph/Graph.h | 1 + torch/nativert/graph/Serialization.cpp | 23 ++++ 13 files changed, 176 insertions(+), 57 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index b3bb0b48b569d..6ebed4f224643 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -16142,8 +16142,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # expected 3*..., but got 8 ep.module()(torch.randn(4, 2)) - @testing.expectedFailureSerDer # T195866111 - @testing.expectedFailureSerDerNonStrict @testing.expectedFailureStrictV2 def test_hints_wrapper(self): strict = True diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 4ae8868f15e84..afb6522ac0e5c 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -3159,7 +3159,7 @@ def _call_function( # to (body_node, lifted_args_tuple, {}) body_node = p_args[0] lifted_args = p_args[1:] - p_args = (body_node, lifted_args, {}) + p_args = (body_node, tuple(lifted_args), {}) # add hints into p_kwargs p_kwargs = {} diff --git a/torch/_export/serde/export_schema.thrift b/torch/_export/serde/export_schema.thrift index f4a08f8739993..155f52595740c 100644 --- a/torch/_export/serde/export_schema.thrift +++ b/torch/_export/serde/export_schema.thrift @@ -1,5 +1,5 @@ // @generated by update_schema.py -// checksum<> +// checksum<<0e870e558fb4362f69b825842ab606cf0becd10a008003ac676156becf20b65b>> namespace py3 torch._export namespace cpp2 torch._export.schema @@ -167,6 +167,8 @@ union Argument { 240: list as_sym_floats; 250: OptionalTensorArgument as_optional_tensor; 260: ComplexValue as_complex; + 280: list> as_int_lists; + 290: map as_string_to_argument; } struct NamedArgument { diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py index a9cec8b185c58..0d95ca32e6455 100644 --- a/torch/_export/serde/schema.py +++ b/torch/_export/serde/schema.py @@ -9,7 +9,7 @@ # NOTE: Please update this value if any modifications are made to the schema -SCHEMA_VERSION = (8, 14) +SCHEMA_VERSION = (8, 15) TREESPEC_VERSION = 1 @@ -212,6 +212,8 @@ class Argument(_Union): as_sym_floats: Annotated[list[SymFloatArgument], 240] as_optional_tensor: Annotated[OptionalTensorArgument, 250] as_complex: Annotated[ComplexValue, 260] + as_int_lists: Annotated[list[list[int]], 280] + as_string_to_argument: Annotated[dict[str, "Argument"], 290] class ArgumentKind(IntEnum): diff --git a/torch/_export/serde/schema.yaml b/torch/_export/serde/schema.yaml index 951351e7786aa..6f13741416cb3 100644 --- a/torch/_export/serde/schema.yaml +++ b/torch/_export/serde/schema.yaml @@ -1,5 +1,5 @@ # @generated by update_schema.py -# checksum<<74d07b92c36d5854263145c231553dcda15215f0460e7ace43554248c05378ec>> +# checksum<> AOTInductorModelPickleData: kind: struct fields: @@ -75,6 +75,10 @@ Argument: type: OptionalTensorArgument as_complex: type: ComplexValue + as_int_lists: + type: List[List[int]] + as_string_to_argument: + type: Dict[str, Argument] ArgumentKind: kind: enum fields: @@ -551,5 +555,5 @@ UserOutputSpec: type: Argument SCHEMA_VERSION: - 8 -- 14 +- 15 TREESPEC_VERSION: 1 diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 84978f0066712..c64aaff9ae1f2 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -39,6 +39,7 @@ from torch.utils._triton import has_triton from ..utils import remove_proxy_from_state_dict +from . import schema from .schema import ( # type: ignore[attr-defined] Argument, ArgumentKind, @@ -1195,6 +1196,13 @@ def serialize_input(self, arg, arg_type: Optional[Any] = None) -> Argument: ) elif arg is None: return Argument.create(as_none=True) + elif isinstance(arg, dict): + serialized_dict = {} + for key, value in arg.items(): + if not isinstance(key, str): + raise SerializeError(f"Dict keys must be strings, got {type(key)}") + serialized_dict[key] = self.serialize_input(value) + return Argument.create(as_string_to_argument=serialized_dict) elif isinstance(arg, (list, tuple)): if len(arg) == 0: if arg_type is not None: @@ -1326,6 +1334,11 @@ def serialize_optional_tensor_args(a): return Argument.create( as_optional_tensors=list(map(serialize_optional_tensor_args, arg)) ) + elif all( + isinstance(a, tuple) and all(type(x) is int for x in a) for a in arg + ): + # list of int tuples + return Argument.create(as_int_lists=[list(t) for t in arg]) else: raise SerializeError( f"Unsupported list/tuple argument type: {[type(a) for a in arg]}" @@ -2735,6 +2748,12 @@ def deserialize_input(self, inp: Argument) -> Any: return self.deserialize_sym_argument(inp.as_sym_float) elif typ_ == "as_sym_bool": return self.deserialize_sym_argument(inp.as_sym_bool) + elif isinstance(value, dict): + if typ_ == "as_string_to_argument": + # Deserialize dict[str, Argument] recursively + return {k: self.deserialize_input(v) for k, v in value.items()} + else: + raise SerializeError(f"Unknown dict type: {typ_}") elif isinstance(value, list): if len(value) == 0: return [] @@ -2744,6 +2763,9 @@ def deserialize_input(self, inp: Argument) -> Any: elif typ_ in ("as_ints", "as_floats", "as_bools", "as_strings"): # convert from serialized.python.types.List to python list return list(value) + elif typ_ == "as_int_lists": + # Convert list of lists back to list of tuples for Triton grids + return [tuple(dims) for dims in value] elif typ_ in ("as_sym_ints", "as_sym_bools", "as_sym_floats"): return [self.deserialize_sym_argument(arg) for arg in value] elif typ_ == "as_optional_tensors": @@ -3239,7 +3261,18 @@ def serialize( return artifact +def _resolve_schema_cls(cls): + if isinstance(cls, str): + resolved = getattr(schema, cls, None) + if resolved is not None: + return resolved + if isinstance(cls, typing.ForwardRef): + return _resolve_schema_cls(cls.__forward_arg__) + return cls + + def _dict_to_dataclass(cls, data): + cls = _resolve_schema_cls(cls) assert not isinstance(cls, str), f"Unresolved class type: '{cls}'." if typing.get_origin(cls) is Annotated: return _dict_to_dataclass(cls.__origin__, data) @@ -3255,12 +3288,13 @@ def _dict_to_dataclass(cls, data): _type = next(iter(data.keys())) _value = next(iter(data.values())) assert isinstance(_type, str) - field_type = cls.__annotations__[_type] + type_hints = typing.get_type_hints(cls, globalns=vars(schema)) + field_type = type_hints[_type] # pyrefly: ignore [missing-attribute] return cls.create(**{_type: _dict_to_dataclass(field_type, _value)}) elif dataclasses.is_dataclass(cls): fields = {} - type_hints = typing.get_type_hints(cls) + type_hints = typing.get_type_hints(cls, globalns=vars(schema)) # For forward compatibility consideration, we ignore all the keys # that are not showing up in the dataclass definition. for f in dataclasses.fields(cls): @@ -3365,6 +3399,10 @@ def _get_argument(a: Argument): return a.as_custom_obj elif a.type == "as_operator": return None + elif a.type == "as_int_lists": + return None + elif a.type == "as_string_to_argument": + return None else: raise AssertionError(f"Unknown input type to the ExportedProgram: {a}") diff --git a/torch/_higher_order_ops/hints_wrap.py b/torch/_higher_order_ops/hints_wrap.py index 3f21c518cbd74..583623393a0a1 100644 --- a/torch/_higher_order_ops/hints_wrap.py +++ b/torch/_higher_order_ops/hints_wrap.py @@ -34,7 +34,7 @@ def __call__(self, body_fn, args, kwargs, hints): backend compiler. """ if not isinstance(args, tuple): - raise RuntimeError(f"args must be a tuple, got {type(args)}") + args = tuple(args) if not all(isinstance(t, (torch.Tensor, int, float, bool)) for t in args): raise RuntimeError( diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 63bff112afee2..3a65d1c895d1c 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -158,11 +158,12 @@ def _generate_kernel_call_helper( ) new_args = [] for idx, arg in enumerate(call_args): - if "*" in arg_types[idx]: + if isinstance(arg_types[idx], str) and "*" in arg_types[idx]: new_args.append(f"({arg_types[idx]})({arg}.data_ptr())") else: - # arg is a scalar - new_args.append(arg) + # arg is a scalar - ensure it's a string for C++ codegen + # With Triton support, arg might be a SymPy expression or other type + new_args.append(str(arg) if not isinstance(arg, str) else arg) # debug printer related logic for cpp kernel type. debug_printer_manager = V.graph.wrapper_code.debug_printer debug_printer_manager.set_printer_args( diff --git a/torch/_inductor/codegen/debug_utils.py b/torch/_inductor/codegen/debug_utils.py index dc4349ad7bbf5..9b465e3d1ffab 100644 --- a/torch/_inductor/codegen/debug_utils.py +++ b/torch/_inductor/codegen/debug_utils.py @@ -161,7 +161,9 @@ def set_printer_args( # TODO: Find a more reliable way to detect kernel args types to print for extern kernel calls if kernel_type == "extern": args_to_print_or_save_extern = [ - arg for arg in args_to_print_or_save if arg.startswith(("buf", "arg")) + arg + for arg in args_to_print_or_save + if isinstance(arg, str) and arg.startswith(("buf", "arg")) ] self.args_to_print_or_save = args_to_print_or_save_extern elif kernel_type == "cpp": @@ -172,7 +174,7 @@ def set_printer_args( else arg ) for arg in args_to_print_or_save - if arg.startswith(("buf", "arg")) + if isinstance(arg, str) and arg.startswith(("buf", "arg")) ] else: self.args_to_print_or_save = args_to_print_or_save diff --git a/torch/csrc/utils/generated_serialization_types.h b/torch/csrc/utils/generated_serialization_types.h index f7abfece3bc31..706d7940ee785 100644 --- a/torch/csrc/utils/generated_serialization_types.h +++ b/torch/csrc/utils/generated_serialization_types.h @@ -1,5 +1,5 @@ // @generated by update_schema.py -// checksum<<74d07b92c36d5854263145c231553dcda15215f0460e7ace43554248c05378ec>> +// checksum<> // clang-format off #pragma once @@ -10,7 +10,6 @@ #include #include #include -#include #include @@ -191,7 +190,7 @@ inline std::string_view printEnum(const ArgumentKind& e) { case ArgumentKind::POSITIONAL: return "POSITIONAL"; case ArgumentKind::KEYWORD: return "KEYWORD"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } @@ -199,7 +198,7 @@ inline void parseEnum(std::string_view s, ArgumentKind& t) { if (s == "UNKNOWN") { t = ArgumentKind::UNKNOWN; return; } if (s == "POSITIONAL") { t = ArgumentKind::POSITIONAL; return; } if (s == "KEYWORD") { t = ArgumentKind::KEYWORD; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } enum class Layout { @@ -224,7 +223,7 @@ inline std::string_view printEnum(const Layout& e) { case Layout::_mkldnn: return "_mkldnn"; case Layout::Strided: return "Strided"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } @@ -237,7 +236,7 @@ inline void parseEnum(std::string_view s, Layout& t) { if (s == "SparseBsc") { t = Layout::SparseBsc; return; } if (s == "_mkldnn") { t = Layout::_mkldnn; return; } if (s == "Strided") { t = Layout::Strided; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } enum class MemoryFormat { @@ -256,7 +255,7 @@ inline std::string_view printEnum(const MemoryFormat& e) { case MemoryFormat::ChannelsLast3d: return "ChannelsLast3d"; case MemoryFormat::PreserveFormat: return "PreserveFormat"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } @@ -266,7 +265,7 @@ inline void parseEnum(std::string_view s, MemoryFormat& t) { if (s == "ChannelsLast") { t = MemoryFormat::ChannelsLast; return; } if (s == "ChannelsLast3d") { t = MemoryFormat::ChannelsLast3d; return; } if (s == "PreserveFormat") { t = MemoryFormat::PreserveFormat; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } enum class ScalarType { @@ -313,7 +312,7 @@ inline std::string_view printEnum(const ScalarType& e) { case ScalarType::FLOAT8E4M3FNUZ: return "FLOAT8E4M3FNUZ"; case ScalarType::FLOAT8E5M2FNUZ: return "FLOAT8E5M2FNUZ"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } @@ -337,7 +336,7 @@ inline void parseEnum(std::string_view s, ScalarType& t) { if (s == "FLOAT8E5M2") { t = ScalarType::FLOAT8E5M2; return; } if (s == "FLOAT8E4M3FNUZ") { t = ScalarType::FLOAT8E4M3FNUZ; return; } if (s == "FLOAT8E5M2FNUZ") { t = ScalarType::FLOAT8E5M2FNUZ; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -454,7 +453,7 @@ inline std::string_view printEnum(const SymExprHint::Tag& e) { case SymExprHint::Tag::AS_BOOL: return "AS_BOOL"; case SymExprHint::Tag::AS_FLOAT: return "AS_FLOAT"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } @@ -462,7 +461,7 @@ inline void parseEnum(std::string_view s, SymExprHint::Tag& t) { if (s == "AS_INT") { t = SymExprHint::Tag::AS_INT; return; } if (s == "AS_BOOL") { t = SymExprHint::Tag::AS_BOOL; return; } if (s == "AS_FLOAT") { t = SymExprHint::Tag::AS_FLOAT; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -560,14 +559,14 @@ inline std::string_view printEnum(const SymInt::Tag& e) { case SymInt::Tag::AS_EXPR: return "AS_EXPR"; case SymInt::Tag::AS_INT: return "AS_INT"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } inline void parseEnum(std::string_view s, SymInt::Tag& t) { if (s == "AS_EXPR") { t = SymInt::Tag::AS_EXPR; return; } if (s == "AS_INT") { t = SymInt::Tag::AS_INT; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -638,14 +637,14 @@ inline std::string_view printEnum(const SymFloat::Tag& e) { case SymFloat::Tag::AS_EXPR: return "AS_EXPR"; case SymFloat::Tag::AS_FLOAT: return "AS_FLOAT"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } inline void parseEnum(std::string_view s, SymFloat::Tag& t) { if (s == "AS_EXPR") { t = SymFloat::Tag::AS_EXPR; return; } if (s == "AS_FLOAT") { t = SymFloat::Tag::AS_FLOAT; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -716,14 +715,14 @@ inline std::string_view printEnum(const SymBool::Tag& e) { case SymBool::Tag::AS_EXPR: return "AS_EXPR"; case SymBool::Tag::AS_BOOL: return "AS_BOOL"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } inline void parseEnum(std::string_view s, SymBool::Tag& t) { if (s == "AS_EXPR") { t = SymBool::Tag::AS_EXPR; return; } if (s == "AS_BOOL") { t = SymBool::Tag::AS_BOOL; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -866,14 +865,14 @@ inline std::string_view printEnum(const SymIntArgument::Tag& e) { case SymIntArgument::Tag::AS_NAME: return "AS_NAME"; case SymIntArgument::Tag::AS_INT: return "AS_INT"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } inline void parseEnum(std::string_view s, SymIntArgument::Tag& t) { if (s == "AS_NAME") { t = SymIntArgument::Tag::AS_NAME; return; } if (s == "AS_INT") { t = SymIntArgument::Tag::AS_INT; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -944,14 +943,14 @@ inline std::string_view printEnum(const SymFloatArgument::Tag& e) { case SymFloatArgument::Tag::AS_NAME: return "AS_NAME"; case SymFloatArgument::Tag::AS_FLOAT: return "AS_FLOAT"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } inline void parseEnum(std::string_view s, SymFloatArgument::Tag& t) { if (s == "AS_NAME") { t = SymFloatArgument::Tag::AS_NAME; return; } if (s == "AS_FLOAT") { t = SymFloatArgument::Tag::AS_FLOAT; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -1022,14 +1021,14 @@ inline std::string_view printEnum(const SymBoolArgument::Tag& e) { case SymBoolArgument::Tag::AS_NAME: return "AS_NAME"; case SymBoolArgument::Tag::AS_BOOL: return "AS_BOOL"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } inline void parseEnum(std::string_view s, SymBoolArgument::Tag& t) { if (s == "AS_NAME") { t = SymBoolArgument::Tag::AS_NAME; return; } if (s == "AS_BOOL") { t = SymBoolArgument::Tag::AS_BOOL; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -1136,14 +1135,14 @@ inline std::string_view printEnum(const OptionalTensorArgument::Tag& e) { case OptionalTensorArgument::Tag::AS_TENSOR: return "AS_TENSOR"; case OptionalTensorArgument::Tag::AS_NONE: return "AS_NONE"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } inline void parseEnum(std::string_view s, OptionalTensorArgument::Tag& t) { if (s == "AS_TENSOR") { t = OptionalTensorArgument::Tag::AS_TENSOR; return; } if (s == "AS_NONE") { t = OptionalTensorArgument::Tag::AS_NONE; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -1233,11 +1232,11 @@ class Argument { public: enum class Tag { - AS_NONE, AS_TENSOR, AS_TENSORS, AS_INT, AS_INTS, AS_FLOAT, AS_FLOATS, AS_STRING, AS_STRINGS, AS_SYM_INT, AS_SYM_INTS, AS_SCALAR_TYPE, AS_MEMORY_FORMAT, AS_LAYOUT, AS_DEVICE, AS_BOOL, AS_BOOLS, AS_SYM_BOOL, AS_SYM_BOOLS, AS_GRAPH, AS_OPTIONAL_TENSORS, AS_CUSTOM_OBJ, AS_OPERATOR, AS_SYM_FLOAT, AS_SYM_FLOATS, AS_OPTIONAL_TENSOR, AS_COMPLEX + AS_NONE, AS_TENSOR, AS_TENSORS, AS_INT, AS_INTS, AS_FLOAT, AS_FLOATS, AS_STRING, AS_STRINGS, AS_SYM_INT, AS_SYM_INTS, AS_SCALAR_TYPE, AS_MEMORY_FORMAT, AS_LAYOUT, AS_DEVICE, AS_BOOL, AS_BOOLS, AS_SYM_BOOL, AS_SYM_BOOLS, AS_GRAPH, AS_OPTIONAL_TENSORS, AS_CUSTOM_OBJ, AS_OPERATOR, AS_SYM_FLOAT, AS_SYM_FLOATS, AS_OPTIONAL_TENSOR, AS_COMPLEX, AS_INT_LISTS, AS_STRING_TO_ARGUMENT }; private: - std::variant, int64_t, std::vector, F64, std::vector, std::string, std::vector, SymIntArgument, std::vector, ScalarType, MemoryFormat, Layout, Device, bool, std::vector, SymBoolArgument, std::vector, GraphArgument, std::vector, CustomObjArgument, std::string, SymFloatArgument, std::vector, OptionalTensorArgument, ComplexValue> variant_; + std::variant, int64_t, std::vector, F64, std::vector, std::string, std::vector, SymIntArgument, std::vector, ScalarType, MemoryFormat, Layout, Device, bool, std::vector, SymBoolArgument, std::vector, GraphArgument, std::vector, CustomObjArgument, std::string, SymFloatArgument, std::vector, OptionalTensorArgument, ComplexValue, std::vector>, std::unordered_map>> variant_; Tag tag_; public: @@ -1488,6 +1487,24 @@ class Argument { tag_ = Tag::AS_COMPLEX; } + const std::vector>& get_as_int_lists() const { + return std::get<28>(variant_); + } + + void set_as_int_lists(std::vector> def) { + variant_.emplace<28>(std::move(def)); + tag_ = Tag::AS_INT_LISTS; + } + + const std::unordered_map>& get_as_string_to_argument() const { + return std::get<29>(variant_); + } + + void set_as_string_to_argument(std::unordered_map> def) { + variant_.emplace<29>(std::move(def)); + tag_ = Tag::AS_STRING_TO_ARGUMENT; + } + friend void to_json(nlohmann::json& nlohmann_json_j, const Argument& nlohmann_json_t) { if (nlohmann_json_t.tag_ == Tag::AS_NONE) { @@ -1598,6 +1615,14 @@ class Argument { nlohmann_json_j["as_complex"] = nlohmann_json_t.get_as_complex(); return; } + if (nlohmann_json_t.tag_ == Tag::AS_INT_LISTS) { + nlohmann_json_j["as_int_lists"] = nlohmann_json_t.get_as_int_lists(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_STRING_TO_ARGUMENT) { + nlohmann_json_j["as_string_to_argument"] = nlohmann_json_t.get_as_string_to_argument(); + return; + } } friend void from_json(const nlohmann::json& nlohmann_json_j, Argument& nlohmann_json_t) { @@ -1737,6 +1762,16 @@ class Argument { nlohmann_json_t.tag_ = Tag::AS_COMPLEX; return; } + if (nlohmann_json_j.contains("as_int_lists")) { + nlohmann_json_t.variant_.emplace<28>(nlohmann_json_j.at("as_int_lists").template get>>()); + nlohmann_json_t.tag_ = Tag::AS_INT_LISTS; + return; + } + if (nlohmann_json_j.contains("as_string_to_argument")) { + nlohmann_json_t.variant_.emplace<29>(nlohmann_json_j.at("as_string_to_argument").template get>>()); + nlohmann_json_t.tag_ = Tag::AS_STRING_TO_ARGUMENT; + return; + } } }; @@ -1769,8 +1804,10 @@ inline std::string_view printEnum(const Argument::Tag& e) { case Argument::Tag::AS_SYM_FLOATS: return "AS_SYM_FLOATS"; case Argument::Tag::AS_OPTIONAL_TENSOR: return "AS_OPTIONAL_TENSOR"; case Argument::Tag::AS_COMPLEX: return "AS_COMPLEX"; + case Argument::Tag::AS_INT_LISTS: return "AS_INT_LISTS"; + case Argument::Tag::AS_STRING_TO_ARGUMENT: return "AS_STRING_TO_ARGUMENT"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } @@ -1802,7 +1839,9 @@ inline void parseEnum(std::string_view s, Argument::Tag& t) { if (s == "AS_SYM_FLOATS") { t = Argument::Tag::AS_SYM_FLOATS; return; } if (s == "AS_OPTIONAL_TENSOR") { t = Argument::Tag::AS_OPTIONAL_TENSOR; return; } if (s == "AS_COMPLEX") { t = Argument::Tag::AS_COMPLEX; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + if (s == "AS_INT_LISTS") { t = Argument::Tag::AS_INT_LISTS; return; } + if (s == "AS_STRING_TO_ARGUMENT") { t = Argument::Tag::AS_STRING_TO_ARGUMENT; return; } + throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -1905,8 +1944,8 @@ class Graph { std::unordered_map sym_int_values; std::unordered_map sym_bool_values; bool is_single_tensor_return = false; - std::unordered_map custom_obj_values; - std::unordered_map sym_float_values; + std::unordered_map custom_obj_values = {}; + std::unordered_map sym_float_values = {}; public: @@ -2128,7 +2167,7 @@ inline std::string_view printEnum(const ConstantValue::Tag& e) { case ConstantValue::Tag::AS_STRING: return "AS_STRING"; case ConstantValue::Tag::AS_BOOL: return "AS_BOOL"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } @@ -2138,7 +2177,7 @@ inline void parseEnum(std::string_view s, ConstantValue::Tag& t) { if (s == "AS_FLOAT") { t = ConstantValue::Tag::AS_FLOAT; return; } if (s == "AS_STRING") { t = ConstantValue::Tag::AS_STRING; return; } if (s == "AS_BOOL") { t = ConstantValue::Tag::AS_BOOL; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -2466,7 +2505,7 @@ inline std::string_view printEnum(const InputSpec::Tag& e) { case InputSpec::Tag::TOKEN: return "TOKEN"; case InputSpec::Tag::CONSTANT_INPUT: return "CONSTANT_INPUT"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } @@ -2478,7 +2517,7 @@ inline void parseEnum(std::string_view s, InputSpec::Tag& t) { if (s == "CUSTOM_OBJ") { t = InputSpec::Tag::CUSTOM_OBJ; return; } if (s == "TOKEN") { t = InputSpec::Tag::TOKEN; return; } if (s == "CONSTANT_INPUT") { t = InputSpec::Tag::CONSTANT_INPUT; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -2852,7 +2891,7 @@ inline std::string_view printEnum(const OutputSpec::Tag& e) { case OutputSpec::Tag::TOKEN: return "TOKEN"; case OutputSpec::Tag::PARAMETER_MUTATION: return "PARAMETER_MUTATION"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } @@ -2865,7 +2904,7 @@ inline void parseEnum(std::string_view s, OutputSpec::Tag& t) { if (s == "USER_INPUT_MUTATION") { t = OutputSpec::Tag::USER_INPUT_MUTATION; return; } if (s == "TOKEN") { t = OutputSpec::Tag::TOKEN; return; } if (s == "PARAMETER_MUTATION") { t = OutputSpec::Tag::PARAMETER_MUTATION; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -3027,8 +3066,8 @@ class GraphModule { Graph graph; GraphSignature signature; std::vector module_call_graph; - std::unordered_map metadata; - std::unordered_map treespec_namedtuple_fields; + std::unordered_map metadata = {}; + std::unordered_map treespec_namedtuple_fields = {}; public: @@ -3109,9 +3148,9 @@ class ExportedProgram { std::unordered_map opset_version; std::unordered_map range_constraints; SchemaVersion schema_version; - std::vector verifiers; + std::vector verifiers = {}; std::string torch_version = "<=2.4"; - std::vector guards_code; + std::vector guards_code = {}; public: diff --git a/torch/nativert/graph/Graph.cpp b/torch/nativert/graph/Graph.cpp index 47d082f44332f..00acf1782c2d8 100644 --- a/torch/nativert/graph/Graph.cpp +++ b/torch/nativert/graph/Graph.cpp @@ -1034,6 +1034,15 @@ std::ostream& operator<<(std::ostream& out, const Constant& constant) { out << kDevicePrefix << '{' << arg << '}'; } else if constexpr (is_same_v>) { out << fmt::format("[{}]", fmt::join(arg, ",")); + } else if constexpr (is_same_v>>) { + out << '['; + for (const auto& [idx, inner_list] : c10::enumerate(arg)) { + if (idx > 0) { + out << ", "; + } + out << fmt::format("{}", fmt::streamed(inner_list)); + } + out << ']'; } else if constexpr (is_same_v>) { out << fmt::format(""); VLOG(0) << "Subgraph pretty print is not implemented"; diff --git a/torch/nativert/graph/Graph.h b/torch/nativert/graph/Graph.h index bbd87a8e2014b..c713df9401884 100644 --- a/torch/nativert/graph/Graph.h +++ b/torch/nativert/graph/Graph.h @@ -97,6 +97,7 @@ using Constant = std::variant< bool, std::vector, std::vector, + std::vector>, std::unique_ptr>; c10::IValue constantToIValue(const Constant& constant); diff --git a/torch/nativert/graph/Serialization.cpp b/torch/nativert/graph/Serialization.cpp index 4c45edd1f5751..532e73a40bd4a 100644 --- a/torch/nativert/graph/Serialization.cpp +++ b/torch/nativert/graph/Serialization.cpp @@ -101,6 +101,11 @@ Value* symbolicToValue( case torch::_export::Argument::Tag::AS_SYM_FLOAT: { return graph.getValue(arg.get_as_sym_float().get_as_name()); } + case torch::_export::Argument::Tag::AS_STRING_TO_ARGUMENT: { + TORCH_CHECK( + false, + "String to argument mapping is not yet supported in symbolic context"); + } default: TORCH_CHECK( false, @@ -453,6 +458,7 @@ bool isSymbolic(const torch::_export::Argument& arg) { case torch::_export::Argument::Tag::AS_SYM_FLOAT: case torch::_export::Argument::Tag::AS_SYM_FLOATS: case torch::_export::Argument::Tag::AS_CUSTOM_OBJ: + case torch::_export::Argument::Tag::AS_OPTIONAL_TENSOR: return true; default: return false; @@ -532,6 +538,23 @@ Constant constantToValue( case torch::_export::Argument::Tag::AS_SYM_FLOATS: { TORCH_CHECK(false, "SymFloats is not yet implemented"); } + case torch::_export::Argument::Tag::AS_OPTIONAL_TENSOR: + TORCH_CHECK(false, "Optional tensor is symbolic, not constant"); + case torch::_export::Argument::Tag::AS_COMPLEX: + TORCH_CHECK(false, "Complex values are not yet supported as constants"); + case torch::_export::Argument::Tag::AS_INT_LISTS: { + std::vector> ret; + for (const auto& inner_list : jsonArg.get_as_int_lists()) { + std::vector inner_ret; + for (const auto& val : inner_list) { + inner_ret.push_back(val); + } + ret.push_back(inner_ret); + } + return ret; + } + case torch::_export::Argument::Tag::AS_STRING_TO_ARGUMENT: + return None(); default: TORCH_CHECK(false, "Got unknown json argument"); } From 24e1958fcbfe089421f5731d662b7cc766330345 Mon Sep 17 00:00:00 2001 From: William Wen Date: Fri, 21 Nov 2025 13:18:39 -0800 Subject: [PATCH 198/230] [dynamo] add torch._dynamo.set_recursion_limit to fix 3.12/3.13 RecursionError problems (#167888) Fixes https://github.com/pytorch/pytorch/issues/167789 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167888 Approved by: https://github.com/malfet, https://github.com/colesbury --- test/dynamo/test_repros.py | 91 ++++++++++++++++++++++++++++ torch/_C/_dynamo/eval_frame.pyi | 2 + torch/_dynamo/__init__.py | 24 ++++++++ torch/csrc/dynamo/eval_frame.c | 5 +- torch/csrc/dynamo/eval_frame_cpp.cpp | 61 ++++++++++++++++++- torch/csrc/dynamo/eval_frame_cpp.h | 7 ++- torch/csrc/dynamo/init.cpp | 4 ++ 7 files changed, 189 insertions(+), 5 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index aab7d5268fcdc..bb1af9abc3b71 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -7456,6 +7456,97 @@ def forward(self, x): msg, ) + def test_dynamo_set_recursion_limit_simple(self): + # Test that torch._dynamo.set_recursion_limit calls sys.setrecursionlimit for all supported + # Python versions + old_recursion_limit = sys.getrecursionlimit() + old_dynamo_recursion_limit = torch._dynamo.get_recursion_limit() + try: + + def fn(x, n): + if n == 0: + return x + return fn(x, n - 1) + 1 + + sys.setrecursionlimit(100) + + with self.assertRaises(RecursionError): + fn(torch.ones(3), 1000) + + opt_fn = torch.compile(fn, backend="eager", dynamic=False) + torch._dynamo.set_recursion_limit(100000) + self.assertEqual(fn(torch.ones(3), 1000), opt_fn(torch.ones(3), 1000)) + finally: + if old_dynamo_recursion_limit > 0: + torch._dynamo.set_recursion_limit(old_dynamo_recursion_limit) + sys.setrecursionlimit(old_recursion_limit) + + @unittest.skipIf( + sys.version_info < (3, 12) or sys.version_info >= (3, 14), + "only 3.12, 3.13 affected by c recursion limit", + ) + def test_dynamo_set_recursion_limit(self): + old_recursion_limit = sys.getrecursionlimit() + old_dynamo_recursion_limit = torch._dynamo.get_recursion_limit() + try: + + def fn(x, n): + if n == 0: + return x + return fn(x, n - 1) + 1 + + sys.setrecursionlimit(100) + + with self.assertRaises(RecursionError): + fn(torch.ones(3), 1000) + + sys.setrecursionlimit(2000) + + fn(torch.ones(3), 1000) + opt_fn = torch.compile(fn, backend="eager", dynamic=False) + sys.setrecursionlimit(100000) + with self.assertRaises(Exception): + opt_fn(torch.ones(3), 1000) + + torch._dynamo.set_recursion_limit(100000) + self.assertEqual(fn(torch.ones(3), 1000), opt_fn(torch.ones(3), 1000)) + finally: + if old_dynamo_recursion_limit > 0: + torch._dynamo.set_recursion_limit(old_dynamo_recursion_limit) + sys.setrecursionlimit(old_recursion_limit) + + @unittest.skipIf( + sys.version_info < (3, 12) or sys.version_info >= (3, 14), + "only 3.12, 3.13 affected by c recursion limit", + ) + def test_dynamo_set_recursion_limit_usage(self): + old_recursion_limit = sys.getrecursionlimit() + old_dynamo_recursion_limit = torch._dynamo.get_recursion_limit() + try: + torch._dynamo.set_recursion_limit(100) + self.assertEqual(torch._dynamo.get_recursion_limit(), 100) + + with self.assertRaisesRegex(ValueError, "recursion limit"): + torch._dynamo.set_recursion_limit(0) + + self.assertEqual(torch._dynamo.get_recursion_limit(), 100) + + torch._dynamo.set_recursion_limit(1) + sys.setrecursionlimit(100) + + @torch.compile(backend="eager", dynamic=False) + def fn(x, n): + if n == 0: + return x + return fn(x, n - 1) + 1 + + with self.assertRaisesRegex(RuntimeError, "new c_recursion limit"): + fn(torch.ones(3), 5) + finally: + if old_dynamo_recursion_limit > 0: + torch._dynamo.set_recursion_limit(old_dynamo_recursion_limit) + sys.setrecursionlimit(old_recursion_limit) + @expectedFailureDynamic def test_dynamo_default_lru_cache_behavior(self): @torch.compile(backend="eager") diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi index 117795db5ac3e..e81771b0cc958 100644 --- a/torch/_C/_dynamo/eval_frame.pyi +++ b/torch/_C/_dynamo/eval_frame.pyi @@ -19,6 +19,8 @@ def set_guard_complete_hook( hook: Optional[DynamoGuardCompleteHook], ) -> Optional[DynamoGuardCompleteHook]: ... def raise_sigtrap() -> None: ... +def set_c_recursion_limit(limit: int) -> None: ... +def get_c_recursion_limit() -> int: ... class _CacheEntry: def check_fn(self, *args: object, **kwargs: object) -> bool: ... diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index de097edf87752..b0b00bc6f5b89 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -105,6 +105,7 @@ "reset", "run", "error_on_graph_break", + "set_recursion_limit", "set_stance", "skip_frame", "step_unsupported", @@ -181,3 +182,26 @@ def reset_code_caches() -> None: if code: reset_code(code) code_context.clear() + + +def get_recursion_limit() -> int: + """ + Returns the internal dynamo recursion limit set by `torch._dynamo.set_recursion_limit`. + + Returns -1 if no c recursion limit has been set. + """ + return torch._C._dynamo.eval_frame.get_c_recursion_limit() + + +def set_recursion_limit(limit: int) -> None: + """ + Sets an internal dynamo recursion limit. The limit must be >= 1. + + This is possibly needed in Python 3.12-3.13 since there is a separate C recursion limit + that is not visible at the Python level. If you are getting RecursionErrors during + Dynamo compilation and `sys.setrecursionlimit()` doesn't help, this function may alleviate + the issue. + + NOTE: this function will also call `sys.setrecursionlimit()`. + """ + torch._C._dynamo.eval_frame.set_c_recursion_limit(limit) diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c index b08fffedaa014..58cb48de664d5 100644 --- a/torch/csrc/dynamo/eval_frame.c +++ b/torch/csrc/dynamo/eval_frame.c @@ -733,7 +733,10 @@ static PyMethodDef _methods[] = { {"get_eval_frame_callback", get_eval_frame_callback_py, METH_NOARGS, NULL}, {"reset_code", reset_code, METH_O, NULL}, {"unsupported", unsupported, METH_VARARGS, NULL}, - {"set_code_exec_strategy", set_code_exec_strategy, METH_VARARGS, NULL}, + {"set_code_exec_strategy", + dynamo_set_code_exec_strategy, + METH_VARARGS, + NULL}, {"set_guard_error_hook", set_guard_error_hook, METH_O, NULL}, {"set_guard_complete_hook", set_guard_complete_hook, METH_O, NULL}, {"raise_sigtrap", raise_sigtrap, METH_NOARGS, NULL}, diff --git a/torch/csrc/dynamo/eval_frame_cpp.cpp b/torch/csrc/dynamo/eval_frame_cpp.cpp index e678bc7bad04a..72465d6f4774f 100644 --- a/torch/csrc/dynamo/eval_frame_cpp.cpp +++ b/torch/csrc/dynamo/eval_frame_cpp.cpp @@ -50,6 +50,56 @@ static py::handle _callback_from_action( return callback; } +// c_recursion_remaining only defined in 3.12 and 3.13 + +static int32_t c_recursion_limit = -1; + +void dynamo_set_c_recursion_limit(int32_t limit) { + if (limit < 1) { + throw std::range_error("recursion limit must be greater or equal than 1"); + } + c_recursion_limit = limit; + // cannot fail + Py_SetRecursionLimit(limit); // also set the Python limit +} + +int32_t dynamo_get_c_recursion_limit() { + return c_recursion_limit; +} + +#if IS_PYTHON_3_12_PLUS && !IS_PYTHON_3_14_PLUS + +struct CRecursionLimitRAII { + PyThreadState* tstate; + int32_t old_recursion_remaining; + CRecursionLimitRAII(PyThreadState* tstate) : tstate{tstate} { + auto limit = dynamo_get_c_recursion_limit(); + auto& remaining = tstate->c_recursion_remaining; + this->old_recursion_remaining = remaining; + if (limit < 0) { + // no change to limit + return; + } + if (limit < remaining) { + PyErr_SetString( + PyExc_RuntimeError, + "new c_recursion limit is lower than thread's current c_recursion_remaining."); + } + remaining = limit; + } + ~CRecursionLimitRAII() { + this->tstate->c_recursion_remaining = this->old_recursion_remaining; + } +}; + +#else + +struct CRecursionLimitRAII { + CRecursionLimitRAII(PyThreadState* tstate) {} +}; + +#endif + // frame and callback are borrowed references. // Returns new reference. PyObject* dynamo__custom_eval_frame( @@ -258,6 +308,13 @@ PyObject* dynamo__custom_eval_frame( bool apply_to_code = false; PyObject* guarded_code = nullptr; try { + CRecursionLimitRAII tmp(tstate); // increase C recursion limit to the given + // value during compilation + // C recursion limit failure + if (PyErr_Occurred()) { + fail(); + return eval_result; + } callback_result = dynamo_call_callback( callback, frame, locals.get(), cache_entry, frame_state); new_strategy = @@ -320,7 +377,7 @@ PyObject* dynamo__custom_eval_frame( return eval_result; } -PyObject* set_code_exec_strategy(PyObject* dummy, PyObject* args) { +PyObject* dynamo_set_code_exec_strategy(PyObject* dummy, PyObject* args) { PyObject* code_obj = nullptr; PyObject* strategy_obj = nullptr; if (!PyArg_ParseTuple(args, "OO", &code_obj, &strategy_obj)) { @@ -344,7 +401,7 @@ PyObject* set_code_exec_strategy(PyObject* dummy, PyObject* args) { Py_RETURN_NONE; } -void skip_code_recursive(PyCodeObject* code) { +void dynamo_skip_code_recursive(PyCodeObject* code) { ExtraState* extra = get_extra_state(code); if (extra == nullptr) { extra = init_and_set_extra_state(code); diff --git a/torch/csrc/dynamo/eval_frame_cpp.h b/torch/csrc/dynamo/eval_frame_cpp.h index 2f3587094f763..8cc1ab7618b3d 100644 --- a/torch/csrc/dynamo/eval_frame_cpp.h +++ b/torch/csrc/dynamo/eval_frame_cpp.h @@ -16,8 +16,11 @@ PyObject* dynamo__custom_eval_frame( int throw_flag, PyObject* callback); -PyObject* set_code_exec_strategy(PyObject* dummy, PyObject* obj); -void skip_code_recursive(PyCodeObject* code); +PyObject* dynamo_set_code_exec_strategy(PyObject* dummy, PyObject* obj); +void dynamo_skip_code_recursive(PyCodeObject* code); + +void dynamo_set_c_recursion_limit(int32_t limit); +int32_t dynamo_get_c_recursion_limit(); #ifdef __cplusplus diff --git a/torch/csrc/dynamo/init.cpp b/torch/csrc/dynamo/init.cpp index 790ff9acff3a1..3a1865cabc049 100644 --- a/torch/csrc/dynamo/init.cpp +++ b/torch/csrc/dynamo/init.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -250,6 +251,9 @@ void initDynamoBindings(PyObject* torch) { .def_readwrite("cur_action", &FrameExecStrategy::cur_action) .def_readwrite("recursive_action", &FrameExecStrategy::recursive_action); + m.def("set_c_recursion_limit", &dynamo_set_c_recursion_limit); + m.def("get_c_recursion_limit", &dynamo_get_c_recursion_limit); + m.def("_debug_get_cache_entry_list", &_debug_get_cache_entry_list); m.def("_reset_precompile_entries", &_reset_precompile_entries); m.def("_load_precompile_entry", &_load_precompile_entry); From 68921acd6f4e1d80823a1a6be250dde61808a587 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 21 Nov 2025 14:54:06 -0800 Subject: [PATCH 199/230] [dynamo][guards] Log backend match recompilation reason (#168387) Before this PR, we will see an empty recompilation reason. With this PR image Fixes https://github.com/pytorch/pytorch/issues/168373 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168387 Approved by: https://github.com/williamwen42 --- torch/_C/_dynamo/eval_frame.pyi | 2 ++ torch/_dynamo/convert_frame.py | 7 +++++-- torch/_dynamo/eval_frame.py | 2 +- torch/_dynamo/guards.py | 31 ++++++++++++++++++++++++++++++- torch/csrc/dynamo/init.cpp | 1 + 5 files changed, 39 insertions(+), 4 deletions(-) diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi index e81771b0cc958..060bf2638e096 100644 --- a/torch/_C/_dynamo/eval_frame.pyi +++ b/torch/_C/_dynamo/eval_frame.pyi @@ -1,5 +1,6 @@ import enum import types +from collections.abc import Callable from typing import Optional, overload from torch._dynamo.guards import GuardManagerWrapper @@ -29,6 +30,7 @@ class _CacheEntry: compile_id: CompileId # If we run into circular issues, just use object guard_manager: GuardManagerWrapper + backend: Callable next: _CacheEntry | None class _PrecompileEntry: diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 58767245fa9a4..87dc80e99bd79 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -99,6 +99,7 @@ always_optimize_code_objects, Constraint, dynamo_tls, + innermost_fn, skip_code, TorchPatcher, ) @@ -1574,7 +1575,9 @@ def count_args(code: CodeType) -> int: # Check recompilations recompile_reason: Optional[str] = None if is_recompilation(cache_size) and frame: - reasons = get_and_maybe_log_recompilation_reasons(cache_entry, frame) + reasons = get_and_maybe_log_recompilation_reasons( + cache_entry, frame, innermost_fn(compiler_fn) + ) recompile_reason = ( "Unable to find recompilation reasons" if not reasons else reasons[0] ) @@ -1582,7 +1585,7 @@ def count_args(code: CodeType) -> int: inline_inbuilt_nn_modules_candidate = False if not config.inline_inbuilt_nn_modules and frame: inbuilt_nn_reasons = get_and_maybe_log_recompilation_reasons( - cache_entry, frame, skip_logging=True + cache_entry, frame, innermost_fn(compiler_fn), skip_logging=True ) inbuilt_nn_recompile_reason = ( None if not inbuilt_nn_reasons else inbuilt_nn_reasons[0] diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 43bc570841239..4253fa031d2ec 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -251,7 +251,7 @@ def fail_callback( cache_entries = _debug_get_cache_entry_list(frame.f_code) if cache_entries: reasons = get_and_maybe_log_recompilation_reasons( - cache_entries[0], frame, skip_logging=True + cache_entries[0], frame, innermost_fn(callback), skip_logging=True ) if reasons: failures = textwrap.indent("\n".join(reasons), "- ") diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 77db6ec52d54d..cf621921cd59b 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -3685,6 +3685,7 @@ def make_guard_filter_entry(guard: Guard) -> GuardFilterEntry: self.guard_manager, output_graph.local_scope, CompileContext.current_compile_id(), + backend=None, # no need to set this because we are trying to find the offending guard entry ) raise AssertionError( "Guard failed on the same frame it was created. This is a bug - please create an issue." @@ -4302,6 +4303,7 @@ def get_guard_fail_reason_helper( guard_manager: GuardManagerWrapper, f_locals: dict[str, object], compile_id: Optional[CompileId], + backend: Optional[Callable], ) -> str: """ Return the reason why `guard_manager` failed. @@ -4314,6 +4316,10 @@ def get_guard_fail_reason_helper( scope.update(guard_manager.closure_vars) reasons: list[str] = [] + cache_entry_backend = None + if guard_manager.cache_entry: + cache_entry_backend = guard_manager.cache_entry.backend + no_tensor_aliasing_check_failed = False verbose_code_parts: list[str] = [] @@ -4336,6 +4342,24 @@ def get_guard_fail_reason_helper( else: reasons = verbose_code_parts verbose_code_parts = [] + elif cache_entry_backend != backend: + # None of the guard entries failed - a backend match issue + reason = ( + "BACKEND_MATCH failure: torch.compile detected different backend callables." + " If this is unexpected, wrap your backend in functools.partial (or reuse the" + " same cached backend) to avoid creating a new backend function each time." + " More details: https://github.com/pytorch/pytorch/issues/168373" + ) + reasons.append(reason) + else: + # Unexpected recompilation - points to a bug + reason = ( + "Unexpected recompilation: runtime guards failed even though they passed" + " during recompilation-reason analysis." + " Please open an issue with a minimal repro:" + " https://github.com/pytorch/pytorch" + ) + reasons.append(reason) if no_tensor_aliasing_check_failed: reasons = recompilation_reason_for_no_tensor_aliasing_guard( @@ -4372,11 +4396,14 @@ def get_guard_fail_reason( code: types.CodeType, f_locals: dict[str, object], compile_id: CompileId, + backend: Callable, skip_logging: bool = False, ) -> str: if isinstance(guard_manager, DeletedGuardManagerWrapper): return f"{compile_id}: {guard_manager.invalidation_reason}" - reason_str = get_guard_fail_reason_helper(guard_manager, f_locals, compile_id) + reason_str = get_guard_fail_reason_helper( + guard_manager, f_locals, compile_id, backend + ) if skip_logging: return reason_str guard_failures[orig_code_map[code]].append(reason_str) @@ -4397,6 +4424,7 @@ def get_guard_fail_reason( def get_and_maybe_log_recompilation_reasons( cache_entry: Optional[CacheEntry], frame: DynamoFrameType, + backend: Callable, skip_logging: bool = False, ) -> list[str]: """ @@ -4411,6 +4439,7 @@ def get_and_maybe_log_recompilation_reasons( cache_entry.code, frame.f_locals, cache_entry.compile_id, + backend, skip_logging, ) if reason: diff --git a/torch/csrc/dynamo/init.cpp b/torch/csrc/dynamo/init.cpp index 3a1865cabc049..69d6e0555ceb4 100644 --- a/torch/csrc/dynamo/init.cpp +++ b/torch/csrc/dynamo/init.cpp @@ -225,6 +225,7 @@ void initDynamoBindings(PyObject* torch) { .def_readonly("code", &CacheEntry::code) .def_readonly("compile_id", &CacheEntry::compile_id) .def_readonly("trace_annotation", &CacheEntry::trace_annotation) + .def_readonly("backend", &CacheEntry::backend) .def_property_readonly("next", &CacheEntry::next) .def( "update_diff_guard_root_manager", From 95ae5a47883f3de5aa5577aa55cbcbd8ab70dda7 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 21 Nov 2025 14:54:07 -0800 Subject: [PATCH 200/230] [dynamo][pytree][compile time] Specialize tree_is_leaf (#168070) Fixes https://github.com/pytorch/pytorch/issues/168373 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168070 Approved by: https://github.com/XuehaiPan, https://github.com/fxdawnn, https://github.com/mlazos, https://github.com/zou3519, https://github.com/williamwen42 ghstack dependencies: #168387 --- test/dynamo/test_repros.py | 65 ++++++++++++++++++++++++++++ torch/_dynamo/trace_rules.py | 2 + torch/_dynamo/variables/__init__.py | 1 + torch/_dynamo/variables/functions.py | 65 ++++++++++++++++++++++++++++ 4 files changed, 133 insertions(+) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index bb1af9abc3b71..9bb94b9a47d40 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -8334,6 +8334,71 @@ def fn(a, b): # Should compile successfully with fullgraph=True self.assertEqual(cnt.frame_count, 1) + def test_pytree_tree_is_leaf_not_traced(self): + # Test that torch.utils._pytree.tree_is_leaf is not traced into + # when is_leaf parameter is None (the common case) + from torch.utils._pytree import tree_is_leaf + + cnt = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnt, fullgraph=True) + def fn(x, y): + # Test with various types + # Tensors are leaves + is_leaf_tensor = tree_is_leaf(x) + assert is_leaf_tensor is True + + # Lists are not leaves (they're in SUPPORTED_NODES) + is_leaf_list = tree_is_leaf([x, y]) + assert is_leaf_list is False + + # Dicts are not leaves + is_leaf_dict = tree_is_leaf({"a": x, "b": y}) + assert is_leaf_dict is False + + return x + y + + x = torch.randn(3, 4) + y = torch.randn(3, 4) + result = fn(x, y) + expected = x + y + + self.assertTrue(torch.allclose(result, expected)) + # Should compile successfully with fullgraph=True + self.assertEqual(cnt.frame_count, 1) + + def test_pytree_tree_is_leaf_with_namedtuple(self): + # Test that torch.utils._pytree.tree_is_leaf handles namedtuples correctly + from collections import namedtuple + + from torch.utils._pytree import tree_is_leaf + + Point = namedtuple("Point", ["x", "y"]) + + cnt = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnt, fullgraph=True) + def fn(a, b): + # Namedtuples are not leaves (they're in SUPPORTED_NODES) + point = Point(a, b) + is_leaf_namedtuple = tree_is_leaf(point) + assert is_leaf_namedtuple is False + + # But individual tensors are leaves + is_leaf_tensor = tree_is_leaf(a) + assert is_leaf_tensor is True + + return a + b + + x = torch.randn(3, 4) + y = torch.randn(3, 4) + result = fn(x, y) + expected = x + y + + self.assertTrue(torch.allclose(result, expected)) + # Should compile successfully with fullgraph=True + self.assertEqual(cnt.frame_count, 1) + instantiate_parametrized_tests(ReproTests) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 36093b042002e..083c8b1f93807 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -65,6 +65,7 @@ NestedUserFunctionVariable, PolyfilledFunctionVariable, PyTreeGetNodeTypeFunctionVariable, + PyTreeTreeIsLeafFunctionVariable, ReparametrizeModuleCallVariable, SkipFunctionVariable, TorchInGraphFunctionVariable, @@ -380,6 +381,7 @@ "torch/testing/_internal/common_distributed.py#forward": UserFunctionVariable, f"torch/testing/_internal/common_distributed.py#{TORCH_DYNAMO_RESUME_IN_PREFIX}": UserFunctionVariable, "torch.utils._pytree._get_node_type": PyTreeGetNodeTypeFunctionVariable, + "torch.utils._pytree.tree_is_leaf": PyTreeTreeIsLeafFunctionVariable, } diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index ac0be3e5888be..439ce274b7ce6 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -65,6 +65,7 @@ NestedUserFunctionVariable, PolyfilledFunctionVariable, PyTreeGetNodeTypeFunctionVariable, + PyTreeTreeIsLeafFunctionVariable, SkipFunctionVariable, TMADescriptorExperimentalVariable, TMADescriptorStableVariable, diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index cd345759956be..31ffe9813c3fd 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -66,6 +66,7 @@ DefaultsSource, GetItemSource, SkipGuardSource, + TorchSource, TypeSource, ) from ..utils import ( @@ -119,6 +120,13 @@ _spec_cache: WeakKeyDictionary[Any, Any] = WeakKeyDictionary() +@functools.lru_cache +def get_pytree_SUPPORTED_NODES_source(): + return AttrSource( + AttrSource(AttrSource(TorchSource(), "utils"), "_pytree"), "SUPPORTED_NODES" + ) + + class FunctionSpec: def __init__(self, func: FunctionType): code = func.__code__ @@ -2758,3 +2766,60 @@ def call_function( type_source = AttrSource(CollectionsSource(), "namedtuple") return VariableTracker.build(tx, namedtuple, type_source) return VariableTracker.build(tx, python_type, source=type_source) + + +class PyTreeTreeIsLeafFunctionVariable(UserFunctionVariable): + """ + `torch.utils._pytree.tree_is_leaf` function is a hot function. We want to special case it to reduce Dynamo tracing time. + + def tree_is_leaf( + tree: PyTree, + is_leaf: Callable[[PyTree], bool] | None = None, + ) -> bool: + if is_leaf is not None and is_leaf(tree): + return True + return _get_node_type(tree) not in SUPPORTED_NODES + + When is_leaf is None (the common case), we can optimize by not tracing into the function. + When is_leaf is not None, we fall back to regular tracing since it requires executing user code. + """ + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # tree_is_leaf(tree, is_leaf=None) + if len(args) < 1 or len(args) > 2: + raise_type_error_exc( + tx, + f"tree_is_leaf requires 1 or 2 arguments, got {len(args)}", + ) + + # Check if is_leaf parameter is provided + is_leaf = kwargs.get("is_leaf", ConstantVariable.create(None)) + if len(args) == 2: + is_leaf = args[1] + + if not ( + isinstance(is_leaf, variables.ConstantVariable) and is_leaf.value is None + ): + return super().call_function(tx, args, kwargs) + + # Optimize the case where is_leaf is None + # return _get_node_type(tree) not in SUPPORTED_NODES + tree = args[0] + node_type_var = PyTreeGetNodeTypeFunctionVariable( + torch.utils._pytree._get_node_type + ).call_function(tx, [tree], {}) + + # If the SUPPORTED_NODES was seen earlier and mutated, there would be a + # source and that will give us the mutated SUPPORTED_NODES. + supported_nodes_var = VariableTracker.build( + tx, + torch.utils._pytree.SUPPORTED_NODES, + source=get_pytree_SUPPORTED_NODES_source(), + ) + out = supported_nodes_var.call_method(tx, "__contains__", [node_type_var], {}) + return ConstantVariable.create(not out.value) From a9184a03c8686caf3c8105bb104a70bfe17f1f5e Mon Sep 17 00:00:00 2001 From: mori360 Date: Sat, 22 Nov 2025 03:43:54 +0000 Subject: [PATCH 201/230] [DTensor] update redistribute_cost, add disable_graph_based_transform (#166747) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes https://github.com/pytorch/pytorch/issues/157585 ### Background DTensor has 2 different utilities * redistribute_cost() API is used to compute a cost value for each possible redistribution when generating a sharding strategy. DTensor dispatch then minimizes over these cost values to select the strategy before performing a redistribution * redistribute_tensor APIs are given a src/target placement, but need to find the fastest set of operations to perform this redistribution. This logic uses different cost modeling than `redistribute_cost` The redistribute_tensor APIs use a standardized representation (TransformInfo) for capturing a sequence of collectives that transform a placement from src to dest. These TransformInfos are the same representation whether they are generated using a **greedy** algorithm or a **graph-based / min-cost algorithm**. ### This PR We propose to make redistribute_cost generate the same set of transform_infos (either greedily or using min-cost) when estimating costs. This ensures that the cost estimated during redistribute_cost will always match the actual behavior when it comes time to redistribute. The current redistribute_cost does not take device order into consideration, Thus as the case post in the issue, that the estimations for cost S(0)S(0) -> S(0)R or RS(0) are the same with consideration of S(0) -> R -- namely, redistribute_cost underestimates this case by ignoring shard ordering, while transforminfo based estimation correctly models the more expensive sequence of collectives that would actually happen. In this PR, we use _TransformInfo here with the shortest path and accumulate costs, and modify comm_byte in different collectives Add `_FORCE_MIN_COST_REDISTRIBUTION_PLAN` as global config with highest priority to disable graph based approach to gen transform info. We add a flag and a contextmanager for opting into using the min-cost transforminfo. ### Experiments We show that using TransformInfos as part of redistribute_cost slows down mm_strategy by 50%, compared to the exponential slowdown seen when using the min-cost algorithm. Here are the time cost with device dims on different cost estimator methods, 1. Non-TransformInfo, which is the current method before PR 2. gen_transform_info There are around 50% extra cost on 2 Screenshot 2025-11-18 at 4 20 10 PM With traces, we can find the difference is mainly from gen_transform_info Screenshot 2025-11-18 at 3 44 17 PM Screenshot 2025-11-18 at 3 44 11 PM Pull Request resolved: https://github.com/pytorch/pytorch/pull/166747 Approved by: https://github.com/wconstab --- .../tensor/debug/test_debug_mode.py | 16 +--- test/distributed/tensor/test_op_strategy.py | 31 +++++++ test/distributed/tensor/test_redistribute.py | 74 +++++++++++++++ torch/distributed/tensor/_collective_utils.py | 55 +++++++---- torch/distributed/tensor/_redistribute.py | 93 +++++++++++++++++-- 5 files changed, 232 insertions(+), 37 deletions(-) diff --git a/test/distributed/tensor/debug/test_debug_mode.py b/test/distributed/tensor/debug/test_debug_mode.py index 5d4db74b6a929..c0625d37c6dad 100644 --- a/test/distributed/tensor/debug/test_debug_mode.py +++ b/test/distributed/tensor/debug/test_debug_mode.py @@ -215,12 +215,8 @@ def test_debug_mode_densor_redistribution_trace(self): debug_mode.debug_string(), """\ aten::mm(dt: f32[128, 8]| S(0)[0]S(0)[1], dt: f32[8, 128]| S(1)[0]S(1)[1]) - redistribute_input(0, S(0)[0]S(0)[1] -> S(0)R) - redistribute_input(t: f32[16, 8], trace: S(0)[0]S(0)[1]->S(0)R) - _c10d_functional::all_gather_into_tensor(t: f32[16, 8], 2, 3) - _c10d_functional::wait_tensor(t: f32[32, 8]) - redistribute_input(1, S(1)[0]S(1)[1] -> RS(1)) - redistribute_input(t: f32[8, 16], trace: S(1)[0]S(1)[1]->S(1)R->RR->RS(1)) + redistribute_input(1, S(1)[0]S(1)[1] -> RR) + redistribute_input(t: f32[8, 16], trace: S(1)[0]S(1)[1]->S(1)R->RR) _c10d_functional::all_gather_into_tensor(t: f32[8, 16], 2, 3) _c10d_functional::wait_tensor(t: f32[16, 16]) aten::chunk(t: f32[16, 16], 2) @@ -229,11 +225,9 @@ def test_debug_mode_densor_redistribution_trace(self): _c10d_functional::wait_tensor(t: f32[32, 32]) aten::chunk(t: f32[32, 32], 4) aten::cat(['t: f32[8, 32]', 't: f32[8, 32]', 't: f32[8, 32]', 't: f32[8, 32]'], 1) - aten::chunk(t: f32[8, 128], 2, 1) - aten::clone(t: f32[8, 64]) - aten::mm(t: f32[32, 8], t: f32[8, 64]) - aten::sum(dt: f32[128, 128]| S(0)S(1)) - aten::sum(t: f32[32, 64])""", + aten::mm(t: f32[16, 8], t: f32[8, 128]) + aten::sum(dt: f32[128, 128]| S(0)[0]S(0)[1]) + aten::sum(t: f32[16, 128])""", ) def test_debug_mode_einsum(self): diff --git a/test/distributed/tensor/test_op_strategy.py b/test/distributed/tensor/test_op_strategy.py index 72d95efcfa8c9..42c4ccf122fd9 100644 --- a/test/distributed/tensor/test_op_strategy.py +++ b/test/distributed/tensor/test_op_strategy.py @@ -382,6 +382,37 @@ def test_bmm_strategies(self): ) self.assertFalse(output_sharding.needs_redistribute) + def test_redistribute_cost_with_order(self): + mesh_2d = DeviceMesh( + self.device_type, torch.arange(self.world_size).reshape(2, 2) + ) + + # Source: Shard on dim 0 across all three mesh dimensions + source_placement = (Shard(0), Shard(0)) + + # Target: Replicate on first mesh dimension, shard on others + # This requires 2 allgathers, one on dim=0 and one on dim=1 + replicate_mesh_dim0 = (Replicate(), Shard(0)) + + # Target: Replicate on second mesh dimension, shard on others + # This requires 1 allgather on dim=1 + replicate_mesh_dim1 = (Shard(0), Replicate()) + + global_tensor = torch.randn(4, 4) + global_tensor_meta = extract_tensor_meta(global_tensor) + + source_spec = DTensorSpec(mesh_2d, source_placement, global_tensor_meta) + target_spec_dim0 = DTensorSpec(mesh_2d, replicate_mesh_dim0, global_tensor_meta) + target_spec_dim1 = DTensorSpec(mesh_2d, replicate_mesh_dim1, global_tensor_meta) + + # Calculate costs for allgather on each mesh dimension + cost_mesh_dim0 = redistribute_cost(source_spec, target_spec_dim0) + cost_mesh_dim1 = redistribute_cost(source_spec, target_spec_dim1) + + # Cost increases with earlier mesh dimensions due to the way + # mesh dimensions are ordered (outer to inner in device hierarchy) + self.assertGreater(cost_mesh_dim0, cost_mesh_dim1) + # -------------Test op strategy registration------------- # custom op without List[Tensor] as input diff --git a/test/distributed/tensor/test_redistribute.py b/test/distributed/tensor/test_redistribute.py index ec1d69e9b02e6..ebb2c5f01668f 100644 --- a/test/distributed/tensor/test_redistribute.py +++ b/test/distributed/tensor/test_redistribute.py @@ -21,6 +21,10 @@ ) from torch.distributed.tensor._collective_utils import shard_dim_alltoall from torch.distributed.tensor._dtensor_spec import ShardOrderEntry +from torch.distributed.tensor._redistribute import ( + _gen_transform_infos, + use_min_cost_redistribution_plan, +) from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.placement_types import _StridedShard, MaskPartial from torch.testing._internal.common_distributed import skip_if_lt_x_gpu @@ -880,6 +884,76 @@ def test_ordered_redistribute(self): ) self.assertEqual(sharded_dt.to_local(), expected_dt.to_local()) + @with_comms + def test_force_min_cost_redistribution_plan(self): + """ + Test that the disable_graph_based_transform context manager correctly controls + the redistribution algorithm selection (graph-based vs greedy). + """ + # Set deterministic seed for reproducible tensor generation + torch.manual_seed(21) + mesh = init_device_mesh(self.device_type, (2, 2, 2)) + input_data = torch.randn((8, 8, 8), device=self.device_type) + + # the redistribution path differs if we use graph-based or greedy search solution + src_placement, src_order = ( + [Shard(0), Shard(0), Shard(0)], # All mesh dims shard tensor dim 0 + ( + ShardOrderEntry(tensor_dim=0, mesh_dims=(0, 1, 2)), + ), # Device order: 0→1→2 + ) + dst_placement, dst_order = ( + [Shard(1), Shard(1), Shard(1)], # All mesh dims shard tensor dim 1 + ( + ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 1, 2)), + ), # Device order: 0→1→2 + ) + + # Test both graph-based (enable_graph=True) and greedy (enable_graph=False) algorithms + for idx, enable_graph in enumerate([True, False]): + sharded_dt = _distribute_tensor( + input_data.clone(), mesh, src_placement, shard_order=src_order + ) + + with ( + use_min_cost_redistribution_plan(enabled=enable_graph), + DebugMode(record_torchfunction=False) as debug_mode, + ): + sharded_dt = redistribute(sharded_dt, mesh, dst_placement, dst_order) + trace_str = self._extract_redistribute_trace_from_debug_mode( + debug_mode.debug_string() + ) + + # Validate graph-based algorithm trace (idx=0, disable_graph=False) + # Graph-based uses optimal path search (Dijkstra's algorithm) + # Expected path has 6 transformations with strategic intermediate states + # Path: S(0)[0,1,2] → S(0)[0,1]S(2) → S(0)S(2)[1,0] → + # S(1)S(2)[1,0] → S(1)[0,1]S(2) → S(1)[0,1,2] + if idx == 0: + self.assertExpectedInline( + trace_str, + """S(0)[0]S(0)[1]S(0)[2]->S(0)[0]S(0)[1]S(2)->S(0)S(2)[1]S(2)[0]->S(1)S(2)[1]S(2)[0]->S(1)[0]S(1)[1]S(2)->S(1)[0]S(1)[1]S(1)[2]""", + ) + # Validate greedy algorithm trace (idx=1, disable_graph=True) + # Greedy uses simple heuristic approach (processes mesh dims sequentially) + # Expected path has 6 transformations but with different intermediate states + # Path: S(0)[0,1,2] → S(0)[0,1]R → S(0)RR → + # S(1)RR → S(1)[0,1]R → S(1)[0,1,2] + elif idx == 1: + self.assertExpectedInline( + trace_str, + """S(0)[0]S(0)[1]S(0)[2]->S(0)[0]S(0)[1]R->S(0)RR->S(1)RR->S(1)[0]S(1)[1]R->S(1)[0]S(1)[1]S(1)[2]""", + ) + expected_dt = _distribute_tensor( + input_data.clone(), mesh, dst_placement, shard_order=dst_order + ) + self.assertEqual(sharded_dt.to_local(), expected_dt.to_local()) + + # Clear the transformation cache between iterations. Without this, + # the second iteration would use cached paths from the first, + # causing the trace validation to fail because: + _gen_transform_infos.cache_clear() + @with_comms def test_generate_shard_orders(self): """Check if `generate_shard_orders` generates unique sharding combinations""" diff --git a/torch/distributed/tensor/_collective_utils.py b/torch/distributed/tensor/_collective_utils.py index dff426a6d5e5a..90f32efafd395 100644 --- a/torch/distributed/tensor/_collective_utils.py +++ b/torch/distributed/tensor/_collective_utils.py @@ -227,6 +227,7 @@ def check_tensor_meta( return None +# TODO: autoparallel depends on this function, we will keep it until we update autoparallel redistribute_cost def spec_to_bytes(spec: "dtensor_spec.DTensorSpec") -> int: assert spec.tensor_meta is not None, "spec should have tensor meta defined!" return spec.tensor_meta.dtype.itemsize * math.prod(spec.shape) @@ -338,39 +339,61 @@ def redistribute_cost( mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh) cost = 0.0 - comm_bytes_gb = ( - spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024 - ) # Transformation that considered for redistribute cost: # 1. allgather 2. alltoall # 3. allreduce 4. reduce_scatter - for i, (current, target) in enumerate( - zip(current_spec.placements, target_spec.placements) - ): + from torch.distributed._functional_collectives import _are_we_tracing + from torch.distributed.tensor._redistribute import ( + _gen_transform_infos, + _gen_transform_infos_non_cached, + ) + + # No redistribution needed when placements are already identical. + # This also prevents potential failures in _gen_transform_infos for certain configurations + # (e.g., sub-meshes) where finding a transform path between identical states may error out. + # TODO(zpcore): test placements with _StridedShard. + if current_spec.placements == target_spec.placements: + return cost + if _are_we_tracing(): + transform_infos = _gen_transform_infos_non_cached(current_spec, target_spec) + else: + transform_infos = _gen_transform_infos(current_spec, target_spec) + for transform_info in transform_infos: + assert current_spec.tensor_meta is not None, ( + "spec should have tensor meta defined!" + ) + comm_bytes_gb = ( + current_spec.tensor_meta.dtype.itemsize + * math.prod(transform_info.logical_shape) + / 1024 + / 1024 + / 1024 + ) + current = transform_info.src_dst_placements[0] + target = transform_info.src_dst_placements[1] if current == target: continue - - num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[i] + mesh_dim = transform_info.mesh_dim + num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] if current.is_shard() and target.is_replicate(): - # allgather gives larger comm bytes - comm_bytes_gb *= num_devices_on_mesh_dim # add up allgather comm cost - cost += allgather_cost(comm_bytes_gb, mesh_topo, i) + cost += allgather_cost(comm_bytes_gb, mesh_topo, mesh_dim) elif current.is_shard() and target.is_shard(): - # should be alltoall comm, since we haven't implement it yet, add penalty + # should be alltoall comm, since we haven't implement it yet, add 1.0 as penalty # to favor allgather instead - cost += allgather_cost(comm_bytes_gb, mesh_topo, i) + 1.0 + # TODO: add alltoall_cost + comm_bytes_gb /= num_devices_on_mesh_dim + cost += allgather_cost(comm_bytes_gb, mesh_topo, mesh_dim) + 1.0 elif current.is_partial() and target.is_replicate(): # add up allreduce comm cost - cost += allreduce_cost(comm_bytes_gb, mesh_topo, i) + cost += allreduce_cost(comm_bytes_gb, mesh_topo, mesh_dim) elif current.is_partial() and target.is_shard(): # add up reduce_scatter comm cost - cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, i) + cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, mesh_dim) # after reduce_scatter the comm bytes for further collectives halved. comm_bytes_gb /= num_devices_on_mesh_dim elif current.is_shard() and target.is_partial(): # ban shard -> partial as it does not make sense to perform # this redistribute return float("inf") - return cost diff --git a/torch/distributed/tensor/_redistribute.py b/torch/distributed/tensor/_redistribute.py index a407ba6ca91df..84e58c4df169c 100644 --- a/torch/distributed/tensor/_redistribute.py +++ b/torch/distributed/tensor/_redistribute.py @@ -32,6 +32,72 @@ logger = logging.getLogger(__name__) +# Global configuration flag to control the redistribution planning strategy. +# When True, forces the graph-based algorithm using Dijkstra's shortest path. +# When False, prefers the greedy algorithm for faster planning. Uses the graph-based algorithm +# only when necessary to support strided-shard redistribution +_FORCE_MIN_COST_REDISTRIBUTION_PLAN: Optional[bool] = None + + +@contextlib.contextmanager +def use_min_cost_redistribution_plan(enabled: bool = True): + """ + Context manager to control the redistribution planning strategy for DTensor operations. + + This context manager allows you to choose between two algorithms for computing the + sequence of collective operations needed to redistribute a DTensor from one placement + to another: + + - **Graph-based**: Uses Dijkstra's algorithm to find the minimum-cost path + through all possible placement transformations. This approach considers the global + cost of all collective operations and finds the optimal sequence. Best for complex + redistribution patterns where reducing communication cost and memory overhead is critical. + + - **Greedy**: Uses a heuristic approach that makes locally optimal choices + at each step. This is faster to compute but may not produce the globally optimal + transformation sequence. Best for simple redistribution patterns or when planning + speed is more important than optimal communication. + + **Default Behavior (without this context manager):** + + When this context manager is NOT used, the algorithm selection follows this priority: + + 1. **Non-default shard orders** + → Always use graph-based algorithm (required for correctness) + + 2. **Explicit `use_graph_based_transform` parameter** to `_gen_transform_infos_non_cached` + → Use the specified algorithm (True = graph-based, False = greedy) + + 3. **No explicit parameter** (default case) + → Use greedy algorithm for faster planning + + **Behavior with this context manager:** + + This context manager overrides the default selection by setting the global flag + `_FORCE_MIN_COST_REDISTRIBUTION_PLAN`, which takes precedence over the explicit + `use_graph_based_transform` parameter (but not over non-default shard order requirements). + + **Cache Considerations:** + + The redistribution planner caches transform info for performance via the `@cache` + decorator on `_gen_transform_infos`. If you need to change the algorithm selection + for the same input specs, clear the cache using `_gen_transform_infos.cache_clear()` + to ensure the new setting takes effect and doesn't reuse cached results from a + previous run. + + Args: + enabled (bool): If True, forces the use of the graph-based algorithm. + If False, forces the use of the greedy algorithm. + Default: True + """ + global _FORCE_MIN_COST_REDISTRIBUTION_PLAN + old_value = _FORCE_MIN_COST_REDISTRIBUTION_PLAN + _FORCE_MIN_COST_REDISTRIBUTION_PLAN = enabled + try: + yield + finally: + _FORCE_MIN_COST_REDISTRIBUTION_PLAN = old_value + class _TransformInfo(NamedTuple): mesh_dim: int @@ -648,22 +714,29 @@ def _gen_transform_infos_non_cached( dst_spec: DTensorSpec, use_graph_based_transform: Optional[bool] = None, ) -> list[_TransformInfo]: - transform_infos: list[_TransformInfo] = [] device_mesh = src_spec.device_mesh src_shard_order = src_spec.shard_order dst_shard_order = dst_spec.shard_order # DTensorSpec should automatically generate shard_order, and it can be () if # no shard. assert src_shard_order is not None and dst_shard_order is not None - if use_graph_based_transform is None: - if all( - DTensorSpec.is_default_device_order(order) - for order in (src_shard_order, dst_shard_order) - ): - use_graph_based_transform = False - else: - # switch to graph search algorithm if the device order is not the default - use_graph_based_transform = True + + # Determine which transform strategy to use: + # 1. Non-standard device order → always use graph-based + # 2. Global flag or explicit parameter True → use graph-based + # 3. Otherwise → use greedy + has_non_default_order = not all( + DTensorSpec.is_default_device_order(order) + for order in (src_shard_order, dst_shard_order) + ) + + if has_non_default_order is True: + use_graph_based_transform = True + elif _FORCE_MIN_COST_REDISTRIBUTION_PLAN is not None: + use_graph_based_transform = _FORCE_MIN_COST_REDISTRIBUTION_PLAN + elif use_graph_based_transform is None: + use_graph_based_transform = False + drp = get_redistribute_planner(device_mesh, len(src_spec.shape)) if use_graph_based_transform: transform_infos = drp.generate_graph_based_transform_infos( From 4909fd89dcd0016b367e2ab4ff7f594d5e7f1b9e Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Fri, 21 Nov 2025 21:23:33 +0000 Subject: [PATCH 202/230] Move CUDAEvent to c10 (#158219) # Motivation When I refactored the caching allocator, I noticed that there are two separate pieces of code of `EventPool` : one in [aten/cuda/CachingHostAllocator.cpp](https://github.com/pytorch/pytorch/blob/0f21fa84fb605c61482e4218df89f8bb1ef70c14/aten/src/ATen/cuda/CachingHostAllocator.cpp#L23) and another in [c10/cuda/CUDACachingAllocator](https://github.com/pytorch/pytorch/blob/0f21fa84fb605c61482e4218df89f8bb1ef70c14/c10/cuda/CUDACachingAllocator.cpp#L869). I would like to refactor these so that they share a single implementation. To achieve this, I have to move `aten/cuda/CUDAEvent.h` to `c10/cuda`, which I understand this is a big change. However, I think it makes sense conceptually - `CUDAStream` and `CUDAEvent` are both fundamental CUDA abstractions, and since `CUDAStream` is already in `c10/cuda`, placing `CUDAEvent` there as well seems reasonable for consistency. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158219 Approved by: https://github.com/albanD --- aten/src/ATen/cuda/CUDAEvent.h | 254 +--------------- .../hip/impl/HIPEventMasqueradingAsCUDA.h | 86 ++++++ c10/cuda/CMakeLists.txt | 1 + c10/cuda/CUDAEvent.h | 278 ++++++++++++++++++ torch/utils/hipify/cuda_to_hip_mappings.py | 11 + 5 files changed, 381 insertions(+), 249 deletions(-) create mode 100644 aten/src/ATen/hip/impl/HIPEventMasqueradingAsCUDA.h create mode 100644 c10/cuda/CUDAEvent.h diff --git a/aten/src/ATen/cuda/CUDAEvent.h b/aten/src/ATen/cuda/CUDAEvent.h index 7a650b9cbcf35..73340604574ad 100644 --- a/aten/src/ATen/cuda/CUDAEvent.h +++ b/aten/src/ATen/cuda/CUDAEvent.h @@ -3,259 +3,15 @@ #include #include #include -#include +#include #include -#include -#include - -#include - -#include -#include - -/* -* `cudaEventExternal` is a torch-specific flag that is used to -* indicate that the CUDAEvent will be used only for synchronization -* with work outside of the cuda graph, rather than creation of -* cross-stream dependencies within a cuda graph. Resources: -* https://docs.nvidia.com/cuda/archive/12.9.0/cuda-c-programming-guide/index.html#cross-stream-dependencies-and-events -* https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g3457b81d1d32c6a00f6132fbc2693d47 -* https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g0c23426b7252eaa9cef695859991304e -*/ -#define cudaEventExternal 0x08 namespace at::cuda { -/* -* CUDAEvents are movable not copyable wrappers around CUDA's events. -* -* CUDAEvents are constructed lazily when first recorded unless it is -* reconstructed from a cudaIpcEventHandle_t. The event has a device, and this -* device is acquired from the first recording stream. However, if reconstructed -* from a handle, the device should be explicitly specified; or if ipc_handle() is -* called before the event is ever recorded, it will use the current device. -* Later streams that record the event must match this device. -*/ -struct TORCH_CUDA_CPP_API CUDAEvent { - // Constructors - // Default value for `flags` is specified below - it's cudaEventDisableTiming - CUDAEvent() noexcept = default; - CUDAEvent(unsigned int flags) noexcept : flags_{flags} {} - - CUDAEvent( - DeviceIndex device_index, const cudaIpcEventHandle_t* handle) : device_index_(device_index) { - CUDAGuard guard(device_index_); - - AT_CUDA_CHECK(cudaIpcOpenEventHandle(&event_, *handle)); - is_created_ = true; - } - - // Note: event destruction done on creating device to avoid creating a - // CUDA context on other devices. - ~CUDAEvent() { - try { - if (is_created_) { - CUDAGuard guard(device_index_); - const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); - if (C10_UNLIKELY(interp)) { - (*interp)->trace_gpu_event_deletion(at::kCUDA, reinterpret_cast(event_)); - } - AT_CUDA_CHECK(cudaEventDestroy(event_)); - } - } catch (...) { /* No throw */ } - } - - CUDAEvent(const CUDAEvent&) = delete; - CUDAEvent& operator=(const CUDAEvent&) = delete; - - CUDAEvent(CUDAEvent&& other) noexcept { moveHelper(std::move(other)); } - CUDAEvent& operator=(CUDAEvent&& other) noexcept { - if (this != &other) { - moveHelper(std::move(other)); - } - return *this; - } - - operator cudaEvent_t() const { return event(); } - - // Less than operator (to allow use in sets) - friend bool operator<(const CUDAEvent& left, const CUDAEvent& right) { - return left.event_ < right.event_; - } - - std::optional device() const { - if (is_created_) { - return at::Device(at::kCUDA, device_index_); - } else { - return {}; - } - } - - bool isCreated() const { return is_created_; } - DeviceIndex device_index() const {return device_index_;} - cudaEvent_t event() const { return event_; } - - // Note: cudaEventQuery can be safely called from any device - bool query() const { - if (!is_created_) { - return true; - } - - cudaError_t err = cudaEventQuery(event_); - if (err == cudaSuccess) { - return true; - } else if (err != cudaErrorNotReady) { - C10_CUDA_CHECK(err); - } else { - // ignore and clear the error if not ready - (void)cudaGetLastError(); - } - - return false; - } - - void record() { record(getCurrentCUDAStream()); } - - void recordOnce(const CUDAStream& stream) { - if (!was_recorded_) record(stream); - } - - // Note: cudaEventRecord must be called on the same device as the event. - void record(const CUDAStream& stream) { - if (!is_created_) { - createEvent(stream.device_index()); - } - - TORCH_CHECK(device_index_ == stream.device_index(), "Event device ", device_index_, - " does not match recording stream's device ", stream.device_index(), "."); - CUDAGuard guard(device_index_); - -#ifndef USE_ROCM - // it is an error to use cudaEventRecordExternal when not doing stream capture - unsigned int flags = (c10::cuda::currentStreamCaptureStatusMayInitCtx() != c10::cuda::CaptureStatus::None && external_) ? cudaEventRecordExternal : cudaEventRecordDefault; - AT_CUDA_CHECK(cudaEventRecordWithFlags(event_, stream, flags)); -#else - AT_CUDA_CHECK(cudaEventRecord(event_, stream)); -#endif - const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); - if (C10_UNLIKELY(interp)) { - (*interp)->trace_gpu_event_record(at::kCUDA, - reinterpret_cast(event_), - reinterpret_cast(stream.stream()) - ); - } - was_recorded_ = true; - } - - // Note: cudaStreamWaitEvent must be called on the same device as the stream. - // The event has no actual GPU resources associated with it. - void block(const CUDAStream& stream) { - if (is_created_) { - CUDAGuard guard(stream.device_index()); -#ifndef USE_ROCM - // it is an error to use cudaEventWaitExternal when not doing stream capture - unsigned int flags = (c10::cuda::currentStreamCaptureStatusMayInitCtx() != c10::cuda::CaptureStatus::None && external_) ? cudaEventWaitExternal : cudaEventWaitDefault; - AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, flags)); -#else - AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_)); -#endif - const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); - if (C10_UNLIKELY(interp)) { - (*interp)->trace_gpu_event_wait(at::kCUDA, - reinterpret_cast(event_), - reinterpret_cast(stream.stream()) - ); - } - } - } - - // Note: cudaEventElapsedTime can be safely called from any device - float elapsed_time(const CUDAEvent& other) const { - TORCH_CHECK_VALUE( - !(flags_ & cudaEventDisableTiming) && !(other.flags_ & cudaEventDisableTiming), - "Both events must be created with argument 'enable_timing=True'."); - TORCH_CHECK_VALUE( - is_created_ && other.isCreated(), - "Both events must be recorded before calculating elapsed time."); - TORCH_CHECK( - query() && other.query(), - "Both events must be completed before calculating elapsed time."); - - float time_ms = 0; - // We do not strictly have to set the device index to the same as our event, - // but if we don't and the current device is not initialized, it will - // create a new cuda context, which will consume a lot of memory. - CUDAGuard guard(device_index_); - // raise cudaErrorNotReady if either event is recorded but not yet completed - AT_CUDA_CHECK(cudaEventElapsedTime(&time_ms, event_, other.event_)); - return time_ms; - } - - // Note: cudaEventSynchronize can be safely called from any device - void synchronize() const { - if (is_created_) { - const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); - if (C10_UNLIKELY(interp)) { - (*interp)->trace_gpu_event_synchronization(at::kCUDA, reinterpret_cast(event_)); - } - AT_CUDA_CHECK(cudaEventSynchronize(event_)); - } - } - - // Note: cudaIpcGetEventHandle must be called on the same device as the event - void ipc_handle(cudaIpcEventHandle_t * handle) { - if (!is_created_) { - // this CUDAEvent object was initially constructed from flags but event_ - // is not created yet. - createEvent(getCurrentCUDAStream().device_index()); - } - CUDAGuard guard(device_index_); - AT_CUDA_CHECK(cudaIpcGetEventHandle(handle, event_)); - } - -private: - unsigned int flags_ = cudaEventDisableTiming; - bool is_created_ = false; - bool was_recorded_ = false; - bool external_ = false; - DeviceIndex device_index_ = -1; - cudaEvent_t event_{}; - - void createEvent(DeviceIndex device_index) { - external_ = (flags_ & cudaEventExternal) != 0; -#ifdef USE_ROCM - TORCH_CHECK(!external_, "External events are disallowed in rocm"); -#endif - flags_ &= ~cudaEventExternal; - device_index_ = device_index; - CUDAGuard guard(device_index_); - AT_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_)); - const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); - if (C10_UNLIKELY(interp)) { - (*interp)->trace_gpu_event_creation(at::kCUDA, reinterpret_cast(event_)); - } - is_created_ = true; - } - - void moveHelper(CUDAEvent&& other) { - // Transfer ownership of all state from other to this - flags_ = other.flags_; - is_created_ = other.is_created_; - was_recorded_ = other.was_recorded_; - external_ = other.external_; - device_index_ = other.device_index_; - event_ = other.event_; - - // Reset other to a valid empty state to prevent double-free - // The moved-from object must not attempt to destroy the event - other.is_created_ = false; - other.event_ = cudaEvent_t{}; - } -}; - -// EventPool - Thread-safe pool of CUDA events to avoid expensive cudaEventCreate -// calls. cudaEventCreate when concurrently invoked from multiple threads can be -// very expensive (especially on certain device/driver combinations). +// EventPool - Thread-safe pool of CUDA events to avoid expensive +// cudaEventCreate calls. cudaEventCreate when concurrently invoked from +// multiple threads can be very expensive (especially on certain device/driver +// combinations). using CUDAEventPtr = std::unique_ptr>; diff --git a/aten/src/ATen/hip/impl/HIPEventMasqueradingAsCUDA.h b/aten/src/ATen/hip/impl/HIPEventMasqueradingAsCUDA.h new file mode 100644 index 0000000000000..f2741a32889fb --- /dev/null +++ b/aten/src/ATen/hip/impl/HIPEventMasqueradingAsCUDA.h @@ -0,0 +1,86 @@ +#pragma once + +#include + +// Use of c10::hip namespace here makes hipification easier, because +// I don't have to also fix namespaces. Sorry! +namespace c10 { namespace hip { + +// See Note [Masquerading as CUDA] for motivation + +struct HIPEventMasqueradingAsCUDA { + HIPEventMasqueradingAsCUDA() noexcept = default; + HIPEventMasqueradingAsCUDA(unsigned int flags) noexcept + : event_(HIPEvent(flags)) {} + HIPEventMasqueradingAsCUDA( + DeviceIndex device_index, + const hipIpcEventHandle_t* handle) + : event_(HIPEvent(device_index, handle)) {} + + ~HIPEventMasqueradingAsCUDA() = default; + + HIPEventMasqueradingAsCUDA(const HIPEventMasqueradingAsCUDA&) = delete; + HIPEventMasqueradingAsCUDA& operator=(const HIPEventMasqueradingAsCUDA&) = delete; + HIPEventMasqueradingAsCUDA(HIPEventMasqueradingAsCUDA&& other) noexcept = default; + HIPEventMasqueradingAsCUDA& operator=(HIPEventMasqueradingAsCUDA&& other) noexcept = default; + + operator hipEvent_t() const { + return event_.event(); + } + + // Less than operator (to allow use in sets) + friend bool operator<( + const HIPEventMasqueradingAsCUDA& left, + const HIPEventMasqueradingAsCUDA& right) { + return left.event_ < right.event_; + } + + std::optional device() const { + // Unsafely coerce HIP device into CUDA device + return Device(c10::DeviceType::CUDA, event_.device_index()); + } + bool isCreated() const { + return event_.isCreated(); + } + DeviceIndex device_index() const { + return event_.device_index(); + } + hipEvent_t event() const { + return event_.event(); + } + bool query() const { + return event_.query(); + } + void record() { + return event_.record(); + } + + void recordOnce(const HIPStreamMasqueradingAsCUDA& stream) { + event_.recordOnce(stream.hip_stream()); + } + + void record(const HIPStreamMasqueradingAsCUDA& stream) { + event_.record(stream.hip_stream()); + } + + void block(const HIPStreamMasqueradingAsCUDA& stream) { + event_.block(stream.hip_stream()); + } + + float elapsed_time(const HIPEventMasqueradingAsCUDA& other) const { + return event_.elapsed_time(other.event_); + } + + void synchronize() const { + event_.synchronize(); + } + + void ipc_handle(hipIpcEventHandle_t* handle) { + event_.ipc_handle(handle); + } + + private: + HIPEvent event_; +}; + +}} // namespace c10::hip diff --git a/c10/cuda/CMakeLists.txt b/c10/cuda/CMakeLists.txt index 2604f677858d1..fd80c45fcc79e 100644 --- a/c10/cuda/CMakeLists.txt +++ b/c10/cuda/CMakeLists.txt @@ -43,6 +43,7 @@ set(C10_CUDA_HEADERS CUDACachingAllocator.h CUDADeviceAssertionHost.h CUDAException.h + CUDAEvent.h CUDAFunctions.h CUDAGuard.h CUDAMacros.h diff --git a/c10/cuda/CUDAEvent.h b/c10/cuda/CUDAEvent.h new file mode 100644 index 0000000000000..6e5205044879f --- /dev/null +++ b/c10/cuda/CUDAEvent.h @@ -0,0 +1,278 @@ +#pragma once + +#include +#include +#include +#include + +/* + * `cudaEventExternal` is a torch-specific flag that is used to + * indicate that the CUDAEvent will be used only for synchronization + * with work outside of the cuda graph, rather than creation of + * cross-stream dependencies within a cuda graph. Resources: + * https://docs.nvidia.com/cuda/archive/12.9.0/cuda-c-programming-guide/index.html#cross-stream-dependencies-and-events + * https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g3457b81d1d32c6a00f6132fbc2693d47 + * https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g0c23426b7252eaa9cef695859991304e + */ +#define cudaEventExternal 0x08 + +namespace c10::cuda { + +/* + * CUDAEvents are movable not copyable wrappers around CUDA's events. + * + * CUDAEvents are constructed lazily when first recorded unless it is + * reconstructed from a cudaIpcEventHandle_t. The event has a device, and this + * device is acquired from the first recording stream. However, if reconstructed + * from a handle, the device should be explicitly specified; or if ipc_handle() + * is called before the event is ever recorded, it will use the current device. + * Later streams that record the event must match this device. + */ +struct CUDAEvent { + // Constructors + // Default value for `flags` is specified below - it's cudaEventDisableTiming + CUDAEvent() noexcept = default; + CUDAEvent(unsigned int flags) noexcept : flags_{flags} {} + + CUDAEvent(DeviceIndex device_index, const cudaIpcEventHandle_t* handle) + : device_index_(device_index) { + CUDAGuard guard(device_index_); + + C10_CUDA_CHECK(cudaIpcOpenEventHandle(&event_, *handle)); + is_created_ = true; + } + + // Note: event destruction done on creating device to avoid creating a + // CUDA context on other devices. + ~CUDAEvent() { + if (is_created_) { + CUDAGuard guard(device_index_); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_deletion( + c10::kCUDA, reinterpret_cast(event_)); + } + C10_CUDA_CHECK_WARN(cudaEventDestroy(event_)); + } + } + + CUDAEvent(const CUDAEvent&) = delete; + CUDAEvent& operator=(const CUDAEvent&) = delete; + + CUDAEvent(CUDAEvent&& other) noexcept { + moveHelper(std::move(other)); + } + CUDAEvent& operator=(CUDAEvent&& other) noexcept { + if (this != &other) { + moveHelper(std::move(other)); + } + return *this; + } + + operator cudaEvent_t() const { + return event(); + } + + // Less than operator (to allow use in sets) + friend bool operator<(const CUDAEvent& left, const CUDAEvent& right) { + return left.event_ < right.event_; + } + + std::optional device() const { + if (is_created_) { + return c10::Device(c10::kCUDA, device_index_); + } else { + return {}; + } + } + + bool isCreated() const { + return is_created_; + } + DeviceIndex device_index() const { + return device_index_; + } + cudaEvent_t event() const { + return event_; + } + + // Note: cudaEventQuery can be safely called from any device + bool query() const { + if (!is_created_) { + return true; + } + + cudaError_t err = cudaEventQuery(event_); + if (err == cudaSuccess) { + return true; + } else if (err != cudaErrorNotReady) { + C10_CUDA_CHECK(err); + } else { + // ignore and clear the error if not ready + (void)cudaGetLastError(); + } + + return false; + } + + void record() { + record(getCurrentCUDAStream()); + } + + void recordOnce(const CUDAStream& stream) { + if (!was_recorded_) + record(stream); + } + + // Note: cudaEventRecord must be called on the same device as the event. + void record(const CUDAStream& stream) { + if (!is_created_) { + createEvent(stream.device_index()); + } + + TORCH_CHECK( + device_index_ == stream.device_index(), + "Event device ", + device_index_, + " does not match recording stream's device ", + stream.device_index(), + "."); + CUDAGuard guard(device_index_); + +#ifndef USE_ROCM + // it is an error to use cudaEventRecordExternal when not doing stream + // capture + unsigned int flags = (c10::cuda::currentStreamCaptureStatusMayInitCtx() != + c10::cuda::CaptureStatus::None && + external_) + ? cudaEventRecordExternal + : cudaEventRecordDefault; + C10_CUDA_CHECK(cudaEventRecordWithFlags(event_, stream, flags)); +#else + C10_CUDA_CHECK(cudaEventRecord(event_, stream)); +#endif + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_record( + c10::kCUDA, + reinterpret_cast(event_), + reinterpret_cast(stream.stream())); + } + was_recorded_ = true; + } + + // Note: cudaStreamWaitEvent must be called on the same device as the stream. + // The event has no actual GPU resources associated with it. + void block(const CUDAStream& stream) { + if (is_created_) { + CUDAGuard guard(stream.device_index()); +#ifndef USE_ROCM + // it is an error to use cudaEventWaitExternal when not doing stream + // capture + unsigned int flags = (c10::cuda::currentStreamCaptureStatusMayInitCtx() != + c10::cuda::CaptureStatus::None && + external_) + ? cudaEventWaitExternal + : cudaEventWaitDefault; + C10_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, flags)); +#else + C10_CUDA_CHECK(cudaStreamWaitEvent(stream, event_)); +#endif + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_wait( + c10::kCUDA, + reinterpret_cast(event_), + reinterpret_cast(stream.stream())); + } + } + } + + // Note: cudaEventElapsedTime can be safely called from any device + float elapsed_time(const CUDAEvent& other) const { + TORCH_CHECK_VALUE( + !(flags_ & cudaEventDisableTiming) && + !(other.flags_ & cudaEventDisableTiming), + "Both events must be created with argument 'enable_timing=True'."); + TORCH_CHECK_VALUE( + is_created_ && other.isCreated(), + "Both events must be recorded before calculating elapsed time."); + TORCH_CHECK( + query() && other.query(), + "Both events must be completed before calculating elapsed time."); + + float time_ms = 0; + // We do not strictly have to set the device index to the same as our event, + // but if we don't and the current device is not initialized, it will + // create a new cuda context, which will consume a lot of memory. + CUDAGuard guard(device_index_); + // raise cudaErrorNotReady if either event is recorded but not yet completed + C10_CUDA_CHECK(cudaEventElapsedTime(&time_ms, event_, other.event_)); + return time_ms; + } + + // Note: cudaEventSynchronize can be safely called from any device + void synchronize() const { + if (is_created_) { + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_synchronization( + c10::kCUDA, reinterpret_cast(event_)); + } + C10_CUDA_CHECK(cudaEventSynchronize(event_)); + } + } + + // Note: cudaIpcGetEventHandle must be called on the same device as the event + void ipc_handle(cudaIpcEventHandle_t* handle) { + if (!is_created_) { + // this CUDAEvent object was initially constructed from flags but event_ + // is not created yet. + createEvent(getCurrentCUDAStream().device_index()); + } + CUDAGuard guard(device_index_); + C10_CUDA_CHECK(cudaIpcGetEventHandle(handle, event_)); + } + + private: + unsigned int flags_ = cudaEventDisableTiming; + bool is_created_ = false; + bool was_recorded_ = false; + bool external_ = false; + DeviceIndex device_index_ = -1; + cudaEvent_t event_{}; + + void createEvent(DeviceIndex device_index) { + external_ = (flags_ & cudaEventExternal) != 0; +#ifdef USE_ROCM + TORCH_CHECK(!external_, "External events are disallowed in rocm"); +#endif + flags_ &= ~cudaEventExternal; + device_index_ = device_index; + CUDAGuard guard(device_index_); + C10_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_)); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_creation( + c10::kCUDA, reinterpret_cast(event_)); + } + is_created_ = true; + } + + void moveHelper(CUDAEvent&& other) { + // Transfer ownership of all state from other to this + flags_ = other.flags_; + is_created_ = other.is_created_; + was_recorded_ = other.was_recorded_; + external_ = other.external_; + device_index_ = other.device_index_; + event_ = other.event_; + + // Reset other to a valid empty state to prevent double-free + // The moved-from object must not attempt to destroy the event + other.is_created_ = false; + other.event_ = cudaEvent_t{}; + } +}; + +} // namespace c10::cuda diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index fb7dc1c7cb7f0..18afecd18c9be 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -9231,6 +9231,8 @@ API_PYTORCH, ), ), + ("cuda::CUDAEvent", ("hip::HIPEventMasqueradingAsCUDA", API_PYTORCH)), + ("CUDAEvent", ("HIPEventMasqueradingAsCUDA", API_PYTORCH)), ("cuda::CUDAStream", ("hip::HIPStreamMasqueradingAsCUDA", API_PYTORCH)), ("CUDAStream", ("HIPStreamMasqueradingAsCUDA", API_PYTORCH)), ( @@ -9285,6 +9287,14 @@ "c10/cuda/CUDACachingAllocator.h", ("ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h", API_PYTORCH), ), + ( + "ATen/cuda/CUDAEvent.h", # To keep BC, we have to keep this mapping + ("ATen/hip/HIPEvent.h", API_PYTORCH), + ), + ( + "c10/cuda/CUDAEvent.h", + ("ATen/hip/impl/HIPEventMasqueradingAsCUDA.h", API_PYTORCH), + ), ( "c10/cuda/CUDAStream.h", ("ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h", API_PYTORCH), @@ -9425,6 +9435,7 @@ ("c10/cuda/CUDAMathCompat.h", ("c10/hip/HIPMathCompat.h", API_C10)), ("c10/cuda/CUDAFunctions.h", ("c10/hip/HIPFunctions.h", API_C10)), ("c10/cuda/CUDAMiscFunctions.h", ("c10/hip/HIPMiscFunctions.h", API_C10)), + ("c10/cuda/CUDAEvent.h", ("c10/hip/HIPEvent.h", API_C10)), ("c10/cuda/CUDAStream.h", ("c10/hip/HIPStream.h", API_C10)), ("c10/cuda/CUDAGraphsC10Utils.h", ("c10/hip/HIPGraphsC10Utils.h", API_C10)), ("c10/cuda/CUDAAllocatorConfig.h", ("c10/hip/HIPAllocatorConfig.h", API_C10)), From 9c5d972a10228338b69db1beffdb9629d844a64f Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Sat, 22 Nov 2025 05:02:59 +0000 Subject: [PATCH 203/230] [NativeRT] Fix out_t index handling in TritonKernel (#168384) Summary: I think this is a quick indexing bug fix. Differential Revision: D87673889 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168384 Approved by: https://github.com/XueningXu --- torch/nativert/kernels/TritonKernel.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/nativert/kernels/TritonKernel.cpp b/torch/nativert/kernels/TritonKernel.cpp index 081c81f7c646b..11dd671f8fbe6 100644 --- a/torch/nativert/kernels/TritonKernel.cpp +++ b/torch/nativert/kernels/TritonKernel.cpp @@ -167,8 +167,8 @@ void TritonKernel::computeInternal(ExecutionFrame& executionFrame) const { // todo: check if this is redundant auto out_t = out.toTensorList(); - for (const auto& i : output_indices_) { - out_t[i] = input(i, executionFrame).toTensor(); + for (const auto i : c10::irange(output_indices_.size())) { + out_t[i] = input(output_indices_[i], executionFrame).toTensor(); } } From 112a4fab3c656767bf505b9ec96b2b99deb5ba87 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Fri, 21 Nov 2025 12:44:17 -0800 Subject: [PATCH 204/230] Add string support for ABI stable custom ops (#168370) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168370 Approved by: https://github.com/albanD --- .../csrc/my_string_op.cpp | 32 ++++++++++++++ .../libtorch_agnostic_2_10/ops.py | 19 ++++++++ test/cpp_extensions/test_libtorch_agnostic.py | 25 +++++++++++ torch/csrc/shim_common.cpp | 39 ++++++++++++++++ torch/csrc/stable/c/shim.h | 14 ++++++ torch/csrc/stable/library.h | 5 +++ torch/csrc/stable/stableivalue_conversions.h | 44 ++++++++++++++++++- 7 files changed, 177 insertions(+), 1 deletion(-) create mode 100644 test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_string_op.cpp diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_string_op.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_string_op.cpp new file mode 100644 index 0000000000000..1b97f60882b0f --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_string_op.cpp @@ -0,0 +1,32 @@ +#include +#include +#include + +#include +#include + +using torch::stable::Tensor; + +std::tuple, int64_t> my_string_op(Tensor t, std::string_view accessor, std::string passthru) { + int64_t res; + if (accessor == "dim") { + res = t.dim(); + } else if (accessor == "size") { + res = t.size(0); + } else if (accessor == "stride") { + res = t.stride(0); + } else { + STD_TORCH_CHECK(false, "Unsupported accessor value: ", std::string(accessor).c_str()) + } + + auto vec = std::vector({std::string(accessor), std::to_string(res), passthru}); + return std::make_tuple(vec, res); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def("my_string_op(Tensor t, str accessor, str passthru) -> (str[], int)"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { + m.impl("my_string_op", TORCH_BOX(&my_string_op)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py index 102e22e668cdf..b68839dc565c7 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py @@ -246,3 +246,22 @@ def my_get_curr_cuda_blas_handle() -> int: Return the current cuBlasHandle_t pointer value. """ return torch.ops.libtorch_agnostic_2_10.my_get_curr_cuda_blas_handle.default() + + +def my_string_op(t, accessor, passthru) -> tuple[list[str], int]: + """ + The purpose of this op is to test inputting and outputting strings in a + stable custom op. This particular op takes in a Tensor, a string denoting + which tensor metadata API to call, and a pass through string to return a + string list and the value of the tensor metadata. + + If accessor is "size" or "stride", query along the 0th dim. + + Args: + t: Tensor - input tensor to query + accessor: str - which property to access ("dim", "size", or "stride") + passthru: str - a string that gets returned as the last element of the list + + Returns: tuple - (list of [accessor, value, passthru] as strings, value) + """ + return torch.ops.libtorch_agnostic_2_10.my_string_op.default(t, accessor, passthru) diff --git a/test/cpp_extensions/test_libtorch_agnostic.py b/test/cpp_extensions/test_libtorch_agnostic.py index 55681a45e4445..dfb9b6b37f593 100644 --- a/test/cpp_extensions/test_libtorch_agnostic.py +++ b/test/cpp_extensions/test_libtorch_agnostic.py @@ -834,6 +834,31 @@ def test_my_get_curr_cuda_blas_handle(self, device): expected = torch.cuda.current_blas_handle() self.assertEqual(res, expected) + @skipIfTorchVersionLessThan(2, 10) + def test_my_string_op(self, device): + import libtorch_agnostic_2_10 as libtorch_agnostic + + t = torch.empty(3, 4, 5, device=device) + + dim_vec, result_dim = libtorch_agnostic.ops.my_string_op(t, "dim", "ice") + self.assertEqual(dim_vec, ["dim", str(t.dim()), "ice"]) + self.assertEqual(result_dim, t.dim()) + + size_vec, result_size = libtorch_agnostic.ops.my_string_op( + t, "size", "cream" + ) + self.assertEqual(size_vec, ["size", str(t.size(0)), "cream"]) + self.assertEqual(result_size, t.size(0)) + + stride_vec, result_stride = libtorch_agnostic.ops.my_string_op( + t, "stride", "cake" + ) + self.assertEqual(stride_vec, ["stride", str(t.stride(0)), "cake"]) + self.assertEqual(result_stride, t.stride(0)) + + with self.assertRaisesRegex(RuntimeError, "Unsupported accessor value: "): + libtorch_agnostic.ops.my_string_op(t, "invalid", "") + instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": diff --git a/torch/csrc/shim_common.cpp b/torch/csrc/shim_common.cpp index ffbb7bb1235a7..3a437fa78229e 100644 --- a/torch/csrc/shim_common.cpp +++ b/torch/csrc/shim_common.cpp @@ -145,6 +145,10 @@ static StableIValue from_ivalue( list_pointer_to_list_handle(stableivalue_list.release()), extension_build_version); } + case c10::TypeKind::StringType: { + return torch::stable::detail::_from( + ivalue.toStringRef(), extension_build_version); + } default: { TORCH_CHECK( false, @@ -251,6 +255,10 @@ static c10::IValue to_ivalue( TORCH_ERROR_CODE_CHECK(torch_delete_list(list_handle)); return ivalue_list; } + case c10::TypeKind::StringType: { + return c10::IValue(torch::stable::detail::_to( + stable_ivalue, extension_build_version)); + } default: { TORCH_CHECK( false, @@ -578,3 +586,34 @@ torch_get_mutable_data_ptr(AtenTensorHandle tensor, void** ret_data_ptr) { *ret_data_ptr = t->mutable_data_ptr(); }); } + +AOTI_TORCH_EXPORT AOTITorchError +torch_new_string_handle(const char* data, size_t length, StringHandle* handle) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto str_ptr = new std::string(data, length); + *handle = reinterpret_cast(str_ptr); + }); +} + +AOTI_TORCH_EXPORT AOTITorchError torch_delete_string(StringHandle handle) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto str_ptr = reinterpret_cast(handle); + delete str_ptr; + }); +} + +AOTI_TORCH_EXPORT AOTITorchError +torch_string_length(StringHandle handle, size_t* length) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto str_ptr = reinterpret_cast(handle); + *length = str_ptr->length(); + }); +} + +AOTI_TORCH_EXPORT AOTITorchError +torch_string_c_str(StringHandle handle, const char** data) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto str_ptr = reinterpret_cast(handle); + *data = str_ptr->c_str(); + }); +} diff --git a/torch/csrc/stable/c/shim.h b/torch/csrc/stable/c/shim.h index 83bdfd59096fe..202ca3ba40c05 100644 --- a/torch/csrc/stable/c/shim.h +++ b/torch/csrc/stable/c/shim.h @@ -103,6 +103,20 @@ AOTI_TORCH_EXPORT AOTITorchError torch_get_const_data_ptr( const void** ret_data_ptr // returns borrowed reference ); +struct StringOpaque; +using StringHandle = StringOpaque*; + +AOTI_TORCH_EXPORT AOTITorchError +torch_new_string_handle(const char* data, size_t length, StringHandle* handle); + +AOTI_TORCH_EXPORT AOTITorchError torch_delete_string(StringHandle handle); + +AOTI_TORCH_EXPORT AOTITorchError +torch_string_length(StringHandle handle, size_t* length); + +AOTI_TORCH_EXPORT AOTITorchError +torch_string_c_str(StringHandle handle, const char** data); + #ifdef USE_CUDA AOTI_TORCH_EXPORT AOTITorchError diff --git a/torch/csrc/stable/library.h b/torch/csrc/stable/library.h index dc36c4d182478..ac6d252f757a1 100644 --- a/torch/csrc/stable/library.h +++ b/torch/csrc/stable/library.h @@ -131,6 +131,11 @@ struct UnboxType> { using type = std::vector; }; +template <> +struct UnboxType { + using type = std::string; +}; + template using unbox_type_t = typename UnboxType::type; diff --git a/torch/csrc/stable/stableivalue_conversions.h b/torch/csrc/stable/stableivalue_conversions.h index ed885fbe03a12..c44e656d88e11 100644 --- a/torch/csrc/stable/stableivalue_conversions.h +++ b/torch/csrc/stable/stableivalue_conversions.h @@ -64,6 +64,9 @@ struct FromImpl { static_assert( !is_std_vector_v, "std::vector requires TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0"); + static_assert( + !std::is_same_v, + "std::string requires TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0"); static_assert( sizeof(T) <= sizeof(StableIValue), "StableLibrary stack does not support parameter types larger than 64 bits."); @@ -395,6 +398,21 @@ struct FromImpl { } }; +// Specialization for std::string, which should return a new owning reference of +// the string +template <> +struct FromImpl { + static StableIValue call( + const std::string& val, + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { + StringHandle handle; + TORCH_ERROR_CODE_CHECK( + torch_new_string_handle(val.c_str(), val.length(), &handle)) + return from(handle); + } +}; + #endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 // ============================================================================= @@ -408,7 +426,6 @@ struct ToImpl { StableIValue val, [[maybe_unused]] uint64_t extension_build_version, [[maybe_unused]] bool is_internal) { - static_assert(std::is_trivially_copyable_v); // Ensure 2.10+ types don't accidentally use the base case - provide clear // compile-time errors. static_assert( @@ -420,6 +437,10 @@ struct ToImpl { static_assert( !is_std_vector_v, "std::vector requires TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0"); + static_assert( + !std::is_same_v, + "std::string requires TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0"); + static_assert(std::is_trivially_copyable_v); // T may not have a default constructor. (For example, it might be // c10::Device.) However, std::memcpy implicitly creates a T at the // destination. So, we can use a union to work around this lack of @@ -706,6 +727,27 @@ struct ToImpl { } }; +// Specialization for std::string +// Returns a new std::string; the string in val is deleted. +template <> +struct ToImpl { + static std::string call( + StableIValue val, + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { + StringHandle handle = to(val); + size_t length; + TORCH_ERROR_CODE_CHECK(torch_string_length(handle, &length)); + const char* data; + TORCH_ERROR_CODE_CHECK(torch_string_c_str(handle, &data)); + auto strptr = new std::string(data, length); + + // delete the old string before returning new string + TORCH_ERROR_CODE_CHECK(torch_delete_string(handle)); + return *strptr; + } +}; + #endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 // ============================================================================= From 9301432fb66a6a8280a2c405e057438d6344226a Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Fri, 21 Nov 2025 11:35:25 -0800 Subject: [PATCH 205/230] Fix lints with newer triton (#168340) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168340 Approved by: https://github.com/anijain2305, https://github.com/zou3519 --- torch/_higher_order_ops/triton_kernel_wrap.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index 0e398897a7eab..628c889f6cbc7 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -382,14 +382,14 @@ def _get_specialization(args): # type: ignore[no-untyped-def] try: # Latest versions of Triton take specialize_extra as an arg to create_specialize_impl specialize_impl = triton.runtime.jit.create_specialize_impl( - specialize_extra=backend.get_arg_specialization + specialize_extra=backend.get_arg_specialization # pyrefly: ignore [missing-attribute] ) except TypeError: # Unknown arg `specialize_extra` # Older versions of Triton take specialize_extra as an arg to specialize_impl specialize_impl = functools.partial( # pyrefly: ignore # missing-argument triton.runtime.jit.create_specialize_impl(), - specialize_extra=backend.get_arg_specialization, + specialize_extra=backend.get_arg_specialization, # pyrefly: ignore [missing-attribute] ) # create_specialize_impl is removed in https://github.com/triton-lang/triton/pull/7771 # switch to native_specialize_impl instead @@ -413,7 +413,7 @@ def _native_specialize_impl( specialize_impl = functools.partial( specialize_impl_orig, - specialize_extra=backend.get_arg_specialization, + specialize_extra=backend.get_arg_specialization, # pyrefly: ignore [missing-attribute] ) from triton._utils import find_paths_if, get_iterable_path From b565593c62975df4e5086c675d2854bd70eb7d25 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Fri, 21 Nov 2025 11:35:25 -0800 Subject: [PATCH 206/230] [dynamo] Add optree.tree_map microbenchmark (#168341) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168341 Approved by: https://github.com/anijain2305 ghstack dependencies: #168340 --- benchmarks/dynamo/microbenchmarks/.gitignore | 1 + .../dynamo/microbenchmarks/optree_tree_map.py | 121 ++++++++++++++++++ 2 files changed, 122 insertions(+) create mode 100644 benchmarks/dynamo/microbenchmarks/.gitignore create mode 100644 benchmarks/dynamo/microbenchmarks/optree_tree_map.py diff --git a/benchmarks/dynamo/microbenchmarks/.gitignore b/benchmarks/dynamo/microbenchmarks/.gitignore new file mode 100644 index 0000000000000..c627000badbf8 --- /dev/null +++ b/benchmarks/dynamo/microbenchmarks/.gitignore @@ -0,0 +1 @@ +*.prof diff --git a/benchmarks/dynamo/microbenchmarks/optree_tree_map.py b/benchmarks/dynamo/microbenchmarks/optree_tree_map.py new file mode 100644 index 0000000000000..6421bd900e663 --- /dev/null +++ b/benchmarks/dynamo/microbenchmarks/optree_tree_map.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 + +import argparse +import time +from pathlib import Path + +import optree + +import torch +import torch._dynamo +from torch._dynamo.debug_utils import profile_to_file + + +PROFILE_PATH = Path(__file__).with_name("optree_tree_map.prof") + + +def make_tensor_tree(depth: int, branching_factor: int, tensor_size: int, device: str): + """Create a moderately deep pytree populated with tensors.""" + + def _make_level(level: int): + if level == 0: + return torch.randn(tensor_size, tensor_size, device=device) + + children = [_make_level(level - 1) for _ in range(branching_factor)] + return { + "tensor": torch.randn(tensor_size, tensor_size, device=device), + "list": list(children), + "tuple": tuple(children), + } + + return _make_level(depth) + + +def add_leaf(lhs: torch.Tensor, *rest: torch.Tensor) -> torch.Tensor: + out = lhs + for other in rest: + out = out + other + return out + + +def optree_tree_map_loop(lhs, rhs, loop_iters): + tree = lhs + for _ in range(loop_iters): + tree = optree.tree_map( + add_leaf, + tree, + rhs, + namespace="torch", + ) + return tree + + +def _capture_compile_profile(args, lhs, rhs) -> None: + profile_path = Path(args.profile_out) + profile_path.parent.mkdir(parents=True, exist_ok=True) + + @profile_to_file(str(profile_path)) + def _run_compile() -> None: + torch._dynamo.reset() + compiled = torch.compile( + optree_tree_map_loop, + backend="eager", + fullgraph=True, + ) + compiled(lhs, rhs, args.loop_iters) + + print(f"Collecting compile-only cProfile at {profile_path}") + _run_compile() + + +def _parse_args(): + parser = argparse.ArgumentParser() + default_device = "cuda" if torch.cuda.is_available() else "cpu" + parser.add_argument("--device", default=default_device, help="Device to run on") + parser.add_argument( + "--loop-iters", + type=int, + default=50, + help="Number of tree_map calls per compiled invocation", + ) + parser.add_argument( + "--tree-depth", type=int, default=2, help="Depth of the constructed pytree" + ) + parser.add_argument( + "--branching-factor", + type=int, + default=2, + help="Branching factor for list/tuple nodes", + ) + parser.add_argument( + "--tensor-size", + type=int, + default=1, + help="Edge length for square tensor leaves", + ) + parser.add_argument( + "--profile-out", + default=str(PROFILE_PATH), + help="Destination .prof file for the compile-time cProfile", + ) + return parser.parse_args() + + +def main() -> None: + args = _parse_args() + + lhs = make_tensor_tree( + args.tree_depth, args.branching_factor, args.tensor_size, args.device + ) + rhs = make_tensor_tree( + args.tree_depth, args.branching_factor, args.tensor_size, args.device + ) + + t0 = time.perf_counter() + _capture_compile_profile(args, lhs, rhs) + t1 = time.perf_counter() + print(f"Took {t1 - t0:.1f}s") + + +if __name__ == "__main__": + main() From 322ad3099408adad63b9d07c914205c172ea31fb Mon Sep 17 00:00:00 2001 From: drisspg Date: Sat, 22 Nov 2025 02:13:30 +0000 Subject: [PATCH 207/230] [Flex] Fix symbolic shapes lowering (#168383) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168383 Approved by: https://github.com/laithsakka, https://github.com/bobrenjc93 ghstack dependencies: #168319 --- test/inductor/test_flex_attention.py | 45 ++++++++++++++++++++++++++++ torch/_inductor/subgraph_lowering.py | 1 - torch/fx/experimental/sym_node.py | 4 +-- 3 files changed, 47 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 84d179e2ca52b..c095243df7654 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -2278,6 +2278,51 @@ def test_shape(S, backend): test_shapes = [256, 255, 383, 384] _ = [test_shape(S, backend) for S in test_shapes] + @supported_platform + @skip_on_cpu + def test_mask_mod_handles_symint_addition(self, device): + dtype = torch.float16 + + def run(q, k, v): + ql = q.size(-2) + kl = k.size(-2) + frame = 32 + + def _opaque_mask(b, h, q_idx, kv_idx): + ref = ql // frame + mot = kl // frame + limit = (ref + mot) * frame + return q_idx < limit + + block_mask = create_block_mask( + _opaque_mask, + B=q.size(0), + H=q.size(1), + Q_LEN=ql, + KV_LEN=kl, + device=device, + ) + return flex_attention(q, k, v, block_mask=block_mask) + + compiled_run = torch.compile(run, fullgraph=True, dynamic=True) + + q = torch.randn(1, 2, 192, 32, device=device, dtype=dtype) + k = torch.randn(1, 2, 128, 32, device=device, dtype=dtype) + v = torch.randn(1, 2, 128, 32, device=device, dtype=dtype) + + eager_out = run(q, k, v) + compiled_out = compiled_run(q, k, v) + torch.testing.assert_close(eager_out, compiled_out, atol=1e-3, rtol=1e-3) + + # Exercise different dynamic shapes to ensure SymInt sums remain well-formed. + q2 = torch.randn(1, 2, 160, 32, device=device, dtype=dtype) + k2 = torch.randn(1, 2, 96, 32, device=device, dtype=dtype) + v2 = torch.randn(1, 2, 96, 32, device=device, dtype=dtype) + + eager_out2 = run(q2, k2, v2) + compiled_out2 = compiled_run(q2, k2, v2) + torch.testing.assert_close(eager_out2, compiled_out2, atol=1e-3, rtol=1e-3) + @supported_platform def test_multiple_score_mod_calls(self, device): query = torch.randn((1, 8, 1024, 64), dtype=torch.float32, device=device) diff --git a/torch/_inductor/subgraph_lowering.py b/torch/_inductor/subgraph_lowering.py index e0a87ebac3d87..aa1b4d2db025d 100644 --- a/torch/_inductor/subgraph_lowering.py +++ b/torch/_inductor/subgraph_lowering.py @@ -121,7 +121,6 @@ def call_function( raise SubgraphLoweringException( f"{target} not supported in subgraph, (missing lowering)" ) - return lowerings[target](*args, **kwargs) def output(self, target: str, args: tuple[Any], kwargs: dict[str, Any]) -> None: # type: ignore[override] diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index 16b975f6b069a..96b44b0aebd4d 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -959,7 +959,7 @@ def _bitwise_xor(a, b): reflectable_magic_methods = { - "add": _optimized_add, + "add": operator.add, "sub": operator.sub, "mul": operator.mul, "mod": _sympy_mod, @@ -1398,7 +1398,7 @@ def binary_magic_impl(self, other): out = PythonMod(self.expr, other.expr) elif method == "add": # see Note [optimized_summation] - (optimized_summation, out) = func( + (optimized_summation, out) = _optimized_add( self.expr, other.expr, self._optimized_summation, From a9cb5bc90b59a23f38b1f9943af1616b9bb80ce4 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Sat, 22 Nov 2025 00:20:15 -0800 Subject: [PATCH 208/230] [user-streams] Move some estimator utilities outside of distributed (#168343) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168343 Approved by: https://github.com/williamwen42, https://github.com/sanketpurandare --- torch/distributed/_tools/runtime_estimator.py | 151 ++---------------- torch/utils/_runtime_estimation.py | 151 ++++++++++++++++++ 2 files changed, 160 insertions(+), 142 deletions(-) create mode 100644 torch/utils/_runtime_estimation.py diff --git a/torch/distributed/_tools/runtime_estimator.py b/torch/distributed/_tools/runtime_estimator.py index bee54e0454d5d..caf399cf6a802 100644 --- a/torch/distributed/_tools/runtime_estimator.py +++ b/torch/distributed/_tools/runtime_estimator.py @@ -1,6 +1,4 @@ # Owner(s): ["module: unknown"] -import math -import os from collections import defaultdict from typing import Any, TYPE_CHECKING from typing_extensions import Self @@ -8,150 +6,25 @@ import torch import torch.utils._pytree as pytree from torch._guards import active_fake_mode -from torch._inductor.utils import get_device_tflops, get_gpu_dram_gbps from torch._subclasses.fake_tensor import FakeTensorMode from torch.distributed._tools.mod_tracker import ModTracker from torch.utils._mode_utils import no_dispatch from torch.utils._python_dispatch import TorchDispatchMode -from torch.utils.flop_counter import flop_registry +from torch.utils._runtime_estimation import ( + _FLOAT_TYPES, + _IGNORE_OPS, + _VIEW_OPS, + get_compute_time, + get_transfer_time, +) if TYPE_CHECKING: from collections.abc import Callable - -aten = torch.ops.aten - -# This value is hard-coded here: -# https://github.com/pytorch/pytorch/blob/5fba5d83f0703ff8077ab65448a998e9ad6598fd/c10/cuda/CUDACachingAllocator.cpp#L117 -_PYTORCH_MIN_ALLOCATE = ( - 2**9 if int(os.environ.get("PYTORCH_NO_CUDA_MEMORY_CACHING", 0)) == 0 else 1 -) - -# No fall-back kernel needed/exists for view ops -_VIEW_OPS = { - aten.lift_fresh, - aten.t, - aten.transpose, - aten.view, - aten.detach, - aten._unsafe_view, - aten.split, - aten.adjoint, - aten.as_strided, - aten.diagonal, - aten.expand, - aten.expand_as, - aten.movedim, - aten.permute, - aten.select, - aten.squeeze, - aten.mT, - aten.mH, - aten.real, - aten.imag, - aten.view_as, - aten.unflatten, - aten.unfold, - aten.unbind, - aten.unsqueeze, - aten.vsplit, - aten.hsplit, - aten.split_with_sizes, - aten.swapaxes, - aten.swapdims, - aten.chunk, -} -# We can ignore benchmarking tensor create ops -_CREATE_OPS = { - aten.randint, - aten.randn, - aten.rand, - aten.randn_like, - aten.rand_like, - aten.randint_like, - aten.arange, - aten.ones_like, - aten.zeros_like, -} - -_IGNORE_OPS = _VIEW_OPS | _CREATE_OPS - __all__ = ["RuntimeEstimator"] -def get_compute_time(func_packet, args, kwargs, out, out_dtypes) -> float: # type: ignore[no-untyped-def] - """ - Estimates the compute time of an aten operator. - - Args: - func_packet: The operator overload packet. - args: The arguments to the operator. - kwargs: The keyword arguments to the operator. - out: The output of the operator. - out_dtypes: The output data types. - - Returns: - float: The estimated compute time in nanoseconds. - """ - if func_packet in flop_registry: - assert len(out_dtypes) == 1, ( - f"Only support single out dtype got {out_dtypes} for {func_packet}" - ) - dtype = out_dtypes.pop() - # This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s - peak_gpu_flops = get_device_tflops(dtype) * 1e15 - # We can expect to achieve 75% of theoretical peak flops - factor = 0.75 - peak_empirical_flops = factor * peak_gpu_flops - flop_count_func = flop_registry[func_packet] - # We divide by a factor of 2 to get the MACs (multiply and accumulate) - flop_count = flop_count_func(*args, **kwargs, out_val=out) / 2 - # We multiply by 1e9 to get the time in nano seconds - compute_time = (flop_count / peak_empirical_flops) * 1e9 - return compute_time - return 0.0 - - -def get_num_bytes(t: torch.Tensor) -> int: - """ - Calculates the memory consumption of a tensor. - - Args: - t (torch.Tensor): The input tensor. - - Returns: - int: The memory consumption of the tensor in bytes. - """ - num_bytes = t.untyped_storage().nbytes() - mem_consumed = math.ceil(num_bytes / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE - return mem_consumed - - -def get_transfer_time(flat_args_kwargs, flat_outs) -> float: # type: ignore[no-untyped-def] - """ - Estimates the memory transfer time of input and output tensors. - - Args: - flat_args_kwargs (List[torch.Tensor]): The flat list of arguments and keyword arguments. - flat_outs (List[torch.Tensor]): The flat list of outputs. - - Returns: - float: The estimated memory transfer time in nanoseconds. - """ - gpu_memory_bandwidth = get_gpu_dram_gbps() - read_bytes = sum( - get_num_bytes(t) for t in flat_args_kwargs if isinstance(t, torch.Tensor) - ) - write_bytes = sum( - get_num_bytes(t) for t in flat_outs if isinstance(t, torch.Tensor) - ) - counted_bytes = read_bytes + write_bytes - # The GPU memory bandwidth is in GB/s so the transfer time is in nanoseconds - transfer_time = counted_bytes / gpu_memory_bandwidth - return transfer_time - - class RuntimeEstimator(TorchDispatchMode): """ Estimates the GPU runtime in milliseconds using various estimation methods under the ``FakeTensorMode``. @@ -197,12 +70,6 @@ class RuntimeEstimator(TorchDispatchMode): runtime_estimator.display_modulewise_stats() """ - _float_types: set[torch.dtype] = { - torch.float16, - torch.bfloat16, - torch.float32, - torch.float64, - } _no_fallback_kernel: set[torch._ops._OpNamespace] = set() fake_mode: FakeTensorMode @@ -258,7 +125,7 @@ def _maybe_run_and_benchmark_fallback_kernel( # type: ignore[no-untyped-def] def to_real_tensor(e): # type: ignore[no-untyped-def] if cls.fake_mode.is_our_fake(e): - if e.dtype in cls._float_types: + if e.dtype in _FLOAT_TYPES: out = torch.rand_like(e, device=e.fake_device) else: out = torch.ones_like(e, device=e.fake_device) @@ -405,7 +272,7 @@ def _roofline_estimate(cls, func, args, kwargs) -> tuple[Any, float]: # type: i out_dtypes = { t.dtype for t in flat_outs - if isinstance(t, torch.Tensor) and t.dtype in cls._float_types + if isinstance(t, torch.Tensor) and t.dtype in _FLOAT_TYPES } args, kwargs = pytree.tree_unflatten(flat_args_kwargs, args_spec) diff --git a/torch/utils/_runtime_estimation.py b/torch/utils/_runtime_estimation.py new file mode 100644 index 0000000000000..fcda7cceaee48 --- /dev/null +++ b/torch/utils/_runtime_estimation.py @@ -0,0 +1,151 @@ +import math +import os + +import torch +from torch._inductor.utils import get_device_tflops, get_gpu_dram_gbps +from torch.utils._ordered_set import OrderedSet + +from .flop_counter import flop_registry + + +aten = torch.ops.aten + +_FLOAT_TYPES = OrderedSet( + [ + torch.float16, + torch.bfloat16, + torch.float32, + torch.float64, + ] +) + +# This value is hard-coded here: +# https://github.com/pytorch/pytorch/blob/5fba5d83f0703ff8077ab65448a998e9ad6598fd/c10/cuda/CUDACachingAllocator.cpp#L117 +_PYTORCH_MIN_ALLOCATE = ( + 2**9 if int(os.environ.get("PYTORCH_NO_CUDA_MEMORY_CACHING", 0)) == 0 else 1 +) + +# No fall-back kernel needed/exists for view ops +_VIEW_OPS = OrderedSet( + [ + aten.lift_fresh, + aten.t, + aten.transpose, + aten.view, + aten.detach, + aten._unsafe_view, + aten.split, + aten.adjoint, + aten.as_strided, + aten.diagonal, + aten.expand, + aten.expand_as, + aten.movedim, + aten.permute, + aten.select, + aten.squeeze, + aten.mT, + aten.mH, + aten.real, + aten.imag, + aten.view_as, + aten.unflatten, + aten.unfold, + aten.unbind, + aten.unsqueeze, + aten.vsplit, + aten.hsplit, + aten.split_with_sizes, + aten.swapaxes, + aten.swapdims, + aten.chunk, + ] +) +# We can ignore benchmarking tensor create ops +_CREATE_OPS = OrderedSet( + [ + aten.randint, + aten.randn, + aten.rand, + aten.randn_like, + aten.rand_like, + aten.randint_like, + aten.arange, + aten.ones_like, + aten.zeros_like, + ] +) + +_IGNORE_OPS = _VIEW_OPS | _CREATE_OPS + + +def get_compute_time(func_packet, args, kwargs, out, out_dtypes) -> float: # type: ignore[no-untyped-def] + """ + Estimates the compute time of an aten operator. + + Args: + func_packet: The operator overload packet. + args: The arguments to the operator. + kwargs: The keyword arguments to the operator. + out: The output of the operator. + out_dtypes: The output data types. + + Returns: + float: The estimated compute time in nanoseconds. + """ + if func_packet in flop_registry: + assert len(out_dtypes) == 1, ( + f"Only support single out dtype got {out_dtypes} for {func_packet}" + ) + dtype = out_dtypes.pop() + # This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s + peak_gpu_flops = get_device_tflops(dtype) * 1e15 + # We can expect to achieve 75% of theoretical peak flops + factor = 0.75 + peak_empirical_flops = factor * peak_gpu_flops + flop_count_func = flop_registry[func_packet] + # We divide by a factor of 2 to get the MACs (multiply and accumulate) + flop_count = flop_count_func(*args, **kwargs, out_val=out) / 2 + # We multiply by 1e9 to get the time in nano seconds + compute_time = (flop_count / peak_empirical_flops) * 1e9 + return compute_time + return 0.0 + + +def get_num_bytes(t: torch.Tensor) -> int: + """ + Calculates the memory consumption of a tensor. + + Args: + t (torch.Tensor): The input tensor. + + Returns: + int: The memory consumption of the tensor in bytes. + """ + num_bytes = t.untyped_storage().nbytes() + mem_consumed = math.ceil(num_bytes / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE + return mem_consumed + + +def get_transfer_time(flat_args_kwargs, flat_outs) -> float: # type: ignore[no-untyped-def] + """ + Estimates the memory transfer time of input and output tensors. + + Args: + flat_args_kwargs (List[torch.Tensor]): The flat list of arguments and keyword arguments. + flat_outs (List[torch.Tensor]): The flat list of outputs. + + Returns: + float: The estimated memory transfer time in nanoseconds. + """ + gpu_memory_bandwidth = get_gpu_dram_gbps() + read_bytes = sum( + get_num_bytes(t) for t in flat_args_kwargs if isinstance(t, torch.Tensor) + ) + write_bytes = sum( + get_num_bytes(t) for t in flat_outs if isinstance(t, torch.Tensor) + ) + counted_bytes = read_bytes + write_bytes + # The GPU memory bandwidth is in GB/s so the transfer time is in nanoseconds + transfer_time = counted_bytes / gpu_memory_bandwidth + return transfer_time From 1048ac941499c722407cdfab9eee64d3d5956a97 Mon Sep 17 00:00:00 2001 From: Alexander Grund Date: Sat, 22 Nov 2025 14:57:36 +0000 Subject: [PATCH 209/230] Fix exit code condition for test_nan_assert (#167971) The test is skipped on a condition which needs to be used here or it will fail because the exit code is -6 not zero if the condition is not met and the test executed Fixes https://github.com/pytorch/pytorch/pull/154441 @kwen2501 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167971 Approved by: https://github.com/kwen2501 --- test/distributed/test_c10d_nccl.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 512808757c40c..5b1b6c8925806 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -92,6 +92,9 @@ torch.version.cuda is not None or torch.version.hip is not None ) +CUDA_12_AND_ABOVE = torch.cuda.is_available() and ( + torch.version.cuda is not None and int(torch.version.cuda.split(".")[0]) >= 12 +) _start_time = time.time() _logger = logging.getLogger(__name__) @@ -345,7 +348,11 @@ def setUp(self): # These tests are expected to throw SIGABRT(6); # But if we are in Sandcastle, `skip_but_pass_in_sandcastle` would return 0. - TEST_NAN_ASSERT_RETURN = 0 if IS_SANDCASTLE else signal.SIGABRT + TEST_NAN_ASSERT_RETURN = ( + 0 + if (IS_SANDCASTLE and not (TEST_MULTIGPU and CUDA_12_AND_ABOVE)) + else signal.SIGABRT + ) self.special_return_code_checks = { self.test_nan_assert_float16.__wrapped__: TEST_NAN_ASSERT_RETURN, self.test_nan_assert_float32.__wrapped__: TEST_NAN_ASSERT_RETURN, @@ -537,10 +544,6 @@ def init_collective_task(t): # reset ENV os.environ["TORCH_NCCL_CUDA_EVENT_CACHE"] = "0" - CUDA_12_AND_ABOVE = torch.cuda.is_available() and ( - torch.version.cuda is not None and int(torch.version.cuda.split(".")[0]) >= 12 - ) - @requires_nccl() @skip_but_pass_in_sandcastle_if( # skip for cu126 as well due to https://github.com/pytorch/pytorch/issues/153479 From a3cc252e03572835c15afde54b81fc5e8616ad27 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Fri, 21 Nov 2025 22:27:30 -0800 Subject: [PATCH 210/230] [dynamo] Special case handling for tree_map (#168342) This is a ~20x speedup for benchmarks/dynamo/microbenchmarks/optree_tree_map.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/168342 Approved by: https://github.com/anijain2305 --- test/dynamo/test_higher_order_ops.py | 62 +++---- test/dynamo/test_tree_map.py | 259 +++++++++++++++++++++++++++ test/functorch/test_control_flow.py | 18 +- torch/_dynamo/variables/base.py | 80 +++++++++ torch/_dynamo/variables/constant.py | 58 ++++++ torch/_dynamo/variables/dicts.py | 54 +++++- torch/_dynamo/variables/functions.py | 45 +++++ torch/_dynamo/variables/lists.py | 44 +++++ torch/_dynamo/variables/tensor.py | 12 ++ 9 files changed, 591 insertions(+), 41 deletions(-) create mode 100644 test/dynamo/test_tree_map.py diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 4e2a292fc69d4..21398490e7b03 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -4304,15 +4304,15 @@ def forward(self, L_x_: "f32[5]"): _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None - child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None + _wrap_for_grad: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None - child_1: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child); child_1 = None + child: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(_wrap_for_grad); child = None set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None - sin: "f32[5]" = child.sin(); child = None + sin: "f32[5]" = _wrap_for_grad.sin(); _wrap_for_grad = None primals_out: "f32[]" = sin.sum(); sin = None results: "f32[]" = torch._C._functorch._unwrap_for_grad(primals_out, 1); primals_out = None @@ -4352,24 +4352,24 @@ def forward(self, L_x_: "f32[5]", L_v_: "f32[5]"): _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None - child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None + _wrap_for_grad: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None - child_3: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child) + child_2: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(_wrap_for_grad) set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None - child_1: "f32[5]" = child.sin() - child_2: "f32[5]" = child.cos(); child = None + child: "f32[5]" = _wrap_for_grad.sin() + child_1: "f32[5]" = _wrap_for_grad.cos(); _wrap_for_grad = None - _unwrap_for_grad: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1) - _unwrap_for_grad_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_2, 1) + _unwrap_for_grad: "f32[5]" = torch._C._functorch._unwrap_for_grad(child, 1) + _unwrap_for_grad_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1) _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None - _autograd_grad = torch._functorch.eager_transforms._autograd_grad([child_1, child_2], [child_3], [l_v_, l_v_], retain_graph = True, create_graph = True); child_1 = child_2 = child_3 = l_v_ = None + _autograd_grad = torch._functorch.eager_transforms._autograd_grad([child, child_1], [child_2], [l_v_, l_v_], retain_graph = True, create_graph = True); child = child_1 = child_2 = l_v_ = None getitem: "f32[5]" = _autograd_grad[0]; _autograd_grad = None return (_unwrap_for_grad, _unwrap_for_grad_1, getitem) """, @@ -4404,28 +4404,28 @@ def forward(self, L_x_: "f32[5]", L_v_: "f32[5]"): _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None - child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None + _wrap_for_grad: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None - child_3: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child) + child_2: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(_wrap_for_grad) set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None - child_1: "f32[5]" = child.sin() - child_2: "f32[5]" = child.cos(); child = None + child: "f32[5]" = _wrap_for_grad.sin() + child_1: "f32[5]" = _wrap_for_grad.cos(); _wrap_for_grad = None - value: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1) - value_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_2, 1) + _unwrap_for_grad: "f32[5]" = torch._C._functorch._unwrap_for_grad(child, 1) + _unwrap_for_grad_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1) _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None - child_4: "f32[5]" = l_v_.sin() + child_3: "f32[5]" = l_v_.sin() - _autograd_grad = torch._functorch.eager_transforms._autograd_grad([child_1, child_2], [child_3], [l_v_, child_4], retain_graph = True, create_graph = True); child_1 = child_2 = child_3 = l_v_ = child_4 = None + _autograd_grad = torch._functorch.eager_transforms._autograd_grad([child, child_1], [child_2], [l_v_, child_3], retain_graph = True, create_graph = True); child = child_1 = child_2 = l_v_ = child_3 = None getitem: "f32[5]" = _autograd_grad[0]; _autograd_grad = None - return (value, value_1, getitem) + return (_unwrap_for_grad, _unwrap_for_grad_1, getitem) """, ) @@ -4458,18 +4458,18 @@ def forward(self, L_x_: "f32[5]"): _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None - child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None + aux: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None - child_1: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child); child_1 = None + child: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(aux); child = None set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None - sin: "f32[5]" = child.sin() + sin: "f32[5]" = aux.sin() primals_out: "f32[]" = sin.sum(); sin = None - aux: "f32[5]" = torch._C._functorch._unwrap_for_grad(child, 1); child = aux = None + aux_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = aux_1 = None results: "f32[]" = torch._C._functorch._unwrap_for_grad(primals_out, 1); primals_out = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None @@ -5014,11 +5014,11 @@ def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3, 3]"): aux: "f32[3, 3, 3]" = child.cos() _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [child, child_1], create_graph = True); child = child_1 = None - child_2: "f32[3, 3, 3]" = _autograd_grad[0] - child_3: "f32[3, 3, 3]" = _autograd_grad[1]; _autograd_grad = None + getitem: "f32[3, 3, 3]" = _autograd_grad[0] + getitem_1: "f32[3, 3, 3]" = _autograd_grad[1]; _autograd_grad = None - _unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_2, 1); child_2 = None - _unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_3, 1); child_3 = None + _unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(getitem, 1); getitem = None + _unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(getitem_1, 1); getitem_1 = None output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None @@ -5058,11 +5058,11 @@ def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3, 3]"): aux: "f32[3, 3, 3]" = child.cos() _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [child, child_1], create_graph = True); child = child_1 = None - child_2: "f32[3, 3, 3]" = _autograd_grad[0] - child_3: "f32[3, 3, 3]" = _autograd_grad[1]; _autograd_grad = None + getitem: "f32[3, 3, 3]" = _autograd_grad[0] + getitem_1: "f32[3, 3, 3]" = _autograd_grad[1]; _autograd_grad = None - _unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_2, 1); child_2 = None - _unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_3, 1); child_3 = None + _unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(getitem, 1); getitem = None + _unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(getitem_1, 1); getitem_1 = None output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None diff --git a/test/dynamo/test_tree_map.py b/test/dynamo/test_tree_map.py new file mode 100644 index 0000000000000..a7ade021b5acd --- /dev/null +++ b/test/dynamo/test_tree_map.py @@ -0,0 +1,259 @@ +# Owner(s): ["module: dynamo"] + +import optree + +import torch +import torch._dynamo +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests, + TestCase, +) +from torch.utils import _pytree as pytree + + +try: + import torch.utils._cxx_pytree as cxx_pytree +except ImportError: # pragma: no cover + cxx_pytree = None + + +def _tensor_leaf(*values): + first = values[0].clone() + for other in values[1:]: + first = first + other + return first + + +def _combine_leaves(*values): + first = values[0] + if isinstance(first, torch.Tensor): + return _tensor_leaf(*values) + if first is None: + return None + if isinstance(first, tuple): + # When tuples are marked as leaves, keep the structure from + # the leading tree so that specs remain aligned. + return first + total = first + for other in values[1:]: + total = total + other + return total + + +def _tuple_is_leaf(node): + return isinstance(node, tuple) + + +TREE_MAP_IMPLEMENTATIONS = [ + ("optree", optree.tree_map), + ("pytree_python", pytree.tree_map), +] +if cxx_pytree is not None: + TREE_MAP_IMPLEMENTATIONS.append(("pytree_cxx", cxx_pytree.tree_map)) + + +KWARG_CASES = [ + ("default", {}, None), + ("none_is_leaf", {"none_is_leaf": True}, {"optree"}), + ("is_leaf", {"is_leaf": _tuple_is_leaf}, None), + ("namespace", {"namespace": "torch"}, {"optree"}), + ( + "namespace_and_none_is_leaf", + {"namespace": "torch", "none_is_leaf": True}, + {"optree"}, + ), + ( + "namespace_none_is_leaf_predicate", + {"namespace": "torch", "none_is_leaf": True, "is_leaf": _tuple_is_leaf}, + {"optree"}, + ), +] + + +_NONE_IS_LEAF_UNSET = object() + + +def _build_tree(offset: int) -> dict[str, object]: + base = torch.arange(4, dtype=torch.float32).reshape(2, 2) + offset + nested = base + 5 + return { + "tensor": base, + "list": [ + base + 1, + { + "inner": base + 2, + "none": None, + }, + ], + "tuple": (3 + offset, (nested, None)), + "const_dict": {"leaf": base + 3}, + } + + +def _assert_trees_allclose(test_case: TestCase, ref, res) -> None: + ref_flat, ref_spec = pytree.tree_flatten(ref) + res_flat, res_spec = pytree.tree_flatten(res) + test_case.assertEqual(ref_spec, res_spec) + for expected, actual in zip(ref_flat, res_flat): + if isinstance(expected, torch.Tensor): + test_case.assertTrue(torch.allclose(expected, actual)) + else: + test_case.assertEqual(expected, actual) + + +@instantiate_parametrized_tests +class TreeMapCompileTests(TestCase): + def setUp(self): + super().setUp() + torch._dynamo.reset() + + def _run_tree_map(self, tree_map_impl, kwargs): + lhs = _build_tree(0) + rhs = _build_tree(7) + + def fn(a, b): + return tree_map_impl(_combine_leaves, a, b, **kwargs) + + compiled = torch.compile(fn, backend="eager", fullgraph=True) + expected = fn(lhs, rhs) + result = compiled(lhs, rhs) + _assert_trees_allclose(self, expected, result) + + @parametrize("tree_map_name,tree_map_impl", TREE_MAP_IMPLEMENTATIONS) + @parametrize("kwargs_name,kwargs,allowed_impls", KWARG_CASES) + def test_tree_map_variants( + self, + tree_map_name: str, + tree_map_impl, + kwargs_name: str, + kwargs: dict, + allowed_impls, + ) -> None: + if tree_map_name == "pytree_cxx" and cxx_pytree is None: + self.skipTest("torch.utils._cxx_pytree is unavailable") + if allowed_impls is not None and tree_map_name not in allowed_impls: + self.skipTest("kwargs unsupported for implementation") + self._run_tree_map(tree_map_impl, kwargs) + + def test_tree_map_rejects_mismatched_container_types(self) -> None: + def fn(a, b): + return pytree.tree_map(lambda u, v: u + v, a, b) + + lhs = [torch.ones(2), torch.ones(2)] + rhs = (torch.ones(2), torch.ones(2)) + + with self.assertRaises(ValueError): + fn(lhs, rhs) + + compiled = torch.compile(fn, backend="eager", fullgraph=True) + with self.assertRaisesRegex( + (ValueError, torch._dynamo.exc.Unsupported), + "Node type mismatch", + ): + compiled(lhs, rhs) + + def test_tree_map_is_leaf_handles_tensor_nodes(self) -> None: + def fn(tree): + return pytree.tree_map( + lambda pair: torch.stack(pair).sum(dim=0), + tree, + is_leaf=lambda node: isinstance(node, tuple), + ) + + tree = [(torch.ones(2), torch.ones(2) * 4)] + compiled = torch.compile(fn, backend="eager", fullgraph=True) + expected = fn(tree) + result = compiled(tree) + _assert_trees_allclose(self, expected, result) + + def test_tree_map_none_nodes_reject_mismatched_siblings(self) -> None: + def fn(a, b): + return optree.tree_map(lambda u, v: (u, v), a, b) + + lhs = {"k": None} + rhs = {"k": torch.ones(2)} + + with self.assertRaisesRegex(ValueError, "Expected None"): + fn(lhs, rhs) + + compiled = torch.compile(fn, backend="eager", fullgraph=True) + with self.assertRaisesRegex( + (ValueError, torch._dynamo.exc.Unsupported), + r"(Expected None|expected )", + ): + compiled(lhs, rhs) + + @parametrize("tree_map_name,tree_map_impl", TREE_MAP_IMPLEMENTATIONS) + def test_tree_map_none_nodes_default_behavior( + self, tree_map_name: str, tree_map_impl + ) -> None: + if tree_map_name == "optree": + self.skipTest("optree treats None as an internal node by default") + + def fn(a, b): + return tree_map_impl(lambda u, v: (u, v), a, b) + + tree = {"k": None} + compiled = torch.compile(fn, backend="eager", fullgraph=True) + expected = fn(tree, tree) + result = compiled(tree, tree) + + self.assertEqual(result["k"], (None, None)) + self.assertEqual(result, expected) + + def test_constantvariable_handles_none_is_leaf_kwarg(self) -> None: + tree = {"none": None} + + def run_case(none_is_leaf_flag): + def fn(arg): + def mapper(node): + if node is None: + return "visited" + return node + + kwargs = {} + if none_is_leaf_flag is not _NONE_IS_LEAF_UNSET: + kwargs["none_is_leaf"] = none_is_leaf_flag + return optree.tree_map(mapper, arg, **kwargs) + + compiled = torch.compile(fn, backend="eager", fullgraph=True) + expected = fn(tree) + result = compiled(tree) + self.assertEqual(result, expected) + return result["none"] + + self.assertEqual(run_case(_NONE_IS_LEAF_UNSET), None) + self.assertEqual(run_case(False), None) + self.assertEqual(run_case(True), "visited") + + def test_constantvariable_handles_python_and_dtype_leaves(self) -> None: + tree = { + "int": 7, + "nested": {"string": "foo", "dtype": torch.float32}, + } + + def fn(arg): + def mapper(node): + if isinstance(node, int): + return node + 1 + if isinstance(node, str): + return node.upper() + if isinstance(node, torch.dtype): + return torch.float64 + return node + + return optree.tree_map(mapper, arg) + + compiled = torch.compile(fn, backend="eager", fullgraph=True) + expected = fn(tree) + result = compiled(tree) + self.assertEqual(result["int"], 8) + self.assertEqual(result["nested"]["string"], "FOO") + self.assertIs(result["nested"]["dtype"], torch.float64) + self.assertEqual(result, expected) + + +if __name__ == "__main__": # pragma: no cover + run_tests() diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index f83f059663149..bb228fab844fe 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -4186,13 +4186,13 @@ def forward(self, L_xs_0_0_: "f32[3, 10, 2]", L_xs_0_1_0_: "f32[3, 10, 2]", L_xs interleaved_5: "f32[3, 10, 2]" = torch.ops.aten.slice(interleaved_4, 0, 0, 3); interleaved_4 = None - child_17: "f32[3, 10, 2]" = interleaved_1.flip([0]); interleaved_1 = None - child_18: "f32[3, 10, 2]" = interleaved_3.flip([0]); interleaved_3 = None - child_19: "f32[3, 10, 2]" = interleaved_5.flip([0]); interleaved_5 = None + flip_3: "f32[3, 10, 2]" = interleaved_1.flip([0]); interleaved_1 = None + flip_4: "f32[3, 10, 2]" = interleaved_3.flip([0]); interleaved_3 = None + flip_5: "f32[3, 10, 2]" = interleaved_5.flip([0]); interleaved_5 = None - movedim_3: "f32[3, 10, 2]" = torch.movedim(child_17, 0, 0); child_17 = None - movedim_4: "f32[3, 10, 2]" = torch.movedim(child_18, 0, 0); child_18 = None - movedim_5: "f32[3, 10, 2]" = torch.movedim(child_19, 0, 0); child_19 = None + movedim_3: "f32[3, 10, 2]" = torch.movedim(flip_3, 0, 0); flip_3 = None + movedim_4: "f32[3, 10, 2]" = torch.movedim(flip_4, 0, 0); flip_4 = None + movedim_5: "f32[3, 10, 2]" = torch.movedim(flip_5, 0, 0); flip_5 = None return (movedim_3, movedim_4, movedim_5) """, # noqa: B950 ) @@ -8595,7 +8595,7 @@ def forward(self, L_t_: "f32[2, 3]"): getitem_13: "Sym(u20)" = while_loop[5] getitem_14: "Sym(u21)" = while_loop[6] - child: "f32[2, 3]" = while_loop[7]; while_loop = None + getitem_7: "f32[2, 3]" = while_loop[7]; while_loop = None add: "Sym(u15 + 1)" = getitem_8 + 1 add_1: "Sym(u16 + 1)" = getitem_9 + 1 @@ -8604,7 +8604,7 @@ def forward(self, L_t_: "f32[2, 3]"): add_4: "Sym(u19 + 1)" = getitem_12 + 1 add_5: "Sym(u20 + 1)" = getitem_13 + 1 add_6: "Sym(u21 + 1)" = getitem_14 + 1 - add_7: "f32[2, 3]" = child + 1 + add_7: "f32[2, 3]" = getitem_7 + 1 add_8: "f32[2, 3]" = getitem_8 + l_t_; getitem_8 = None add_9: "f32[2, 3]" = getitem_9 + l_t_; getitem_9 = None @@ -8613,7 +8613,7 @@ def forward(self, L_t_: "f32[2, 3]"): add_12: "f32[2, 3]" = getitem_12 + l_t_; getitem_12 = None add_13: "f32[2, 3]" = getitem_13 + l_t_; getitem_13 = None add_14: "f32[2, 3]" = getitem_14 + l_t_; getitem_14 = None - add_15: "f32[2, 3]" = child + l_t_; child = l_t_ = None + add_15: "f32[2, 3]" = getitem_7 + l_t_; getitem_7 = l_t_ = None return (add, add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9, add_10, add_11, add_12, add_13, add_14, add_15) class cond_fn_0(torch.nn.Module): diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index 78f64d882055c..617f787e43d8a 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -14,6 +14,7 @@ """ import collections +import logging from collections.abc import Callable, ItemsView, KeysView, Sequence, ValuesView from enum import Enum from typing import Any, NoReturn, Optional, TYPE_CHECKING @@ -33,6 +34,10 @@ from ..codegen import PyCodegen from ..symbolic_convert import InstructionTranslator from .constant import ConstantVariable + from .functions import UserFunctionVariable + + +log = logging.getLogger(__name__) class SourceType(Enum): @@ -557,6 +562,81 @@ def call_method( hints=hints, ) + def call_tree_map( + self, + tx: Any, + tree_map_fn: "UserFunctionVariable", + map_fn: "VariableTracker", + rest: Sequence["VariableTracker"], + tree_map_kwargs: dict[str, "VariableTracker"], + ) -> "VariableTracker": + """Performance optimization to implement optree.tree_map faster than tracing it""" + is_leaf_var = tree_map_kwargs.get("is_leaf") + if is_leaf_var is not None and not ( + is_leaf_var.is_python_constant() + and is_leaf_var.as_python_constant() is None + ): + pred_result = is_leaf_var.call_function(tx, [self], {}) + try: + leaf_decision = pred_result.as_python_constant() + except NotImplementedError: + return self._tree_map_fallback( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + if leaf_decision: + return map_fn.call_function(tx, [self, *rest], {}) + + return self.call_tree_map_branch( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + + def call_tree_map_branch( + self, + tx: Any, + tree_map_fn: "UserFunctionVariable", + map_fn: "VariableTracker", + rest: Sequence["VariableTracker"], + tree_map_kwargs: dict[str, "VariableTracker"], + ) -> "VariableTracker": + """Emulate optree.tree_map without is_leaf/none_is_leaf checks (handled above)""" + return self._tree_map_fallback( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + + def _tree_map_fallback( + self, + tx: Any, + tree_map_fn: "UserFunctionVariable", + map_fn: "VariableTracker", + rest: Sequence["VariableTracker"], + tree_map_kwargs: dict[str, "VariableTracker"], + ) -> "VariableTracker": + tree_map_fn_copy = tree_map_fn.clone() + tree_map_fn_copy._maybe_call_tree_map_fastpath = lambda *args, **kwargs: None # type: ignore[missing-attribute] + log.debug( + "tree_map fastpath fallback triggered for %s (rest=%s, kwargs=%s)", + self, + rest, + tree_map_kwargs, + ) + return tree_map_fn_copy.call_function( + tx, + [map_fn, self, *rest], + tree_map_kwargs, + ) + def set_name_hint(self, name: str) -> None: pass diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index a8b6a38cb1e9d..672fa1d804383 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -8,6 +8,7 @@ import enum import operator +from collections.abc import Sequence from typing import Any, Literal, Optional, overload, TYPE_CHECKING, Union from typing_extensions import override @@ -29,6 +30,8 @@ if TYPE_CHECKING: from torch._dynamo.symbolic_convert import InstructionTranslator + from .functions import UserFunctionVariable + class ConstantVariable(VariableTracker): """ @@ -275,6 +278,61 @@ def call_method( ) return super().call_method(tx, name, args, kwargs) + def call_tree_map( + self, + tx: "InstructionTranslator", + tree_map_fn: "UserFunctionVariable", + map_fn: VariableTracker, + rest: Sequence[VariableTracker], + tree_map_kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if self.value is None: + none_is_leaf_var = tree_map_kwargs.get("none_is_leaf") + if none_is_leaf_var is not None: + try: + none_is_leaf = bool(none_is_leaf_var.as_python_constant()) + except NotImplementedError: + return self._tree_map_fallback( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + else: + tree_map_module = getattr( + getattr(tree_map_fn, "fn", None), "__module__", "" + ) + # torch.utils._pytree and torch.utils._cxx_pytree treat None as a leaf + # by default, while optree keeps it as an internal node unless + # none_is_leaf=True is provided. + none_is_leaf = not tree_map_module.startswith("optree") + if none_is_leaf: + return map_fn.call_function(tx, [self, *rest], {}) + else: + for other in rest: + if not ( + other.is_python_constant() + and other.as_python_constant() is None + ): + return self._tree_map_fallback( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + return self.clone() + if isinstance(self.value, (int, float, bool, complex, str, bytes, torch.dtype)): + return map_fn.call_function(tx, [self, *rest], {}) + return super().call_tree_map( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + @override def call_obj_hasattr( self, tx: "InstructionTranslator", name: str diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 9b02465d5766e..7a74f487ff96c 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -23,7 +23,7 @@ import inspect import operator import types -from collections.abc import Hashable as py_Hashable +from collections.abc import Hashable as py_Hashable, Sequence from typing import Any, Optional, TYPE_CHECKING, Union from torch._subclasses.fake_tensor import is_fake @@ -51,6 +51,8 @@ from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator + from .functions import UserFunctionVariable + # [Adding a new supported class within the keys of ConstDictVariable] # - Add its tracker type to is_hashable @@ -316,6 +318,56 @@ def __contains__(self, vt: VariableTracker) -> bool: and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable) ) + def call_tree_map_branch( + self, + tx: "InstructionTranslator", + tree_map_fn: "UserFunctionVariable", + map_fn: VariableTracker, + rest: Sequence[VariableTracker], + tree_map_kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + other_dicts: list[ConstDictVariable] = [] + for candidate in rest: + candidate = candidate.realize() + if not isinstance(candidate, ConstDictVariable) or len( + candidate.items + ) != len(self.items): + return self._tree_map_fallback( + tx, tree_map_fn, map_fn, rest, tree_map_kwargs + ) + other_dicts.append(candidate) + + new_items_hashed = type(self.items)() + for key_tracker, value in self.items.items(): + sibling_leaves: list[VariableTracker] = [] + for candidate in other_dicts: + try: + sibling_leaves.append(candidate.items[key_tracker]) + except KeyError: + return self._tree_map_fallback( + tx, tree_map_fn, map_fn, rest, tree_map_kwargs + ) + new_items_hashed[key_tracker] = value.call_tree_map( + tx, + tree_map_fn, + map_fn, + sibling_leaves, + tree_map_kwargs, + ) + + updated_original_items = { + key_tracker.vt: new_items_hashed[key_tracker] + for key_tracker in new_items_hashed + } + + return self.clone( + items=new_items_hashed, + original_items=updated_original_items, + should_reconstruct_all=True, + source=None, + mutation_type=ValueMutationNew(), + ) + def len(self) -> int: return sum( not isinstance(x, variables.DeletedVariable) for x in self.items.values() diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 31ffe9813c3fd..1b85235e7e6dd 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -114,6 +114,7 @@ _F = TypeVar("_F", bound=Callable[..., Any]) CO_VARARGS = 0x04 CO_VARKEYWORDS = 0x08 +_SUPPORTED_TREE_MAP_KWARGS = frozenset({"namespace", "none_is_leaf", "is_leaf"}) # Module-level cache keyed by the function object @@ -420,6 +421,15 @@ class UserFunctionVariable(BaseUserFunctionVariable): *BaseUserFunctionVariable._nonvar_fields, } + _TREE_MAP_MODULES = frozenset( + { + "optree", + "optree.ops", + "torch.utils._pytree", + "torch.utils._cxx_pytree", + } + ) + @classmethod def create_with_source(cls, value: Any, source: Any) -> "UserFunctionVariable": install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH)) @@ -656,8 +666,43 @@ def call_function( ]: with torch._dynamo.side_effects.allow_side_effects_under_checkpoint(tx): return super().call_function(tx, args, kwargs) + + tree_map_result = self._maybe_call_tree_map_fastpath(tx, args, kwargs) + if tree_map_result is not None: + return tree_map_result + return super().call_function(tx, args, kwargs) + def _maybe_call_tree_map_fastpath( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> Optional[VariableTracker]: + if not ( + self._is_tree_map_function() + and not ({*kwargs} - _SUPPORTED_TREE_MAP_KWARGS) + and len(args) >= 2 + ): + return None + + map_fn = args[0] + first_tree = args[1] + rest = args[2:] + return first_tree.call_tree_map( + tx, + self, + map_fn, + rest, + kwargs, + ) + + def _is_tree_map_function(self) -> bool: + return ( + getattr(self.fn, "__name__", None) == "tree_map" + and getattr(self.fn, "__module__", None) in self._TREE_MAP_MODULES + ) + class BuiltinMethodVariable(BaseUserFunctionVariable): def __init__( diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index cafbea5afde1e..05129fcf8fb45 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -140,6 +140,50 @@ def getitem_const( def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: return list(self.items) + def call_tree_map_branch( + self, + tx: "InstructionTranslator", + tree_map_fn: UserFunctionVariable, + map_fn: VariableTracker, + rest: Sequence[VariableTracker], + tree_map_kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if not isinstance(self, (ListVariable, TupleVariable)): + return self._tree_map_fallback( + tx, tree_map_fn, map_fn, rest, tree_map_kwargs + ) + + other_lists: list[BaseListVariable] = [] + for candidate in rest: + if ( + not isinstance(candidate, BaseListVariable) + or len(candidate.items) != len(self.items) + or self.python_type() != candidate.python_type() + ): + return self._tree_map_fallback( + tx, tree_map_fn, map_fn, rest, tree_map_kwargs + ) + other_lists.append(candidate) + + new_items: list[VariableTracker] = [] + for idx, item in enumerate(self.items): + sibling_leaves = [candidate.items[idx] for candidate in other_lists] + new_items.append( + item.call_tree_map( + tx, + tree_map_fn, + map_fn, + sibling_leaves, + tree_map_kwargs, + ) + ) + + return self.clone( + items=new_items, + source=None, + mutation_type=ValueMutationNew(), + ) + def call_method( self, tx: "InstructionTranslator", diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 16fa0997c7f83..0787ef7c49b57 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -83,6 +83,8 @@ from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator + from .functions import UserFunctionVariable + log = logging.getLogger(__name__) @@ -612,6 +614,16 @@ def unpack_var_sequence(self, tx: "InstructionTranslator", idxes=None): for i in idxes ] + def call_tree_map( + self, + tx, + tree_map_fn: "UserFunctionVariable", + map_fn, + rest, + tree_map_kwargs, + ) -> "VariableTracker": + return map_fn.call_function(tx, [self, *rest], {}) + def valid_size(self): return self._size is not None From 9fa3e6e5134e52c5d841769d960618daa722d847 Mon Sep 17 00:00:00 2001 From: linhaifeng <1371675203@qq.com> Date: Sat, 22 Nov 2025 23:50:13 +0000 Subject: [PATCH 211/230] [BugFix] Fix incorrect type hint. (#168892) tuple[int] -> tuple[int,...] 1 -> more Like shape, shape: tuple[int, ...] # [B, Hq, M, Hkv, N, D] Inspired by https://github.com/pytorch/pytorch/pull/168320 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168892 Approved by: https://github.com/cyyever, https://github.com/Skylion007 --- benchmarks/transformer/score_mod.py | 21 ++++++++------- .../test_torchinductor_strided_blocks.py | 27 ++++++++++--------- test/test_optim.py | 20 +++++++++----- .../_internal/common_methods_invocations.py | 6 ++--- 4 files changed, 44 insertions(+), 30 deletions(-) diff --git a/benchmarks/transformer/score_mod.py b/benchmarks/transformer/score_mod.py index e9af132df28a9..b120d987514e9 100644 --- a/benchmarks/transformer/score_mod.py +++ b/benchmarks/transformer/score_mod.py @@ -147,7 +147,7 @@ def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> @dataclass(frozen=True) class ExperimentConfig: - shape: tuple[int] # [B, Hq, M, Hkv, N, D] + shape: tuple[int, ...] # [B, Hq, M, Hkv, N, D] attn_type: str dtype: torch.dtype calculate_bwd_time: bool @@ -257,7 +257,7 @@ def generate_inputs( def generate_jagged_inputs( - shape: tuple[int], + shape: tuple[int, ...], query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -720,7 +720,7 @@ def print_results(results: list[Experiment], save_path: Optional[str] = None): dropout_p = 0.0 -def generate_score_mod(attn_type: str, shape: tuple[int]) -> Callable | None: +def generate_score_mod(attn_type: str, shape: tuple[int, ...]) -> Callable | None: B, Hq, M, Hkv, N, D = shape is_decoding = M == 1 from attn_gym.mods import generate_alibi_bias, generate_tanh_softcap @@ -762,7 +762,7 @@ def score_mod_w_offset(score, b, h, m, n): prefix_length = 512 -def generate_block_mask(attn_type: str, shape: tuple[int]): +def generate_block_mask(attn_type: str, shape: tuple[int, ...]): B, Hq, M, Hkv, N, D = shape is_decoding = M == 1 @@ -837,7 +837,7 @@ def decoding_w_cached_seq_len(b, h, m, n): return block_mask, mask_mod_kwargs -def get_kernel_options(attn_type: str, shape: tuple[int]): +def get_kernel_options(attn_type: str, shape: tuple[int, ...]): B, Hq, M, Hkv, N, D = shape is_decoding = M == 1 kernel_opt_training_dict = { @@ -924,7 +924,7 @@ def get_backend_context(backend: str): def generate_FA_callable( - attn_type: str, shape: tuple[int], dtype: torch.dtype, backend: str, **kwargs + attn_type: str, shape: tuple[int, ...], dtype: torch.dtype, backend: str, **kwargs ) -> Callable | None: if dtype not in [torch.float16, torch.bfloat16]: return None @@ -983,7 +983,7 @@ def offsets_to_lengths( def generate_FD_callable( - attn_type: str, shape: tuple[int], dtype: torch.dtype + attn_type: str, shape: tuple[int, ...], dtype: torch.dtype ) -> Callable | None: if dtype not in [torch.float16, torch.bfloat16]: return None @@ -1030,7 +1030,10 @@ def flash_attn_with_kvcache_renamed(q, k, v, **kwargs): def generate_attn_mask_linear_score_mod( - shape: tuple[int], block_mask: BlockMask, score_mod: Callable, dtype: torch.dtype + shape: tuple[int, ...], + block_mask: BlockMask, + score_mod: Callable, + dtype: torch.dtype, ): B, Hq, M, N = shape if block_mask is None and score_mod is None: @@ -1055,7 +1058,7 @@ def generate_attn_mask_linear_score_mod( def generate_eager_sdpa( attn_type: str, - shape: tuple[int], + shape: tuple[int, ...], dtype: torch.dtype, block_mask: BlockMask, score_mod: Callable | None = None, diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index 7a9edd5570f3e..d70375ebc3345 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -246,9 +246,9 @@ def foo(x, y): ) def test_pointwise( self, - full_size: tuple[int], - view_size: tuple[int], - stride: Optional[tuple[int]], + full_size: tuple[int, ...], + view_size: tuple[int, ...], + stride: Optional[tuple[int, ...]], offset: Optional[int], require_block_ptr: bool, prefer_nd_tiling: bool, @@ -298,7 +298,7 @@ def get_input() -> torch.Tensor: ], ) def test_broadcast( - self, x_size: tuple[int], y_size: tuple[int], prefer_nd_tiling: bool + self, x_size: tuple[int, ...], y_size: tuple[int, ...], prefer_nd_tiling: bool ): """ Test that we can generate strided block pointers when inputs have different @@ -415,7 +415,7 @@ def load_args(reader): ((5, 6, 1, 1), (5, 6, 4, 3)), ], ) - def test_expand_broadcast(self, x_size: tuple[int], y_size: tuple[int]): + def test_expand_broadcast(self, x_size: tuple[int, ...], y_size: tuple[int, ...]): """ When the load and store have different shapes, we should use broadcast. """ @@ -423,7 +423,7 @@ def test_expand_broadcast(self, x_size: tuple[int], y_size: tuple[int]): def foo(x, y_size): return x.expand(y_size).clone() - def get_input(size: tuple[int]) -> torch.Tensor: + def get_input(size: tuple[int, ...]) -> torch.Tensor: device = torch.device(self.device) full = torch.randn(size).to(device) view = torch.as_strided(full, size, full.stride()) @@ -522,7 +522,7 @@ def test_pointwise_broadcast_nonzero_strides(self, prefer_nd_tiling: bool): ) def test_reduction( self, - view_size: tuple[int], + view_size: tuple[int, ...], num_block_pointers: int, num_triton_kernels: int, prefer_nd_tiling: bool, @@ -574,7 +574,10 @@ def test_reduction( ], ) def test_mixed_pointwise_reduction( - self, view_size: tuple[int], num_block_pointers: int, num_triton_kernels: int + self, + view_size: tuple[int, ...], + num_block_pointers: int, + num_triton_kernels: int, ): """ Tests mixing pointwise with reduction ops. @@ -744,8 +747,8 @@ def foo(x): ) def test_nd_tiling_odd_shapes_pointwise( self, - full_size: tuple[int], - view_size: tuple[int], + full_size: tuple[int, ...], + view_size: tuple[int, ...], num_block_pointers: int, num_tiles: int, ): @@ -794,7 +797,7 @@ def get_input() -> torch.Tensor: ) def test_2d_reduction_odd_shapes( self, - view_size: tuple[int], + view_size: tuple[int, ...], num_block_pointers: int, num_triton_kernels: int, reduction_op: Callable, @@ -829,7 +832,7 @@ def test_2d_reduction_odd_shapes( ) def test_2d_welford_reduction( self, - size: tuple[int], + size: tuple[int, ...], expected_num_block_pointers: int, expected_num_triton_kernels: int, expect_fallback: bool, diff --git a/test/test_optim.py b/test/test_optim.py index de185725b5c2c..973e6d6fe6845 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -1993,7 +1993,7 @@ def test_load_state_dict_pre_post_hook(self, device, dtype, optim_info): @optims(optim_db, dtypes=[torch.float32]) def test_step_post_hook(self, device, dtype, optim_info): - def post_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]): + def post_hook(opt: Optimizer, args: tuple[Any, ...], kwargs: dict[Any, Any]): nonlocal data data += 2 @@ -2025,7 +2025,7 @@ def dummy_closure(): @optims(optim_db, dtypes=[torch.float32]) def test_step_pre_hook(self, device, dtype, optim_info): - def pre_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]): + def pre_hook(opt: Optimizer, args: tuple[Any, ...], kwargs: dict[Any, Any]): nonlocal data data += 2 @@ -2058,19 +2058,27 @@ def dummy_closure(): @optims(optim_db, dtypes=[torch.float32]) def test_step_all_hooks(self, device, dtype, optim_info): - def global_pre_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]): + def global_pre_hook( + opt: Optimizer, args: tuple[Any, ...], kwargs: dict[Any, Any] + ): nonlocal data data.append(0) - def global_post_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]): + def global_post_hook( + opt: Optimizer, args: tuple[Any, ...], kwargs: dict[Any, Any] + ): nonlocal data data.append(5) - def local_pre_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]): + def local_pre_hook( + opt: Optimizer, args: tuple[Any, ...], kwargs: dict[Any, Any] + ): nonlocal data data.append(1) - def local_post_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]): + def local_post_hook( + opt: Optimizer, args: tuple[Any, ...], kwargs: dict[Any, Any] + ): nonlocal data data.append(2) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index cf2fd54490591..0cf0f50c23ef5 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -11704,11 +11704,11 @@ def reference_mse_loss(input, target, reduction="mean"): return se -def reference_layer_norm(inp: npt.NDArray, normalized_shape: tuple[int], weight=None, bias=None, eps=1e-5): +def reference_layer_norm(inp: npt.NDArray, normalized_shape: tuple[int, ...], weight=None, bias=None, eps=1e-5): return reference_native_layer_norm(inp, normalized_shape, weight, bias, eps)[0] -def reference_native_layer_norm(inp: npt.NDArray, normalized_shape: tuple[int], weight, bias, eps): +def reference_native_layer_norm(inp: npt.NDArray, normalized_shape: tuple[int, ...], weight, bias, eps): feature_size = np.prod(normalized_shape) inp_view = inp.reshape(-1, feature_size) # type: ignore[call-overload] mean = inp_view.mean(axis=-1, keepdims=True) @@ -11725,7 +11725,7 @@ def reference_native_layer_norm(inp: npt.NDArray, normalized_shape: tuple[int], return Y.reshape(*inp.shape), mean.reshape(stat_shape), (1.0 / np.sqrt(var + eps)).reshape(stat_shape) -def reference_rms_norm(inp: npt.NDArray, normalized_shape: tuple[int], weight=None, eps=None): +def reference_rms_norm(inp: npt.NDArray, normalized_shape: tuple[int, ...], weight=None, eps=None): if eps is None: eps = torch.finfo(numpy_to_torch_dtype(inp.dtype)).eps feature_size = np.prod(normalized_shape) From 2c204e6dfc98328ca64beaca4ec49810141809e5 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sun, 23 Nov 2025 00:46:16 +0000 Subject: [PATCH 212/230] Revert "[Inductor XPU GEMM] Step 2/N: Move out cutlass files from torch/_inductor/codegen/cuda (#160685)" This reverts commit 7556637e289d00f8aec58252c3b5b45dcfd6eb61. Reverted https://github.com/pytorch/pytorch/pull/160685 on behalf of https://github.com/yangw-dev due to failed internal tests test_cpu_/test_cpu#link-tree/torch/utils/_config_module.py line 371, in _config = self._config[name] KeyError: 'cuda.cutlass_dir' Diff: D87660662 ([comment](https://github.com/pytorch/pytorch/pull/160174#issuecomment-3567237578)) --- test/inductor/test_cutlass_backend.py | 27 ++++++++++--------- test/inductor/test_cutlass_evt.py | 14 +++++----- test/test_public_bindings.py | 4 +-- torch/_inductor/codegen/common.py | 2 +- .../codegen/cuda/cuda_cpp_scheduling.py | 6 +++-- .../codegen/{cutlass => cuda}/cuda_kernel.py | 6 ++--- .../{cutlass => cuda}/cuda_template.py | 2 +- .../cache.py => cuda/cutlass_cache.py} | 8 +++--- .../cutlass_lib_extensions}/__init__.py | 0 .../cutlass_mock_imports}/__init__.py | 0 .../cutlass_mock_imports/cuda/__init__.py | 0 .../cutlass_mock_imports/cuda/cuda.py | 0 .../cutlass_mock_imports/cuda/cudart.py | 0 .../cutlass_mock_imports/pydot/__init__.py | 0 .../cutlass_mock_imports/scipy/__init__.py | 0 .../cutlass_mock_imports/scipy/special.py | 0 .../cutlass_lib_extensions}/evt_extensions.py | 2 +- .../gemm_operation_extensions.py | 2 +- .../cutlass_python_evt.py} | 0 .../utils.py => cuda/cutlass_utils.py} | 6 ++--- .../{cutlass => cuda}/gemm_template.py | 14 +++++----- .../{cutlass => cuda}/serialization.py | 2 +- .../cutlass_mock_imports/__init__.py | 0 torch/_inductor/ir.py | 2 +- torch/_inductor/kernel/bmm.py | 2 +- torch/_inductor/kernel/mm.py | 2 +- torch/_inductor/select_algorithm.py | 16 +++++------ torch/_inductor/utils.py | 2 +- 28 files changed, 57 insertions(+), 62 deletions(-) rename torch/_inductor/codegen/{cutlass => cuda}/cuda_kernel.py (99%) rename torch/_inductor/codegen/{cutlass => cuda}/cuda_template.py (99%) rename torch/_inductor/codegen/{cutlass/cache.py => cuda/cutlass_cache.py} (94%) rename torch/_inductor/codegen/{cutlass => cuda/cutlass_lib_extensions}/__init__.py (100%) rename torch/_inductor/codegen/{cutlass/lib_extensions => cuda/cutlass_lib_extensions/cutlass_mock_imports}/__init__.py (100%) rename torch/_inductor/codegen/{cutlass/lib_extensions => cuda/cutlass_lib_extensions}/cutlass_mock_imports/cuda/__init__.py (100%) rename torch/_inductor/codegen/{cutlass/lib_extensions => cuda/cutlass_lib_extensions}/cutlass_mock_imports/cuda/cuda.py (100%) rename torch/_inductor/codegen/{cutlass/lib_extensions => cuda/cutlass_lib_extensions}/cutlass_mock_imports/cuda/cudart.py (100%) rename torch/_inductor/codegen/{cutlass/lib_extensions => cuda/cutlass_lib_extensions}/cutlass_mock_imports/pydot/__init__.py (100%) rename torch/_inductor/codegen/{cutlass/lib_extensions => cuda/cutlass_lib_extensions}/cutlass_mock_imports/scipy/__init__.py (100%) rename torch/_inductor/codegen/{cutlass/lib_extensions => cuda/cutlass_lib_extensions}/cutlass_mock_imports/scipy/special.py (100%) rename torch/_inductor/codegen/{cutlass/lib_extensions => cuda/cutlass_lib_extensions}/evt_extensions.py (99%) rename torch/_inductor/codegen/{cutlass/lib_extensions => cuda/cutlass_lib_extensions}/gemm_operation_extensions.py (99%) rename torch/_inductor/codegen/{cutlass/python_evt.py => cuda/cutlass_python_evt.py} (100%) rename torch/_inductor/codegen/{cutlass/utils.py => cuda/cutlass_utils.py} (99%) rename torch/_inductor/codegen/{cutlass => cuda}/gemm_template.py (99%) rename torch/_inductor/codegen/{cutlass => cuda}/serialization.py (99%) delete mode 100644 torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/__init__.py diff --git a/test/inductor/test_cutlass_backend.py b/test/inductor/test_cutlass_backend.py index 212795c2d4925..673d3e87d2a5f 100644 --- a/test/inductor/test_cutlass_backend.py +++ b/test/inductor/test_cutlass_backend.py @@ -14,9 +14,7 @@ from typing import Optional from torch._dynamo.exc import BackendCompilerFailed -from torch._inductor.codegen.cutlass.serialization import ( - get_cutlass_operation_serializer, -) +from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer from torch._inductor.utils import clear_caches from torch.export import Dim from torch.testing._internal.logging_utils import log_settings @@ -34,8 +32,11 @@ from torch._dynamo import config as dynamo_config from torch._dynamo.utils import counters from torch._inductor import config -from torch._inductor.codegen.cutlass.cuda_kernel import CUDATemplateCaller -from torch._inductor.codegen.cutlass.utils import _gen_ops_cached, get_max_alignment +from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller +from torch._inductor.codegen.cuda.cutlass_utils import ( + _gen_ops_cached, + get_max_alignment, +) from torch._inductor.exc import InductorError from torch._inductor.ir import FixedLayout from torch._inductor.select_algorithm import NoValidChoicesError @@ -205,7 +206,7 @@ def run_evt_test(self, model, op, shape, num_fusions=1): def test_check_paths(self): cutlass_mock_imports_path = os.path.join( os.path.dirname(torch.__file__), - "_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports", + "_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports", ) cutlass_mock_cuda_path = os.path.join(cutlass_mock_imports_path, "cuda") cutlass_mock_pydot_path = os.path.join(cutlass_mock_imports_path, "pydot") @@ -250,7 +251,7 @@ def mm(a, b): @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_import_cutlass(self): - from torch._inductor.codegen.cutlass.utils import try_import_cutlass + from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass self.assertTrue(try_import_cutlass()) @@ -258,7 +259,7 @@ def test_import_cutlass(self): import cutlass_library # noqa: F401 def test_cutlass_key(self): - from torch._inductor.codegen.cutlass.utils import try_import_cutlass + from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass self.assertTrue(try_import_cutlass()) from torch._inductor.codecache import cutlass_key @@ -1466,7 +1467,7 @@ def test_standalone_runner(self): ): from tempfile import NamedTemporaryFile - from torch._inductor.codegen.cutlass.utils import ( + from torch._inductor.codegen.cuda.cutlass_utils import ( cuda_standalone_runner_compile_command, CUDACompileSourceCapturingContext, ) @@ -1552,7 +1553,7 @@ def mm(a, b): with ( log_settings("+inductor"), self.assertLogs( - logger="torch._inductor.codegen.cutlass", level=logging.DEBUG + logger="torch._inductor.codegen.cuda", level=logging.DEBUG ) as test_log, ): Y_compiled = torch.compile(mm, dynamic=False)(a, b) @@ -1590,7 +1591,7 @@ def forward(self, A, B): expected = model(A, B) # Track render calls - from torch._inductor.codegen.cutlass.gemm_template import CUTLASSGemmTemplate + from torch._inductor.codegen.cuda.gemm_template import CUTLASSGemmTemplate original_render = CUTLASSGemmTemplate.render render_call_count = 0 @@ -1644,7 +1645,7 @@ def forward(self, a, b, c, d): d = torch.randn(64, 128).cuda().half().t() # Track render calls - from torch._inductor.codegen.cutlass.gemm_template import CUTLASSGemmTemplate + from torch._inductor.codegen.cuda.gemm_template import CUTLASSGemmTemplate original_render = CUTLASSGemmTemplate.render render_call_count = 0 @@ -1705,7 +1706,7 @@ def forward(self, a, b): b = torch.randn(32, 64).cuda().half().t() # Track render calls - from torch._inductor.codegen.cutlass.gemm_template import CUTLASSGemmTemplate + from torch._inductor.codegen.cuda.gemm_template import CUTLASSGemmTemplate original_render = CUTLASSGemmTemplate.render render_call_count = 0 diff --git a/test/inductor/test_cutlass_evt.py b/test/inductor/test_cutlass_evt.py index dd296b7f75ac7..862aeb5db1c88 100644 --- a/test/inductor/test_cutlass_evt.py +++ b/test/inductor/test_cutlass_evt.py @@ -5,7 +5,7 @@ import torch from torch._dynamo.test_case import TestCase -from torch._inductor.codegen.cutlass.utils import ( +from torch._inductor.codegen.cuda.cutlass_utils import ( torch_dtype_to_cutlass_type, try_import_cutlass, ) @@ -28,7 +28,7 @@ DataType = cutlass_lib.DataType from cutlass_cppgen.backend.evt.ir.tensor import Tensor as CutlassTensor - from torch._inductor.codegen.cutlass.lib_extensions.evt_extensions import ( + from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import ( _render_argument_type, _trace, trace, @@ -107,7 +107,7 @@ class TestCutlassEVT(TestCase): @unittest.skipIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_py_codegen_accumulator_return(self): - from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen + from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen from torch._inductor.virtualized import V size = (100, 300, 200) @@ -164,7 +164,7 @@ def fn(accum, buf1, buf2): @unittest.skipIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_py_codegen_disjoint_read_indexing(self): - from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen + from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen from torch._inductor.virtualized import V size = (100, 300, 200) @@ -213,7 +213,7 @@ def inner_fn_buf4(index): @unittest.skipIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_py_codegen_broadcasting(self): - from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen + from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen from torch._inductor.virtualized import V size = (100, 300, 200) @@ -273,7 +273,7 @@ def fn(accum, buf1, buf2): @unittest.skipIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_py_codegen(self): - from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen + from torch._inductor.codegen.cuda.cutlass_python_evt import CutlassEVTCodegen from torch._inductor.virtualized import V size = (100, 300, 200) @@ -329,7 +329,7 @@ def fn(accum, buf1, buf2): @unittest.skipIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_example_tensor_creation(self): - from torch._inductor.codegen.cutlass.lib_extensions.evt_extensions import ( + from torch._inductor.codegen.cuda.cutlass_lib_extensions.evt_extensions import ( create_example_tensors, ) from torch._inductor.virtualized import V diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index d175a205935a7..7a9f8f3aa317f 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -292,7 +292,7 @@ def onerror(modname): # do not get imported by public code. # DO NOT add public modules here. private_allowlist = { - "torch._inductor.codegen.cutlass.cuda_kernel", + "torch._inductor.codegen.cuda.cuda_kernel", # TODO(#133647): Remove the onnx._internal entries after # onnx and onnxscript are installed in CI. "torch.onnx._internal.exporter", @@ -357,7 +357,7 @@ def onerror(modname): "torch.testing._internal.distributed.rpc.rpc_test", "torch.testing._internal.distributed.rpc.tensorpipe_rpc_agent_test_fixture", "torch.testing._internal.distributed.rpc_utils", - "torch._inductor.codegen.cutlass.cuda_template", + "torch._inductor.codegen.cuda.cuda_template", "torch._inductor.codegen.cutedsl._cutedsl_utils", "torch._inductor.codegen.cuda.gemm_template", "torch._inductor.codegen.cpp_template", diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 617b0a91d67a0..8b5e68780cb28 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -2657,7 +2657,7 @@ def _bound_variable(self, name: str, *args: Any, **kwargs: Any) -> ValueRanges[A """ from ..bounds import ValueRangeAnalysis from ..select_algorithm import TritonTemplateKernel - from .cutlass.cuda_kernel import CUDATemplateKernel + from .cuda.cuda_kernel import CUDATemplateKernel if isinstance(V.kernel, TritonTemplateKernel): return ValueRanges.unknown() diff --git a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py index 591a95b18f252..16b09d4ba80eb 100644 --- a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py +++ b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from typing import cast -from torch._inductor.codegen.cutlass.python_evt import ( +from torch._inductor.codegen.cuda.cutlass_python_evt import ( CutlassEVTCodegen, MockCutlassHandler, ) @@ -267,7 +267,9 @@ def _can_fuse_epilogue_impl( return False try: - from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen + from torch._inductor.codegen.cuda.cutlass_python_evt import ( + CutlassEVTCodegen, + ) CutlassEVTCodegen.ir_to_evt_python_code( cuda_template_buffer.get_name(), diff --git a/torch/_inductor/codegen/cutlass/cuda_kernel.py b/torch/_inductor/codegen/cuda/cuda_kernel.py similarity index 99% rename from torch/_inductor/codegen/cutlass/cuda_kernel.py rename to torch/_inductor/codegen/cuda/cuda_kernel.py index 9622dc759a6c4..97643ef00a7bd 100644 --- a/torch/_inductor/codegen/cutlass/cuda_kernel.py +++ b/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -16,7 +16,7 @@ from torch._inductor.utils import do_bench_using_profiling, OrderedSet, Placeholder from torch.utils._sympy.value_ranges import ValueRanges -from .utils import DTYPE_TO_CUTLASS_TYPE +from .cutlass_utils import DTYPE_TO_CUTLASS_TYPE if TYPE_CHECKING: @@ -47,7 +47,7 @@ if TYPE_CHECKING: - from torch._inductor.codegen.cutlass.cuda_template import CUDATemplate + from torch._inductor.codegen.cuda.cuda_template import CUDATemplate log = logging.getLogger(__name__) @@ -424,7 +424,7 @@ def cutlass_dtype(self, node: IRNode, default_dtype="void") -> Optional[str]: # Helper method, called into from CUTLASSGemmTemplate if node is None: return default_dtype - from torch._inductor.codegen.cutlass.cuda_template import CUTLASSTemplate + from torch._inductor.codegen.cuda.cuda_template import CUTLASSTemplate return CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype] diff --git a/torch/_inductor/codegen/cutlass/cuda_template.py b/torch/_inductor/codegen/cuda/cuda_template.py similarity index 99% rename from torch/_inductor/codegen/cutlass/cuda_template.py rename to torch/_inductor/codegen/cuda/cuda_template.py index 384713f157062..92c86120570d6 100644 --- a/torch/_inductor/codegen/cutlass/cuda_template.py +++ b/torch/_inductor/codegen/cuda/cuda_template.py @@ -20,7 +20,7 @@ from ...virtualized import V from ..common import KernelTemplate from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel -from .utils import DTYPE_TO_CUTLASS_TYPE +from .cutlass_utils import DTYPE_TO_CUTLASS_TYPE if TYPE_CHECKING: diff --git a/torch/_inductor/codegen/cutlass/cache.py b/torch/_inductor/codegen/cuda/cutlass_cache.py similarity index 94% rename from torch/_inductor/codegen/cutlass/cache.py rename to torch/_inductor/codegen/cuda/cutlass_cache.py index 9de1a6257c2d1..cad4a37902304 100644 --- a/torch/_inductor/codegen/cutlass/cache.py +++ b/torch/_inductor/codegen/cuda/cutlass_cache.py @@ -10,11 +10,9 @@ import torch._inductor.config as config from torch._inductor.codecache import cutlass_key +from torch._inductor.codegen.cuda import cutlass_utils, serialization from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch, get_cuda_version -from torch._inductor.codegen.cutlass import serialization, utils -from torch._inductor.codegen.cutlass.serialization import ( - get_cutlass_operation_serializer, -) +from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer from torch._inductor.runtime.cache_dir_utils import cache_dir from torch._inductor.utils import clear_on_fresh_cache @@ -41,7 +39,7 @@ def get_file_hash(file_module): return hashlib.sha256(f.read()).hexdigest() serialization_hash = get_file_hash(serialization) - cutlass_utils_hash = get_file_hash(utils) + cutlass_utils_hash = get_file_hash(cutlass_utils) hash_target = "-".join( [ diff --git a/torch/_inductor/codegen/cutlass/__init__.py b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py similarity index 100% rename from torch/_inductor/codegen/cutlass/__init__.py rename to torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py diff --git a/torch/_inductor/codegen/cutlass/lib_extensions/__init__.py b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/__init__.py similarity index 100% rename from torch/_inductor/codegen/cutlass/lib_extensions/__init__.py rename to torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/__init__.py diff --git a/torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/cuda/__init__.py b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__init__.py similarity index 100% rename from torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/cuda/__init__.py rename to torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/__init__.py diff --git a/torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/cuda/cuda.py b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/cuda.py similarity index 100% rename from torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/cuda/cuda.py rename to torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/cuda.py diff --git a/torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/cuda/cudart.py b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/cudart.py similarity index 100% rename from torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/cuda/cudart.py rename to torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/cuda/cudart.py diff --git a/torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/pydot/__init__.py b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/pydot/__init__.py similarity index 100% rename from torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/pydot/__init__.py rename to torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/pydot/__init__.py diff --git a/torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/scipy/__init__.py b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/__init__.py similarity index 100% rename from torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/scipy/__init__.py rename to torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/__init__.py diff --git a/torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/scipy/special.py b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/special.py similarity index 100% rename from torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/scipy/special.py rename to torch/_inductor/codegen/cuda/cutlass_lib_extensions/cutlass_mock_imports/scipy/special.py diff --git a/torch/_inductor/codegen/cutlass/lib_extensions/evt_extensions.py b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py similarity index 99% rename from torch/_inductor/codegen/cutlass/lib_extensions/evt_extensions.py rename to torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py index c1daa78228ba1..472438fec90e3 100644 --- a/torch/_inductor/codegen/cutlass/lib_extensions/evt_extensions.py +++ b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py @@ -10,7 +10,7 @@ ) from torch.utils._ordered_set import OrderedSet -from ..utils import torch_dtype_to_cutlass_type, try_import_cutlass +from ..cutlass_utils import torch_dtype_to_cutlass_type, try_import_cutlass EpilogueFunctor = Any # EpilogueFunctor local class defined in _trace diff --git a/torch/_inductor/codegen/cutlass/lib_extensions/gemm_operation_extensions.py b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py similarity index 99% rename from torch/_inductor/codegen/cutlass/lib_extensions/gemm_operation_extensions.py rename to torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py index d10669d40bea0..95af1a968a97c 100644 --- a/torch/_inductor/codegen/cutlass/lib_extensions/gemm_operation_extensions.py +++ b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py @@ -1,5 +1,5 @@ # mypy: ignore-errors -from ..utils import try_import_cutlass +from ..cutlass_utils import try_import_cutlass # copied / modified from original at diff --git a/torch/_inductor/codegen/cutlass/python_evt.py b/torch/_inductor/codegen/cuda/cutlass_python_evt.py similarity index 100% rename from torch/_inductor/codegen/cutlass/python_evt.py rename to torch/_inductor/codegen/cuda/cutlass_python_evt.py diff --git a/torch/_inductor/codegen/cutlass/utils.py b/torch/_inductor/codegen/cuda/cutlass_utils.py similarity index 99% rename from torch/_inductor/codegen/cutlass/utils.py rename to torch/_inductor/codegen/cuda/cutlass_utils.py index 56e02edbb99d5..3ce3a49bb94e9 100644 --- a/torch/_inductor/codegen/cutlass/utils.py +++ b/torch/_inductor/codegen/cuda/cutlass_utils.py @@ -23,7 +23,7 @@ from ...runtime.runtime_utils import cache_dir from ...virtualized import V from ..cpp_utils import DTYPE_TO_CPP -from ..cuda.cuda_env import get_cuda_arch, get_cuda_version +from .cuda_env import get_cuda_arch, get_cuda_version log = logging.getLogger(__name__) @@ -104,8 +104,8 @@ def path_join(path0, path1): torch_root, "_inductor", "codegen", - "cutlass", - "lib_extensions", + "cuda", + "cutlass_lib_extensions", "cutlass_mock_imports", ) diff --git a/torch/_inductor/codegen/cutlass/gemm_template.py b/torch/_inductor/codegen/cuda/gemm_template.py similarity index 99% rename from torch/_inductor/codegen/cutlass/gemm_template.py rename to torch/_inductor/codegen/cuda/gemm_template.py index 58f9622571dcc..9148ee7877d03 100644 --- a/torch/_inductor/codegen/cutlass/gemm_template.py +++ b/torch/_inductor/codegen/cuda/gemm_template.py @@ -11,7 +11,7 @@ import torch import torch.utils._pytree as pytree from torch._inductor.autotune_process import TensorMeta -from torch._inductor.codegen.cutlass.cache import maybe_fetch_ops +from torch._inductor.codegen.cuda.cutlass_cache import maybe_fetch_ops from torch._inductor.codegen.wrapper import PythonWrapperCodegen from torch._inductor.runtime.runtime_utils import dynamo_timed from torch._inductor.scheduler import BaseSchedulerNode @@ -32,11 +32,11 @@ from ...utils import is_dynamic, Placeholder from ...virtualized import V from ..common import IndentedBuffer -from . import utils as cutlass_utils +from . import cutlass_utils from .cuda_kernel import CUDATemplateKernel from .cuda_template import CUTLASSTemplate -from .python_evt import CutlassEVTCodegen, scaled_mm_evt -from .utils import ( +from .cutlass_python_evt import CutlassEVTCodegen, scaled_mm_evt +from .cutlass_utils import ( ACCUMULATOR_DTYPES, dtype_match, torch_dtype_to_cutlass_type, @@ -1474,7 +1474,7 @@ def _render_evt( output_dtype: torch.dtype, accumulator_dtype: torch.dtype, ) -> tuple[str, str, str, EVTArgRenames]: - from .lib_extensions.evt_extensions import create_example_tensors, trace + from .cutlass_lib_extensions.evt_extensions import create_example_tensors, trace acc_dtype = torch_dtype_to_cutlass_type(accumulator_dtype) output_dtype = torch_dtype_to_cutlass_type(output_dtype) @@ -1561,7 +1561,7 @@ def _define_gemm_instance( assert cutlass_utils.try_import_cutlass() import cutlass_library.library as cutlass_lib - from .lib_extensions import gemm_operation_extensions as gemm_extensions + from .cutlass_lib_extensions import gemm_operation_extensions as gemm_extensions emitter = gemm_extensions.EmitGemmUniversal3xInstanceWithEVT(evt_name=evt_name) # type: ignore[call-arg] @@ -1701,8 +1701,6 @@ def clone_with_transposed_stride(node: IRNode) -> IRNode: class CUTLASS2xGemmTemplate(CUTLASSGemmTemplate): - """CUTLASS 2x GEMM Template, which is used to generate CUTLASS GEMM kernels""" - def __init__( self, input_nodes: list[Buffer], diff --git a/torch/_inductor/codegen/cutlass/serialization.py b/torch/_inductor/codegen/cuda/serialization.py similarity index 99% rename from torch/_inductor/codegen/cutlass/serialization.py rename to torch/_inductor/codegen/cuda/serialization.py index 39184e4e6e2c6..a17f04b0a1b5a 100644 --- a/torch/_inductor/codegen/cutlass/serialization.py +++ b/torch/_inductor/codegen/cuda/serialization.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Any, Optional -from torch._inductor.codegen.cutlass.utils import try_import_cutlass +from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass class CUTLASSOperationSerializer: diff --git a/torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/__init__.py b/torch/_inductor/codegen/cutlass/lib_extensions/cutlass_mock_imports/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 0bd97ec240dc5..0f29d38cb44d0 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -119,7 +119,7 @@ from torch.fx.experimental.symbolic_shapes import SympyBoolean from torch.fx.node import Argument - from .codegen.cutlass.cuda_template import CUDATemplate + from .codegen.cuda.cuda_template import CUDATemplate from .codegen.wrapper import PythonWrapperCodegen from .graph import GraphLowering from .utils import IndentedBuffer diff --git a/torch/_inductor/kernel/bmm.py b/torch/_inductor/kernel/bmm.py index 7aeed4d8b92a9..a155d35b5d059 100644 --- a/torch/_inductor/kernel/bmm.py +++ b/torch/_inductor/kernel/bmm.py @@ -262,7 +262,7 @@ def _to_dtype(x): and use_cutlass_template(layout, m, n, k) and _use_cutlass_for_op(name) ): - from ..codegen.cutlass.gemm_template import CUTLASS3xGemmTemplate + from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( choices, layout, kernel_inputs.nodes() diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index 1dd6c2fbfcd75..5b57c458f46e6 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -20,7 +20,7 @@ from torch.torch_version import TorchVersion from .. import config as inductor_config, distributed_autotune -from ..codegen.cutlass.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate +from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate from ..codegen.rocm.ck_tile_universal_gemm_template import CKTileGemmTemplate from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate from ..codegen.subgraph import SubgraphChoiceCaller, SubgraphTemplate diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index aab68f3d0a744..625f35ba36c06 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -2707,7 +2707,7 @@ def __call__( best_config_future=None, return_choice=False, # TODO: return_choice is temporary and will be refactored soon ): - from .codegen.cutlass.cuda_kernel import CUDATemplateCaller + from .codegen.cuda.cuda_kernel import CUDATemplateCaller # Run preprocessing functions on choices for preprocessing_fn in self.preprocessing_fns: @@ -3223,7 +3223,7 @@ def wait_on_futures(): "select_algorithm_num_precompilation_exceptions" ] += 1 exceptions.append((futures[future], e)) - from torch._inductor.codegen.cutlass.cuda_kernel import ( + from torch._inductor.codegen.cuda.cuda_kernel import ( CUDATemplateCaller, ) @@ -3410,9 +3410,7 @@ def benchmark_choices( try: timing = cls.benchmark_choice(choice, autotune_args) except CUDACompileError: - from torch._inductor.codegen.cutlass.cuda_kernel import ( - CUDATemplateCaller, - ) + from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller if not isinstance(choice, CUDATemplateCaller): log.exception( @@ -3423,9 +3421,7 @@ def benchmark_choices( log.warning("Not yet implemented", exc_info=True) timing = float("inf") except RuntimeError as e: - from torch._inductor.codegen.cutlass.cuda_kernel import ( - CUDATemplateCaller, - ) + from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller msg = str(e) if "invalid argument" in msg: @@ -3570,7 +3566,7 @@ def prescreen_choices( return prescreen_winners # prescreen cutlass - from .codegen.cutlass.cuda_kernel import CUDATemplateCaller + from .codegen.cuda.cuda_kernel import CUDATemplateCaller candidates = [] if ( @@ -3604,7 +3600,7 @@ def prune_choices_postscreen( """ Prune the choices after prescreening. """ - from .codegen.cutlass.cuda_kernel import CUDATemplateCaller + from .codegen.cuda.cuda_kernel import CUDATemplateCaller prescreen_key = f"{name}:{inputs_key}" diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 4f3ff7e01879f..59db1aeb12325 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -2025,7 +2025,7 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1) if gemm_size <= 0 or gemm_size < config.cutlass.cutlass_backend_min_gemm_size: return False - from .codegen.cutlass.utils import try_import_cutlass + from .codegen.cuda.cutlass_utils import try_import_cutlass # Do not use cutlass template on ROCm if torch.version.hip: From 3f0d46c8b0035a9c17419601ed7412d294f2c56f Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sun, 23 Nov 2025 00:46:16 +0000 Subject: [PATCH 213/230] Revert "[Inductor XPU GEMM] Step 1/N: Refactor cutlass configuration. (#160174)" This reverts commit 008ac433b06c4177e6b6b6d2a63fc4aebbc1fe74. Reverted https://github.com/pytorch/pytorch/pull/160174 on behalf of https://github.com/yangw-dev due to failed internal tests test_cpu_/test_cpu#link-tree/torch/utils/_config_module.py line 371, in _config = self._config[name] KeyError: 'cuda.cutlass_dir' Diff: D87660662 ([comment](https://github.com/pytorch/pytorch/pull/160174#issuecomment-3567237578)) --- benchmarks/inductor_backends/cutlass.py | 2 +- test/inductor/test_cutlass_backend.py | 110 +++++++++--------- torch/_inductor/codecache.py | 30 ++--- .../codegen/cuda/cuda_cpp_scheduling.py | 2 +- torch/_inductor/codegen/cuda/cuda_template.py | 2 +- torch/_inductor/codegen/cuda/cutlass_cache.py | 2 +- torch/_inductor/codegen/cuda/cutlass_utils.py | 4 +- torch/_inductor/codegen/cuda/gemm_template.py | 24 ++-- torch/_inductor/config.py | 73 +++++------- torch/_inductor/fuzzer.py | 2 +- torch/_inductor/select_algorithm.py | 4 +- torch/_inductor/utils.py | 8 +- 12 files changed, 122 insertions(+), 141 deletions(-) diff --git a/benchmarks/inductor_backends/cutlass.py b/benchmarks/inductor_backends/cutlass.py index af06333038947..b2ed506302aec 100644 --- a/benchmarks/inductor_backends/cutlass.py +++ b/benchmarks/inductor_backends/cutlass.py @@ -125,7 +125,7 @@ def name(self) -> str: def to_options(self) -> dict[str, Any]: return { **super().to_options(), - "cutlass.cutlass_instantiation_level": self.cutlass_instantiation_level, + "cuda.cutlass_instantiation_level": self.cutlass_instantiation_level, } diff --git a/test/inductor/test_cutlass_backend.py b/test/inductor/test_cutlass_backend.py index 673d3e87d2a5f..55f8dd5d24ebc 100644 --- a/test/inductor/test_cutlass_backend.py +++ b/test/inductor/test_cutlass_backend.py @@ -133,10 +133,10 @@ def gen_args(op, shape, dtype=torch.float16): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cutlass.cutlass_max_profiling_configs": 1, + "cuda.cutlass_max_profiling_configs": 1, "benchmark_epilogue_fusion": False, # EVT doesn't support benchmark fusion yet - "cutlass.cutlass_tma_only": True, - "cutlass.cutlass_epilogue_fusion_enabled": True, + "cuda.cutlass_tma_only": True, + "cuda.cutlass_epilogue_fusion_enabled": True, } ) @@ -144,9 +144,9 @@ def gen_args(op, shape, dtype=torch.float16): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cutlass.cutlass_max_profiling_configs": 1, + "cuda.cutlass_max_profiling_configs": 1, "benchmark_epilogue_fusion": False, # EVT doesn't support benchmark fusion yet - "cutlass.cutlass_tma_only": True, + "cuda.cutlass_tma_only": True, } ) @@ -234,8 +234,8 @@ def mm(a, b): "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", "compile_threads": 4, - "cutlass.cutlass_backend_min_gemm_size": 100000, - "cutlass.cutlass_max_profiling_configs": 2, + "cuda.cutlass_backend_min_gemm_size": 100000, + "cuda.cutlass_max_profiling_configs": 2, } ): with mock.patch( @@ -287,7 +287,7 @@ def test_cutlass_backend_subproc_mm(self): "autotune_in_subproc": True, "max_autotune_gemm_backends": "CUTLASS", "compile_threads": 4, - "cutlass.cutlass_max_profiling_configs": 4, + "cuda.cutlass_max_profiling_configs": 4, } ): Y_compiled = torch.compile(torch.mm)(a, b) @@ -324,7 +324,7 @@ def test_cutlass_backend_subproc_addmm(self, dtype): "autotune_in_subproc": True, "max_autotune_gemm_backends": "CUTLASS", "compile_threads": 4, - "cutlass.cutlass_max_profiling_configs": 4, + "cuda.cutlass_max_profiling_configs": 4, } ): for x_shape in x_shapes: @@ -354,7 +354,7 @@ def test_cutlass_backend_subproc_bmm(self): "autotune_in_subproc": True, "max_autotune_gemm_backends": "CUTLASS", "compile_threads": 4, - "cutlass.cutlass_max_profiling_configs": 4, + "cuda.cutlass_max_profiling_configs": 4, } ): Y_compiled = torch.compile(torch.bmm)(a, b) @@ -386,7 +386,7 @@ def forward(self, a, b, c): "max_autotune": True, "autotune_in_subproc": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cutlass.cutlass_max_profiling_configs": 1, + "cuda.cutlass_max_profiling_configs": 1, } ): from torch._inductor.utils import run_and_get_code @@ -428,8 +428,8 @@ def forward(self, a, b, c): "max_autotune": True, "autotune_in_subproc": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cutlass.cutlass_max_profiling_configs": 1, - "cutlass.cutlass_max_profiling_swizzle_options": [ + "cuda.cutlass_max_profiling_configs": 1, + "cuda.cutlass_max_profiling_swizzle_options": [ 1, 2, 4, @@ -505,7 +505,7 @@ def forward(self, a, b): { "max_autotune": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cutlass.cutlass_max_profiling_configs": 2, + "cuda.cutlass_max_profiling_configs": 2, } ), dynamo_config.patch({"error_on_recompile": dynamic}), @@ -595,9 +595,9 @@ def forward(self, x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale): { "max_autotune": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cutlass.cutlass_max_profiling_configs": 2, + "cuda.cutlass_max_profiling_configs": 2, "benchmark_epilogue_fusion": False, # EVT doesn't support benchmark fusion yet - "cutlass.cutlass_tma_only": True, + "cuda.cutlass_tma_only": True, } ), dynamo_config.patch({"error_on_recompile": dynamic}), @@ -677,7 +677,7 @@ def forward(self, x, a, b): { "max_autotune": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cutlass.cutlass_max_profiling_configs": 2, + "cuda.cutlass_max_profiling_configs": 2, } ), dynamo_config.patch({"error_on_recompile": dynamic}), @@ -746,7 +746,7 @@ def forward(self, a, b): { "max_autotune": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cutlass.cutlass_max_profiling_configs": 2, + "cuda.cutlass_max_profiling_configs": 2, } ): expected = [model(*input) for input in inputs] @@ -775,8 +775,8 @@ def test_max_autotune_cutlass_backend_regular_mm_streamk( "max_autotune": True, "autotune_in_subproc": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cutlass.cutlass_max_profiling_configs": 2, - "cutlass.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels + "cuda.cutlass_max_profiling_configs": 2, + "cuda.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels } ): for M, K, N in ( @@ -819,7 +819,7 @@ def test_streamk_with_dynamic( { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cutlass.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels + "cuda.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels } ): with self.assertRaisesRegex(InductorError, r".*NoValidChoicesError.*"): @@ -849,8 +849,8 @@ def test_streamk_with_static( { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cutlass.cutlass_max_profiling_configs": 1, - "cutlass.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels + "cuda.cutlass_max_profiling_configs": 1, + "cuda.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels } ): _ = compiled_model(a, b) @@ -884,7 +884,7 @@ def _test_max_autotune_cutlass_backend_epilogue_fusion( "max_autotune": True, "autotune_in_subproc": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cutlass.cutlass_max_profiling_configs": 4, + "cuda.cutlass_max_profiling_configs": 4, "cuda.version": "12.2", # required to enable the Kernels we need } ): @@ -983,7 +983,7 @@ def mm(a, b): "max_autotune": True, "autotune_in_subproc": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cutlass.cutlass_max_profiling_configs": 2, + "cuda.cutlass_max_profiling_configs": 2, } ): Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b) @@ -1002,7 +1002,7 @@ def forward(self, x, w): "max_autotune": True, "autotune_in_subproc": False, "max_autotune_gemm_backends": "CUTLASS", - "cutlass.cutlass_max_profiling_configs": 2, + "cuda.cutlass_max_profiling_configs": 2, } ): model = MyModel() @@ -1040,7 +1040,7 @@ def forward(self, x, w): "max_autotune": True, "autotune_in_subproc": False, "max_autotune_gemm_backends": "CUTLASS", - "cutlass.cutlass_max_profiling_configs": 2, + "cuda.cutlass_max_profiling_configs": 2, } ): model = MyModel() @@ -1073,8 +1073,8 @@ def forward(self, x, w): "max_autotune": True, "autotune_in_subproc": False, "max_autotune_gemm_backends": "CUTLASS", - "cutlass.cutlass_op_allowlist_regex": "128x256x64.*stream_k_warpspecialized_cooperative_epi_nosmem", - "cutlass.cutlass_max_profiling_configs": 1, + "cuda.cutlass_op_allowlist_regex": "128x256x64.*stream_k_warpspecialized_cooperative_epi_nosmem", + "cuda.cutlass_max_profiling_configs": 1, } ): model = MyModel() @@ -1117,7 +1117,7 @@ def mm(a, b): "max_autotune": True, "autotune_in_subproc": True, "max_autotune_gemm_backends": "CUTLASS", - "cutlass.cutlass_max_profiling_configs": 2, + "cuda.cutlass_max_profiling_configs": 2, "autotune_local_cache": True, } ): @@ -1157,9 +1157,9 @@ def my_addmm(x, a, b, alpha, beta): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cutlass.cutlass_max_profiling_configs": 2, - "cutlass.cutlass_op_allowlist_regex": "", - "cutlass.cutlass_op_denylist_regex": "pingpong", + "cuda.cutlass_max_profiling_configs": 2, + "cuda.cutlass_op_allowlist_regex": "", + "cuda.cutlass_op_denylist_regex": "pingpong", } ): with mock.patch( @@ -1202,9 +1202,9 @@ def addmm(x, a, b, alpha, beta): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cutlass.cutlass_max_profiling_configs": 2, - "cutlass.cutlass_op_allowlist_regex": "pingpong", - "cutlass.cutlass_op_denylist_regex": None, + "cuda.cutlass_max_profiling_configs": 2, + "cuda.cutlass_op_allowlist_regex": "pingpong", + "cuda.cutlass_op_denylist_regex": None, } ): with mock.patch( @@ -1273,7 +1273,7 @@ def run_test(use_fast_accum): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cutlass.cutlass_max_profiling_configs": 2, + "cuda.cutlass_max_profiling_configs": 2, } ): with mock.patch( @@ -1350,7 +1350,7 @@ def test_cutlass_backend_shape_coverage_mm( { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cutlass.cutlass_max_profiling_configs": 2, + "cuda.cutlass_max_profiling_configs": 2, } ), mock.patch( @@ -1461,8 +1461,8 @@ def test_standalone_runner(self): { "max_autotune": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cutlass.cutlass_max_profiling_configs": 2, - "cutlass.generate_test_runner": True, # put standalone runner in the generated code + "cuda.cutlass_max_profiling_configs": 2, + "cuda.generate_test_runner": True, # put standalone runner in the generated code } ): from tempfile import NamedTemporaryFile @@ -1544,7 +1544,7 @@ def mm(a, b): { "max_autotune": True, "max_autotune_gemm_backends": "ATEN,TRITON,CUTLASS", - "cutlass.cutlass_max_profiling_configs": 2, + "cuda.cutlass_max_profiling_configs": 2, # needed for log searching "fx_graph_cache": False, "fx_graph_remote_cache": False, @@ -1608,8 +1608,8 @@ def counting_render(self, *args, **kwargs): "max_autotune_gemm_backends": "CUTLASS", "fx_graph_cache": False, "fx_graph_remote_cache": False, - "cutlass.enable_caching_codegen": True, - "cutlass.cutlass_max_profiling_configs": 2, + "cuda.enable_caching_codegen": True, + "cuda.cutlass_max_profiling_configs": 2, } ): compiled_model = torch.compile(model, fullgraph=True) @@ -1660,10 +1660,10 @@ def counting_render(self, *args, **kwargs): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cutlass.cutlass_max_profiling_configs": 2, + "cuda.cutlass_max_profiling_configs": 2, "fx_graph_cache": False, "fx_graph_remote_cache": False, - "cutlass.enable_caching_codegen": True, + "cuda.enable_caching_codegen": True, } ): # Get expected results @@ -1721,10 +1721,10 @@ def counting_render(self, *args, **kwargs): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cutlass.cutlass_max_profiling_configs": 2, + "cuda.cutlass_max_profiling_configs": 2, "fx_graph_cache": False, "fx_graph_remote_cache": False, - "cutlass.enable_caching_codegen": True, + "cuda.enable_caching_codegen": True, } ): # Get expected results @@ -1752,7 +1752,7 @@ def test_cutlass_backend_matmul_same_tensor(self): { "max_autotune": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cutlass.cutlass_max_profiling_configs": 2, + "cuda.cutlass_max_profiling_configs": 2, } ): compiled = torch.compile(torch.mm) @@ -1771,7 +1771,7 @@ def test_cutlass_backend_matmul_nonzero_offset(self): { "max_autotune": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, - "cutlass.cutlass_max_profiling_configs": 2, + "cuda.cutlass_max_profiling_configs": 2, } ): compiled = torch.compile(torch.mm) @@ -1795,7 +1795,7 @@ def forward(self, B): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cutlass.cutlass_max_profiling_configs": 1, + "cuda.cutlass_max_profiling_configs": 1, } ): _ = torch.compile(model)(B) @@ -1817,7 +1817,7 @@ def forward(self, B): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cutlass.cutlass_max_profiling_configs": 1, + "cuda.cutlass_max_profiling_configs": 1, } ): _ = torch.compile(model)(B) @@ -1845,7 +1845,7 @@ def forward(self, B): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cutlass.cutlass_max_profiling_configs": 1, + "cuda.cutlass_max_profiling_configs": 1, } ): _ = torch.compile(model)(B) @@ -1871,7 +1871,7 @@ def forward(self, a, b): { "max_autotune": True, "max_autotune_gemm_backends": "CUTLASS", - "cutlass.cutlass_max_profiling_configs": 1, + "cuda.cutlass_max_profiling_configs": 1, } ): if use_aoti: @@ -1968,7 +1968,7 @@ def forward(self, a, b, extra_args): # baseline is cutlass kernel + triton # matches expected casting behavior - with config.patch({"cutlass.cutlass_epilogue_fusion_enabled": False}): + with config.patch({"cuda.cutlass_epilogue_fusion_enabled": False}): ref_result = torch.compile(model)(a, b, extra_args) self.assertEqual( @@ -2368,7 +2368,7 @@ def test_config_number_post_filtering(self) -> None: "max_autotune_gemm_backends": "CUTLASS", # needed for log searching "force_disable_caches": True, - "cutlass.cutlass_max_profiling_swizzle_options": [2], + "cuda.cutlass_max_profiling_swizzle_options": [2], } ): with mock.patch( diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 2542d5ecefd3f..a30644312332b 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -34,7 +34,7 @@ from tempfile import _TemporaryFileWrapper from time import time, time_ns from types import ModuleType -from typing import Any, cast, Generic, NoReturn, Optional, TYPE_CHECKING, TypeVar, Union +from typing import Any, cast, Generic, NoReturn, TYPE_CHECKING, TypeVar, Union from typing_extensions import override, Self import torch @@ -3741,7 +3741,7 @@ def _load_triton_kernel_from_source( return getattr(PyCodeCache.load(source_code), kernel_name) -def _cuda_compiler() -> Optional[str]: +def _cuda_compiler() -> str | None: if cuda_env.nvcc_exist(config.cuda.cuda_cxx): return config.cuda.cuda_cxx if config.is_fbcode(): @@ -3759,7 +3759,7 @@ def _cutlass_path() -> str: return parutil.get_dir_path("cutlass-4-headers") else: - return config.cutlass.cutlass_dir + return config.cuda.cutlass_dir def _cutlass_paths() -> list[str]: @@ -3807,7 +3807,7 @@ def cutlass_key() -> bytes: return resource_file.read().encode() combined_hash = hashlib.sha256() - build_code_hash([config.cutlass.cutlass_dir], "", combined_hash) + build_code_hash([config.cuda.cutlass_dir], "", combined_hash) return combined_hash.digest() @@ -3877,14 +3877,14 @@ def _nvcc_compiler_options() -> list[str]: "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", "-w", f"-gencode=arch=compute_{arch},code=[{','.join(code)}]", - config.cutlass.compile_opt_level, + config.cuda.compile_opt_level, "-std=c++17", "--expt-relaxed-constexpr", "-DNDEBUG", ] if config.is_fbcode(): options.extend(["-ccbin", os.path.dirname(build_paths.gcc)]) - if config.cutlass.enable_debug_info: + if config.cuda.enable_debug_info: options.extend(["-lineinfo", "-g", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"]) if config.cuda.enable_ptxas_info: options.extend( @@ -3896,7 +3896,7 @@ def _nvcc_compiler_options() -> list[str]: "--source-in-ptx", ] ) # Annotate the ptx file with source information - if config.cutlass.use_fast_math: + if config.cuda.use_fast_math: options.extend( [ "--use_fast_math", @@ -4100,7 +4100,7 @@ def write(cls, source_code: str, dst_file_ext: str) -> tuple[str, str]: Returns the hash key of source code, and the path to the file. """ - if config.cutlass.cutlass_hash_with_compile_cmd: + if config.cuda.cutlass_hash_with_compile_cmd: cuda_command = repr( cuda_compile_command(["dummy_input"], "dummy_output", dst_file_ext) ) @@ -4151,7 +4151,7 @@ def compile( output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext error_path = binary_error_path(output_path) binary_remote_cache = cls.get_kernel_binary_remote_cache( - caching_enabled=config.cutlass.use_binary_remote_cache + caching_enabled=config.cuda.use_binary_remote_cache and not config.force_disable_caches, caching_available=config.is_fbcode(), ) @@ -4166,13 +4166,13 @@ def compile( cmd_parts, error_output = json.loads(error_json) if ( binary_remote_cache is not None - and config.cutlass.upload_to_binary_remote_cache + and config.cuda.upload_to_binary_remote_cache ): # This ensures that a local error is uploaded to the remote cache, # as we make no assumptions about the remote cache having the same # information as the local cache binary_remote_cache.put( - error_path, config.cutlass.binary_remote_cache_force_write + error_path, config.cuda.binary_remote_cache_force_write ) cls.cache[key_with_ext] = CUDACodeCache.CacheEntry( input_path, output_path, error_json @@ -4236,11 +4236,11 @@ def compile( # Upload to remote cache if enabled if ( binary_remote_cache is not None - and config.cutlass.upload_to_binary_remote_cache + and config.cuda.upload_to_binary_remote_cache ): # will log on errors, but not fail out binary_remote_cache.put( - output_path, config.cutlass.binary_remote_cache_force_write + output_path, config.cuda.binary_remote_cache_force_write ) cls.cache[key_with_ext] = CUDACodeCache.CacheEntry( input_path, output_path, None @@ -4293,10 +4293,10 @@ def _record_cuda_compile_error( # Upload to remote cache directly from memory if enabled if ( binary_remote_cache is not None - and config.cutlass.upload_to_binary_remote_cache + and config.cuda.upload_to_binary_remote_cache ): binary_remote_cache.put( - error_path, config.cutlass.binary_remote_cache_force_write + error_path, config.cuda.binary_remote_cache_force_write ) diff --git a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py index 16b09d4ba80eb..2496860ca1f7c 100644 --- a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py +++ b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py @@ -257,7 +257,7 @@ def _can_fuse_epilogue_impl( ) return False elif ( - not config.cutlass.cutlass_epilogue_fusion_enabled + not config.cuda.cutlass_epilogue_fusion_enabled or not config.epilogue_fusion ): why("cutlass epilogue fusion is not enabled") diff --git a/torch/_inductor/codegen/cuda/cuda_template.py b/torch/_inductor/codegen/cuda/cuda_template.py index 92c86120570d6..79dfa9c6c391f 100644 --- a/torch/_inductor/codegen/cuda/cuda_template.py +++ b/torch/_inductor/codegen/cuda/cuda_template.py @@ -110,7 +110,7 @@ def generate_code_and_args( args are different. """ key: Optional[str] = None - if config.cutlass.enable_caching_codegen: + if config.cuda.enable_caching_codegen: key = self.make_key(name=name, input_key=input_key, layout_repr=layout_repr) if key is not None and key in self.code_cache: diff --git a/torch/_inductor/codegen/cuda/cutlass_cache.py b/torch/_inductor/codegen/cuda/cutlass_cache.py index cad4a37902304..66db98867b413 100644 --- a/torch/_inductor/codegen/cuda/cutlass_cache.py +++ b/torch/_inductor/codegen/cuda/cutlass_cache.py @@ -75,7 +75,7 @@ def maybe_fetch_ops() -> Optional[list[Any]]: # get_cuda_version might return "12.4.0" or "12.4" # but we want to use "12.4" version: str = ".".join(get_cuda_version().split(".")[:2]) - instantiation_level: str = config.cutlass.cutlass_instantiation_level + instantiation_level: str = config.cuda.cutlass_instantiation_level # filename and filepath request_key: str = get_config_request_key(arch, version, instantiation_level) diff --git a/torch/_inductor/codegen/cuda/cutlass_utils.py b/torch/_inductor/codegen/cuda/cutlass_utils.py index 3ce3a49bb94e9..fa46e8766cd58 100644 --- a/torch/_inductor/codegen/cuda/cutlass_utils.py +++ b/torch/_inductor/codegen/cuda/cutlass_utils.py @@ -98,7 +98,7 @@ def path_join(path0, path1): # contains both cutlass and cutlass_library # we need cutlass for eVT - cutlass_python_path = path_join(config.cutlass.cutlass_dir, "python") + cutlass_python_path = path_join(config.cuda.cutlass_dir, "python") torch_root = os.path.abspath(os.path.dirname(torch.__file__)) mock_src_path = os.path.join( torch_root, @@ -252,7 +252,7 @@ def _gen_ops_cached(arch, version) -> dict[Any, Any]: ) return {} arch = _normalize_cuda_arch(arch) - instantiation_level: str = config.cutlass.cutlass_instantiation_level + instantiation_level: str = config.cuda.cutlass_instantiation_level args = CUTLASSArgs( architectures=arch, cuda_version=version, diff --git a/torch/_inductor/codegen/cuda/gemm_template.py b/torch/_inductor/codegen/cuda/gemm_template.py index 9148ee7877d03..c4b7188bd9e62 100644 --- a/torch/_inductor/codegen/cuda/gemm_template.py +++ b/torch/_inductor/codegen/cuda/gemm_template.py @@ -19,7 +19,7 @@ from torch._inductor.utils import clear_on_fresh_cache from ... import ir -from ...config import cutlass as inductor_cutlass_config +from ...config import cuda as inductor_cuda_config from ...ir import ( Buffer, ChoiceCaller, @@ -578,7 +578,7 @@ def _add_cutlass_gemm_choices( for name, op in ops: for ( swizzle - ) in inductor_cutlass_config.cutlass_max_profiling_swizzle_options: + ) in inductor_cuda_config.cutlass_max_profiling_swizzle_options: description = f"{name} swizzle={swizzle}" self.maybe_append_choice( choices, @@ -635,7 +635,7 @@ def header(self) -> IndentedBuffer: #include "cutlass/util/tensor_view_io.h" """ ) - if inductor_cutlass_config.generate_test_runner and not is_dynamic( + if inductor_cuda_config.generate_test_runner and not is_dynamic( *self.input_nodes, self.output_node ): res.splice(GEMM_STANDALONE_RUNNER_ADDITIONAL_INCLUDES) @@ -953,7 +953,7 @@ def filter_op( ) return None - if inductor_cutlass_config.cutlass_tma_only and not self._has_tma_epilogue(op): + if inductor_cuda_config.cutlass_tma_only and not self._has_tma_epilogue(op): return None # Set epilogue. @@ -975,16 +975,14 @@ def filter_op( return None # Apply regex filters at the end when configuration name doesn't change anymore - if inductor_cutlass_config.cutlass_op_allowlist_regex: + if inductor_cuda_config.cutlass_op_allowlist_regex: if not re.search( - inductor_cutlass_config.cutlass_op_allowlist_regex, - op.configuration_name(), + inductor_cuda_config.cutlass_op_allowlist_regex, op.configuration_name() ): return None - if inductor_cutlass_config.cutlass_op_denylist_regex is not None: + if inductor_cuda_config.cutlass_op_denylist_regex is not None: if re.search( - inductor_cutlass_config.cutlass_op_denylist_regex, - op.configuration_name(), + inductor_cuda_config.cutlass_op_denylist_regex, op.configuration_name() ): return None @@ -1037,7 +1035,7 @@ def gen_ops(self) -> "list[tuple[str, cutlass_gemm_op.GemmOperation]]": # type: time.time() - start_time, ) sorted_res = sorted(res.items()) - ret_res = sorted_res[: inductor_cutlass_config.cutlass_max_profiling_configs] + ret_res = sorted_res[: inductor_cuda_config.cutlass_max_profiling_configs] if len(self.filtered_ops_cache) < 50: self.filtered_ops_cache[self.cache_key] = ret_res else: @@ -1279,9 +1277,7 @@ def render( # type: ignore[override] } options.update(dict(zip(extra_names, extra_inputs))) res = self._template_from_string(self._get_template()).render(**options) - if inductor_cutlass_config.generate_test_runner and not is_dynamic( - X, W, Y, Bias - ): + if inductor_cuda_config.generate_test_runner and not is_dynamic(X, W, Y, Bias): test_runner_code = self._template_from_string( GEMM_STANDALONE_RUNNER_TEMPLATE ).render(**options) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 645927686232b..7048990692da0 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1829,13 +1829,28 @@ class aot_inductor_mode: compile_standalone: bool = False -class cutlass: - """ - Config specific to cutlass backend. - """ +class cuda: + """Settings for cuda backend, today this consists of cutlass""" + + # CUDA arch to use for CUDA template kernel compilation. + # e.g. "70", "75", "80", "90", etc. + # When arch is None, Inductor uses torch.cuda.get_device_capability(0). + arch: Optional[str] = None + # CUDA version to use for CUDA template kernel compilation. + # e.g. "11.4", "12.1", etc. + # When version is None, Inductor uses torch.version.cuda. + version: Optional[str] = None + + # Optimization level for the host compiler. compile_opt_level: Literal["-O0", "-O1", "-O2", "-O3", "-OS"] = "-O1" + # Whether to enable device LTO (link-time-optimization). + enable_cuda_lto = False + + # Whether to keep intermediate files dring compilation. + enable_ptxas_info = False + # Whether to enable debug info, e.g. line number, cutlass debug info. enable_debug_info = False @@ -1847,10 +1862,7 @@ class cutlass: cutlass_dir = os.path.realpath( os.environ.get( "TORCHINDUCTOR_CUTLASS_DIR", - os.path.join( - os.path.dirname(torch.__file__), - "../third_party/cutlass/", - ), + os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/"), ) ) @@ -1870,6 +1882,14 @@ class cutlass: # Whether to only use TMA-compatible kernels in CUTLASS cutlass_tma_only = False + # Path to CUDA NVCC. + # NVCC search order: + # 1) cuda_cxx set in this config + # 2) CUDACXX environment variable + # 3) CUDA_HOME environment variable + # 4) default system search PATH. + cuda_cxx: Optional[str] = None + # Minimum value of M*N*K to consider the CUTLASS backend for GEMM ops. cutlass_backend_min_gemm_size: int = 1 @@ -1939,41 +1959,6 @@ class cutlass: enable_caching_codegen: bool = True -class cuda(cutlass): - # CUDA arch to use for CUDA template kernel compilation. - # e.g. "70", "75", "80", "90", etc. - # When arch is None, Inductor uses torch.cuda.get_device_capability(0). - arch: Optional[str] = None - - # CUDA version to use for CUDA template kernel compilation. - # e.g. "11.4", "12.1", etc. - # When version is None, Inductor uses torch.version.cuda. - version: Optional[str] = None - - # Path to CUDA NVCC. - # NVCC search order: - # 1) cuda_cxx set in this config - # 2) CUDACXX environment variable - # 3) CUDA_HOME environment variable - # 4) default system search PATH. - cuda_cxx: Optional[str] = None - - # Whether to enable device LTO (link-time-optimization). - enable_cuda_lto = False - - # Whether to keep intermediate files dring compilation. - enable_ptxas_info = False - - -class xpu(cutlass): - # Xe arch to use for SYCL template kernel compilation. - # eg. 12, 20, which corresponding to Xe12(PVC) and Xe20 (BMG) - arch: Optional[str] = None - # oneAPI version to use for SYCL template kernel compilation. - # e.g. "20250201". - version: Optional[str] = None - - class rocm: # Offload arch list for device code compilation, e.g. ["gfx90a", "gfx942"]. # If empty, the `native` arch is used @@ -2182,7 +2167,7 @@ class trace: # trace functions are not relevant to config caching "trace", # uses absolute path - "cutlass.cutlass_dir", + "cuda.cutlass_dir", # not relevant "worker_start_method", "compile_threads", diff --git a/torch/_inductor/fuzzer.py b/torch/_inductor/fuzzer.py index 2d288e683be5a..152dce2026766 100644 --- a/torch/_inductor/fuzzer.py +++ b/torch/_inductor/fuzzer.py @@ -480,7 +480,7 @@ def keys(self) -> KeysView[ComboType]: "aot_inductor.presets": DEFAULT, # Typing "cuda.arch": DEFAULT, # Out of Scope "cuda.version": DEFAULT, # Out of Scope - "cutlass.cutlass_dir": DEFAULT, # Out of Scope + "cuda.cutlass_dir": DEFAULT, # Out of Scope "cuda.cuda_cxx": DEFAULT, # Out of Scope "rocm.arch": DEFAULT, # Out of Scope "rocm.ck_supported_arch": DEFAULT, # Out of Scope diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 625f35ba36c06..28cdfbf0cc7ea 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -3570,8 +3570,8 @@ def prescreen_choices( candidates = [] if ( - config.cutlass.cutlass_prescreening - and len(config.cutlass.cutlass_max_profiling_swizzle_options) > 1 + config.cuda.cutlass_prescreening + and len(config.cuda.cutlass_max_profiling_swizzle_options) > 1 ): candidates.extend( [ diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 59db1aeb12325..f029a2e73f038 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -2023,7 +2023,7 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: from .virtualized import V gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1) - if gemm_size <= 0 or gemm_size < config.cutlass.cutlass_backend_min_gemm_size: + if gemm_size <= 0 or gemm_size < config.cuda.cutlass_backend_min_gemm_size: return False from .codegen.cuda.cutlass_utils import try_import_cutlass @@ -2044,9 +2044,9 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: if not try_import_cutlass(): log.warning( "Failed to import CUTLASS lib. Please check whether " - "_inductor.config.cutlass.cutlass_dir %s is set correctly. " + "_inductor.config.cuda.cutlass_dir %s is set correctly. " "Skipping CUTLASS backend for now.", - config.cutlass.cutlass_dir, + config.cuda.cutlass_dir, ) return False return res @@ -2054,7 +2054,7 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: def _use_cutlass_for_op(op_name: str) -> bool: """Check if CUTLASS should be used for the given operation.""" - enabled_ops = config.cutlass.cutlass_enabled_ops.upper() + enabled_ops = config.cuda.cutlass_enabled_ops.upper() if enabled_ops == "ALL": return True return op_name.upper() in [x.strip() for x in enabled_ops.split(",")] From 1f34961aa9828dc442468da7afcd83587b84d594 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sun, 23 Nov 2025 00:52:42 +0000 Subject: [PATCH 214/230] Revert "[inductor] Use custom triton kernel subclass when available (#167456)" This reverts commit 4ee6b3d60c85d847212901248fc7d99ee81a5899. Reverted https://github.com/pytorch/pytorch/pull/167456 on behalf of https://github.com/yangw-dev due to failed internal test Diff D87660150 , errorl ModuleNotFoundError: No module named 'extension_backends' ([comment](https://github.com/pytorch/pytorch/pull/167456#issuecomment-3567247490)) --- .../triton/extension_triton_heuristics.py | 41 ----- .../inductor/test_triton_extension_backend.py | 170 +++--------------- torch/_inductor/codegen/simd.py | 6 - torch/_inductor/codegen/triton.py | 62 +++---- .../_inductor/codegen/triton_combo_kernel.py | 16 +- torch/_inductor/codegen/wrapper.py | 18 +- torch/_inductor/runtime/triton_heuristics.py | 6 +- torch/_inductor/select_algorithm.py | 3 +- 8 files changed, 62 insertions(+), 260 deletions(-) delete mode 100644 test/inductor/extension_backends/triton/extension_triton_heuristics.py diff --git a/test/inductor/extension_backends/triton/extension_triton_heuristics.py b/test/inductor/extension_backends/triton/extension_triton_heuristics.py deleted file mode 100644 index bfe558ae1708a..0000000000000 --- a/test/inductor/extension_backends/triton/extension_triton_heuristics.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import Any - -from torch._inductor.runtime import triton_heuristics -from torch._inductor.runtime.triton_heuristics import user_autotune # noqa: F401 - - -EXTENSION_TRITON_META_FIELD = "extension_custom_field" - - -class ExtensionCachingAutotuner(triton_heuristics.CachingAutotuner): - def _create_compile_meta( - self, - cfg: triton_heuristics.Config, - ) -> dict[str, Any]: - assert EXTENSION_TRITON_META_FIELD in self.triton_meta - compile_meta = super()._create_compile_meta(cfg) - assert EXTENSION_TRITON_META_FIELD in compile_meta - return compile_meta - - -def pointwise( - size_hints, - triton_meta, - tile_hint=None, - filename=None, - min_elem_per_thread=0, - inductor_meta=None, -): - """ - Construct @triton.heuristics() based on size_hints. - """ - configs = [triton_heuristics.Config({"XBLOCK": 32})] - return triton_heuristics.cached_autotune( - size_hints, - configs, - triton_meta=triton_meta, - inductor_meta=inductor_meta, - heuristic_type=triton_heuristics.HeuristicType.POINTWISE, - filename=filename, - caching_autotuner_cls=ExtensionCachingAutotuner, - ) diff --git a/test/inductor/test_triton_extension_backend.py b/test/inductor/test_triton_extension_backend.py index ae9afeec0637e..37b32404508bb 100644 --- a/test/inductor/test_triton_extension_backend.py +++ b/test/inductor/test_triton_extension_backend.py @@ -1,15 +1,12 @@ # Owner(s): ["module: inductor"] -import functools import random import string +import sys import unittest -from pathlib import Path -from typing import Any, Optional import torch import torch._dynamo import torch.utils.cpp_extension -from torch._inductor import config try: @@ -21,9 +18,6 @@ ExtensionScheduling, ExtensionWrapperCodegen, ) - from extension_backends.triton.extension_triton_heuristics import ( - EXTENSION_TRITON_META_FIELD, - ) except ImportError: from .extension_backends.triton.device_interface import DeviceInterface from .extension_backends.triton.extension_codegen_backend import ( @@ -31,27 +25,18 @@ ExtensionScheduling, ExtensionWrapperCodegen, ) - from .extension_backends.triton.extension_triton_heuristics import ( - EXTENSION_TRITON_META_FIELD, - ) -import torch._inductor.lowering as inductor_lowering from torch._C import FileCheck from torch._dynamo import device_interface -from torch._inductor import codegen, ir, metrics -from torch._inductor.codegen import common +from torch._inductor import metrics from torch._inductor.codegen.common import ( get_scheduling_for_device, get_wrapper_codegen_for_device, - IndentedBuffer, register_backend_for_device, register_device_op_overrides, ) -from torch._inductor.codegen.wrapper import PythonWrapperCodegen -from torch._inductor.utils import get_triton_code, run_and_get_triton_code +from torch._inductor.utils import get_triton_code from torch.testing._internal.common_utils import IS_FBCODE, IS_MACOS -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU_AND_TRITON -from torch.testing._internal.triton_utils import requires_cuda_and_triton try: @@ -59,9 +44,18 @@ except ImportError: from test_extension_backend import BaseExtensionBackendTests -if HAS_GPU_AND_TRITON: - import triton - import triton.language as tl +try: + try: + from . import test_torchinductor + except ImportError: + import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library +except unittest.SkipTest: + if __name__ == "__main__": + sys.exit(0) + raise + + +TestCase = test_torchinductor.TestCase def mock_triton_hash_with_backend(*args, **kwargs): @@ -71,33 +65,14 @@ def mock_triton_hash_with_backend(*args, **kwargs): @unittest.skipIf(IS_FBCODE, "cpp_extension doesn't work in fbcode right now") +@test_torchinductor.skip_if_cpp_wrapper( + "Not possible to fix until CppWrapperCpu supports triton for CPU" +) class TritonExtensionBackendTests(BaseExtensionBackendTests): """ Test creating a backend for inductor with Triton scheduling. """ - @classmethod - def setUpClass(cls): - super().setUpClass() - if config.cpp_wrapper: - raise unittest.SkipTest( - "Not possible to fix until CppWrapperCpu supports triton for CPU" - ) - - # Store the default backends and reset later - common.init_backend_registration() - - default_backend_patch = unittest.mock.patch.dict(inductor_lowering.lowerings) - default_backend_patch.start() - cls._default_backend_patch = default_backend_patch - - @classmethod - def tearDownClass(cls): - super().tearDownClass() - - # Restore the default backend. - cls._default_backend_patch.stop() - def test_open_device_registration(self): torch._register_device_module("privateuseone", self.module) register_backend_for_device( @@ -137,115 +112,10 @@ def foo(x): "tl_math.sin" ).check("device_str='privateuseone'").run(code) - def _register_custom_backend_with_heuristics(self, device): - class ExtensionTritonKernel(codegen.triton.TritonKernel): - @classmethod - @functools.lru_cache(None) - def gen_common_triton_imports(cls) -> str: - default_imports = super().gen_common_triton_imports() - custom_imports = IndentedBuffer() - custom_imports.splice(default_imports) - path_to_ext_heuristics = ( - Path(__file__).parent / "extension_backends" / "triton" - ) - - custom_imports.splice(f""" - import sys - sys.path.append("{path_to_ext_heuristics}") - import extension_triton_heuristics as triton_heuristics - """) - return custom_imports - - @classmethod - def triton_meta_common(cls) -> dict[str, Any]: - triton_meta = super().triton_meta_common() - triton_meta[EXTENSION_TRITON_META_FIELD] = True - return triton_meta - - class ExtensionTritonScheduling(codegen.triton.TritonScheduling): - kernel_type = ExtensionTritonKernel - - class ExtensionPythonWrapperCodegen(PythonWrapperCodegen): - @classmethod - def _get_triton_info_kernel_cls(cls) -> type[codegen.triton.TritonKernel]: - return ExtensionTritonKernel - - @staticmethod - def create( - is_subgraph: bool, - subgraph_name: Optional[str], - parent_wrapper: Optional[PythonWrapperCodegen], - partition_signatures: Optional[ir.GraphPartitionSignature] = None, - ): - if is_subgraph: - assert subgraph_name is not None - assert parent_wrapper is not None - return PythonWrapperCodegen.create( - subgraph_name, parent_wrapper, partition_signatures - ) - return ExtensionPythonWrapperCodegen() - - register_backend_for_device( - device, ExtensionTritonScheduling, ExtensionPythonWrapperCodegen - ) - - @requires_cuda_and_triton - def test_codegen_with_custom_heuristics_module(self): - self._register_custom_backend_with_heuristics(GPU_TYPE) - - def add(x, y): - return x + y - - x = torch.zeros((32,), device=GPU_TYPE) - y = x - compiled_add = torch.compile(add) - - code = run_and_get_triton_code(compiled_add, x, y) - FileCheck().check("import extension_triton_heuristics").check( - f"{EXTENSION_TRITON_META_FIELD}" - ).check("@triton.jit").run(code) - - @requires_cuda_and_triton - def test_codegen_with_custom_heuristics_module_udtk(self): - self._register_custom_backend_with_heuristics(GPU_TYPE) - - @triton.jit - def add_kernel( - in_ptr0, - in_ptr1, - out_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, - ): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(in_ptr0 + offsets, mask=mask) - y = tl.load(in_ptr1 + offsets, mask=mask) - output = x + y - tl.store(out_ptr + offsets, output, mask=mask) - - def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - output = torch.empty_like(x) - n_elements = output.numel() - - def grid(meta): - return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - - add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16) - return output - - args = [torch.randn(32, device=GPU_TYPE) for _ in range(2)] - code = run_and_get_triton_code(torch.compile(add), *args) - - FileCheck().check("import extension_triton_heuristics").check( - "@triton.jit" - ).run(code) - if __name__ == "__main__": from torch._inductor.test_case import run_tests + from torch.testing._internal.inductor_utils import HAS_CPU - if (HAS_CPU or HAS_GPU_AND_TRITON) and not IS_MACOS: + if HAS_CPU and not IS_MACOS: run_tests() diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 1b58503690e98..cf0e5bf849106 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -2273,12 +2273,8 @@ def generate_combo_kernel_code( mixed_sizes: bool, only_gen_src_code: bool = False, ) -> list[tuple[str, Any, Any]]: - from .triton import TritonKernel from .triton_combo_kernel import ComboKernel - # This is currently the only type supported by this method - assert issubclass(self.kernel_type, TritonKernel) - fused_node_lists = [node.get_nodes() for node in subkernel_nodes] subkernel_map, node_schedule_map = {}, {} for pn, nodes in zip(subkernel_nodes, fused_node_lists): @@ -2290,7 +2286,6 @@ def generate_combo_kernel_code( tiling, features=SIMDKernelFeatures(node_schedule, numel, rnumel), optimize_mask=not mixed_sizes, - triton_kernel_cls=self.kernel_type, ) partitions = ComboKernel.horizontal_partition( @@ -2310,7 +2305,6 @@ def generate_combo_kernel_code( if len(node_group) == 0: continue kernel = ComboKernel( - triton_kernel_cls=self.kernel_type, enable_autotune=enable_autotune, mixed_sizes=mixed_sizes, ) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index f590fe57de609..9b718f0c780c1 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -178,6 +178,28 @@ class defined. return "" +@lru_cache(None) +def gen_common_triton_imports() -> str: + imports = IndentedBuffer() + imports.splice( + """ + import triton + import triton.language as tl + """ + ) + if attr_desc := gen_attr_descriptor_import(): + imports.writeline(attr_desc) + + imports.splice( + """ + from torch._inductor.runtime import triton_helpers, triton_heuristics + from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math + from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + """ + ) + return imports.getvalue() + + class TritonSymbols: """ Stores sympy.Symbol instances and constants associated with triton codegen. @@ -4849,37 +4871,8 @@ def _get_heuristic(self): return "reduction" return "pointwise" - @classmethod - @lru_cache(None) - def gen_common_triton_imports(cls) -> str: - imports = IndentedBuffer() - imports.splice( - """ - import triton - import triton.language as tl - """ - ) - if attr_desc := gen_attr_descriptor_import(): - imports.writeline(attr_desc) - - imports.splice( - """ - from torch._inductor.runtime import triton_helpers, triton_heuristics - from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math - from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties - """ - ) - return imports.getvalue() - - @classmethod - def triton_meta_common(cls): - triton_meta = {"enable_fp_fusion": not config.emulate_precision_casts} - if enable_pdl_codegen(): - triton_meta["launch_pdl"] = True - return triton_meta - - @classmethod - def inductor_meta_common(cls): + @staticmethod + def inductor_meta_common(): inductor_meta = { "backend_hash": torch.utils._triton.triton_hash_with_backend(), "assert_indirect_indexing": config.assert_indirect_indexing, @@ -4957,7 +4950,7 @@ def codegen_kernel(self, name=None) -> str: size_hints[prefix] = size_hint if name is None: - code.splice(self.gen_common_triton_imports()) + code.splice(gen_common_triton_imports()) device_type = V.graph.get_current_device_or_throw().type if device_type == "cpu": code.splice("triton_helpers.set_driver_to_cpu()") @@ -5060,7 +5053,6 @@ def add_constexpr_arg(arg_name): torch._inductor.config.triton.native_matmul and ("tl.dot" in str(self.body) or "tl.dot" in str(self.compute)) ), - **self.triton_meta_common(), } # Skip memory optimization for forward of the training loop where we expect @@ -5167,6 +5159,9 @@ def add_constexpr_arg(arg_name): triton_meta["configs"] = [config_of(signature)] + if enable_pdl_codegen(): + triton_meta["launch_pdl"] = True + # Triton compiler includes equal_to_1 args into constants even # when they are not constexpr. otherwise there may be a segfault # during launching the Inductor-compiled Triton kernel. @@ -5174,6 +5169,7 @@ def add_constexpr_arg(arg_name): # https://github.com/triton-lang/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384 for arg_num in equal_1_arg_indices(signature): # type: ignore[index] triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index,union-attr] + triton_meta["enable_fp_fusion"] = not config.emulate_precision_casts self.triton_meta = triton_meta diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index 010de72f1606a..41b12d05cd32e 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -35,7 +35,7 @@ ) from .simd import prefix_is_reduction, SIMDScheduling from .simd_kernel_features import SIMDKernelFeatures -from .triton import TritonKernel +from .triton import gen_common_triton_imports, TritonKernel from .triton_utils import config_of, signature_to_meta @@ -355,13 +355,9 @@ def codegen_pid_range( code.splice(f"pid_offset = pid // {num_kernels}") def __init__( - self, - triton_kernel_cls: type[TritonKernel], - enable_autotune: bool = False, - mixed_sizes: bool = False, + self, enable_autotune: bool = False, mixed_sizes: bool = False ) -> None: super().__init__() - self.triton_kernel_cls = triton_kernel_cls self.sub_kernels: list[TritonKernel] = [] self.iter_vars_count = itertools.count() self.grids: list[list[int]] = [] @@ -395,13 +391,12 @@ def create_triton_kernel( tiling: dict[str, sympy.Expr], features: SIMDKernelFeatures, optimize_mask: bool, - triton_kernel_cls: type[TritonKernel], ) -> TritonKernel: """ Only allow optimize_mask=True when 1) sequential dispatch is used, 2) numels except x dimension are the same for each sub kernel. """ - return triton_kernel_cls( + return TritonKernel( tiling, features=features, pid_cache={"tl.program_id(0)": "pid_offset"}, @@ -620,13 +615,12 @@ def jit_line( mutated_args = self.get_mutated_args_sub_kernels() dispatch = self.dispatch_class assert dispatch is not None - inductor_meta = { "grid_type": dispatch.grid_expr.__name__, "combo_grid_meta": self.combo_grid_meta(), "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), "mutated_arg_names": mutated_args, - **self.triton_kernel_cls.inductor_meta_common(), + **TritonKernel.inductor_meta_common(), } sub_kernel = selected_kernel @@ -774,7 +768,7 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: ) code = IndentedBuffer() - code.splice(self.triton_kernel_cls.gen_common_triton_imports()) + code.splice(gen_common_triton_imports()) if config.benchmark_combo_kernel: code.splice(self.imports_for_benchmark_kernel()) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 7e4aa07987224..0eab3cac9b4a7 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1025,7 +1025,7 @@ class PythonWrapperCodegen(CodeGen): Generate outer wrapper in Python that calls the kernels. """ - supports_caching: bool = True # Whether the output code is cacheable. + supports_caching = True # Whether the output code is cacheable. def __init__(self): super().__init__() @@ -2280,16 +2280,6 @@ def _define_kernel_helper( def define_subgraph_launcher_fn(self, name: str, subgraph_code): self.subgraph_definitions.splice(subgraph_code.value) - @classmethod - def _get_triton_info_kernel_cls(cls): - # Other inductor triton backends may subclass from - # the `TritonKernel` class. An override of this method - # allows them to set which subclass to use to get information - # such as common triton imports or inductor metadata - from .triton import TritonKernel - - return TritonKernel - def define_user_defined_triton_kernel( self, kernel, @@ -2311,6 +2301,7 @@ def define_user_defined_triton_kernel( TensorArg, TMADescriptorArg, ) + from .triton import gen_common_triton_imports, TritonKernel original_name = kernel.__name__ signature: list[KernelArgType] = [] @@ -2523,10 +2514,9 @@ def rename_sizes_for_launcher(expr: Union[int, sympy.Expr]) -> sympy.Expr: compile_wrapper.writeline(f"async_compile.triton({original_name!r}, '''") inductor_meta["kernel_name"] = name - triton_info_kernel_cls = self._get_triton_info_kernel_cls() - inductor_meta.update(triton_info_kernel_cls.inductor_meta_common()) + inductor_meta.update(TritonKernel.inductor_meta_common()) - compile_wrapper.splice(triton_info_kernel_cls.gen_common_triton_imports()) + compile_wrapper.splice(gen_common_triton_imports()) compile_wrapper.splice( f""" @triton_heuristics.user_autotune( diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 9a1783811d75c..175bf76bfc740 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2132,8 +2132,6 @@ def cached_autotune( filename=None, inductor_meta=None, custom_kernel=False, - caching_autotuner_cls: type[CachingAutotuner] = CachingAutotuner, - debug_autotuner_cls: type[DebugAutotuner] = DebugAutotuner, ): """ A copy of triton.autotune that calls our subclass. Our subclass @@ -2170,7 +2168,7 @@ def decorator(fn): tconfig.kwargs.pop("XBLOCK") if inductor_meta.get("profile_bandwidth"): - return debug_autotuner_cls( + return DebugAutotuner( fn, triton_meta=triton_meta, inductor_meta=inductor_meta, @@ -2189,7 +2187,7 @@ def decorator(fn): filename=filename, with_bandwidth_info=True, ) - return caching_autotuner_cls( + return CachingAutotuner( fn, triton_meta=triton_meta, inductor_meta=inductor_meta, diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 28cdfbf0cc7ea..493ca1179fad8 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -61,6 +61,7 @@ from .codegen.simd_kernel_features import SIMDKernelFeatures from .codegen.subgraph import SubgraphChoiceCaller from .codegen.triton import ( + gen_common_triton_imports, texpr, TMACompatibilityChecker, TritonKernel, @@ -741,7 +742,7 @@ def hook(): # python_argdefs() cannot be run until after the rest of the template lazily adds more args arg_defs, *_ = self.args.python_argdefs() code = IndentedBuffer() - code.splice(self.gen_common_triton_imports()) + code.splice(gen_common_triton_imports()) code.splice(self.jit_lines()) code.writeline( f"def {self.kernel_name}({', '.join(x.full_name() for x in arg_defs)}):" From 4fd97b460cdf7aeebe9f886d1bb55486dfce5006 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sun, 23 Nov 2025 00:56:12 +0000 Subject: [PATCH 215/230] Revert "[dynamo] add torch._dynamo.set_recursion_limit to fix 3.12/3.13 RecursionError problems (#167888)" This reverts commit 24e1958fcbfe089421f5731d662b7cc766330345. Reverted https://github.com/pytorch/pytorch/pull/167888 on behalf of https://github.com/yangw-dev due to failed interal test Tracing payload for Mock should not be called: pt2_compile_chromium_events Fatal Python error: Segmentation fault, please remerge after fixing it ([comment](https://github.com/pytorch/pytorch/pull/167888#issuecomment-3567252753)) --- test/dynamo/test_repros.py | 91 ---------------------------- torch/_C/_dynamo/eval_frame.pyi | 2 - torch/_dynamo/__init__.py | 24 -------- torch/csrc/dynamo/eval_frame.c | 5 +- torch/csrc/dynamo/eval_frame_cpp.cpp | 61 +------------------ torch/csrc/dynamo/eval_frame_cpp.h | 7 +-- torch/csrc/dynamo/init.cpp | 4 -- 7 files changed, 5 insertions(+), 189 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 9bb94b9a47d40..24b8f4c48aa32 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -7456,97 +7456,6 @@ def forward(self, x): msg, ) - def test_dynamo_set_recursion_limit_simple(self): - # Test that torch._dynamo.set_recursion_limit calls sys.setrecursionlimit for all supported - # Python versions - old_recursion_limit = sys.getrecursionlimit() - old_dynamo_recursion_limit = torch._dynamo.get_recursion_limit() - try: - - def fn(x, n): - if n == 0: - return x - return fn(x, n - 1) + 1 - - sys.setrecursionlimit(100) - - with self.assertRaises(RecursionError): - fn(torch.ones(3), 1000) - - opt_fn = torch.compile(fn, backend="eager", dynamic=False) - torch._dynamo.set_recursion_limit(100000) - self.assertEqual(fn(torch.ones(3), 1000), opt_fn(torch.ones(3), 1000)) - finally: - if old_dynamo_recursion_limit > 0: - torch._dynamo.set_recursion_limit(old_dynamo_recursion_limit) - sys.setrecursionlimit(old_recursion_limit) - - @unittest.skipIf( - sys.version_info < (3, 12) or sys.version_info >= (3, 14), - "only 3.12, 3.13 affected by c recursion limit", - ) - def test_dynamo_set_recursion_limit(self): - old_recursion_limit = sys.getrecursionlimit() - old_dynamo_recursion_limit = torch._dynamo.get_recursion_limit() - try: - - def fn(x, n): - if n == 0: - return x - return fn(x, n - 1) + 1 - - sys.setrecursionlimit(100) - - with self.assertRaises(RecursionError): - fn(torch.ones(3), 1000) - - sys.setrecursionlimit(2000) - - fn(torch.ones(3), 1000) - opt_fn = torch.compile(fn, backend="eager", dynamic=False) - sys.setrecursionlimit(100000) - with self.assertRaises(Exception): - opt_fn(torch.ones(3), 1000) - - torch._dynamo.set_recursion_limit(100000) - self.assertEqual(fn(torch.ones(3), 1000), opt_fn(torch.ones(3), 1000)) - finally: - if old_dynamo_recursion_limit > 0: - torch._dynamo.set_recursion_limit(old_dynamo_recursion_limit) - sys.setrecursionlimit(old_recursion_limit) - - @unittest.skipIf( - sys.version_info < (3, 12) or sys.version_info >= (3, 14), - "only 3.12, 3.13 affected by c recursion limit", - ) - def test_dynamo_set_recursion_limit_usage(self): - old_recursion_limit = sys.getrecursionlimit() - old_dynamo_recursion_limit = torch._dynamo.get_recursion_limit() - try: - torch._dynamo.set_recursion_limit(100) - self.assertEqual(torch._dynamo.get_recursion_limit(), 100) - - with self.assertRaisesRegex(ValueError, "recursion limit"): - torch._dynamo.set_recursion_limit(0) - - self.assertEqual(torch._dynamo.get_recursion_limit(), 100) - - torch._dynamo.set_recursion_limit(1) - sys.setrecursionlimit(100) - - @torch.compile(backend="eager", dynamic=False) - def fn(x, n): - if n == 0: - return x - return fn(x, n - 1) + 1 - - with self.assertRaisesRegex(RuntimeError, "new c_recursion limit"): - fn(torch.ones(3), 5) - finally: - if old_dynamo_recursion_limit > 0: - torch._dynamo.set_recursion_limit(old_dynamo_recursion_limit) - sys.setrecursionlimit(old_recursion_limit) - @expectedFailureDynamic def test_dynamo_default_lru_cache_behavior(self): @torch.compile(backend="eager") diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi index 060bf2638e096..3c3a18ed4e063 100644 --- a/torch/_C/_dynamo/eval_frame.pyi +++ b/torch/_C/_dynamo/eval_frame.pyi @@ -20,8 +20,6 @@ def set_guard_complete_hook( hook: Optional[DynamoGuardCompleteHook], ) -> Optional[DynamoGuardCompleteHook]: ... def raise_sigtrap() -> None: ... -def set_c_recursion_limit(limit: int) -> None: ... -def get_c_recursion_limit() -> int: ... class _CacheEntry: def check_fn(self, *args: object, **kwargs: object) -> bool: ... diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index b0b00bc6f5b89..de097edf87752 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -105,7 +105,6 @@ "reset", "run", "error_on_graph_break", - "set_recursion_limit", "set_stance", "skip_frame", "step_unsupported", @@ -182,26 +181,3 @@ def reset_code_caches() -> None: if code: reset_code(code) code_context.clear() - - -def get_recursion_limit() -> int: - """ - Returns the internal dynamo recursion limit set by `torch._dynamo.set_recursion_limit`. - - Returns -1 if no c recursion limit has been set. - """ - return torch._C._dynamo.eval_frame.get_c_recursion_limit() - - -def set_recursion_limit(limit: int) -> None: - """ - Sets an internal dynamo recursion limit. The limit must be >= 1. - - This is possibly needed in Python 3.12-3.13 since there is a separate C recursion limit - that is not visible at the Python level. If you are getting RecursionErrors during - Dynamo compilation and `sys.setrecursionlimit()` doesn't help, this function may alleviate - the issue. - - NOTE: this function will also call `sys.setrecursionlimit()`. - """ - torch._C._dynamo.eval_frame.set_c_recursion_limit(limit) diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c index 58cb48de664d5..b08fffedaa014 100644 --- a/torch/csrc/dynamo/eval_frame.c +++ b/torch/csrc/dynamo/eval_frame.c @@ -733,10 +733,7 @@ static PyMethodDef _methods[] = { {"get_eval_frame_callback", get_eval_frame_callback_py, METH_NOARGS, NULL}, {"reset_code", reset_code, METH_O, NULL}, {"unsupported", unsupported, METH_VARARGS, NULL}, - {"set_code_exec_strategy", - dynamo_set_code_exec_strategy, - METH_VARARGS, - NULL}, + {"set_code_exec_strategy", set_code_exec_strategy, METH_VARARGS, NULL}, {"set_guard_error_hook", set_guard_error_hook, METH_O, NULL}, {"set_guard_complete_hook", set_guard_complete_hook, METH_O, NULL}, {"raise_sigtrap", raise_sigtrap, METH_NOARGS, NULL}, diff --git a/torch/csrc/dynamo/eval_frame_cpp.cpp b/torch/csrc/dynamo/eval_frame_cpp.cpp index 72465d6f4774f..e678bc7bad04a 100644 --- a/torch/csrc/dynamo/eval_frame_cpp.cpp +++ b/torch/csrc/dynamo/eval_frame_cpp.cpp @@ -50,56 +50,6 @@ static py::handle _callback_from_action( return callback; } -// c_recursion_remaining only defined in 3.12 and 3.13 - -static int32_t c_recursion_limit = -1; - -void dynamo_set_c_recursion_limit(int32_t limit) { - if (limit < 1) { - throw std::range_error("recursion limit must be greater or equal than 1"); - } - c_recursion_limit = limit; - // cannot fail - Py_SetRecursionLimit(limit); // also set the Python limit -} - -int32_t dynamo_get_c_recursion_limit() { - return c_recursion_limit; -} - -#if IS_PYTHON_3_12_PLUS && !IS_PYTHON_3_14_PLUS - -struct CRecursionLimitRAII { - PyThreadState* tstate; - int32_t old_recursion_remaining; - CRecursionLimitRAII(PyThreadState* tstate) : tstate{tstate} { - auto limit = dynamo_get_c_recursion_limit(); - auto& remaining = tstate->c_recursion_remaining; - this->old_recursion_remaining = remaining; - if (limit < 0) { - // no change to limit - return; - } - if (limit < remaining) { - PyErr_SetString( - PyExc_RuntimeError, - "new c_recursion limit is lower than thread's current c_recursion_remaining."); - } - remaining = limit; - } - ~CRecursionLimitRAII() { - this->tstate->c_recursion_remaining = this->old_recursion_remaining; - } -}; - -#else - -struct CRecursionLimitRAII { - CRecursionLimitRAII(PyThreadState* tstate) {} -}; - -#endif - // frame and callback are borrowed references. // Returns new reference. PyObject* dynamo__custom_eval_frame( @@ -308,13 +258,6 @@ PyObject* dynamo__custom_eval_frame( bool apply_to_code = false; PyObject* guarded_code = nullptr; try { - CRecursionLimitRAII tmp(tstate); // increase C recursion limit to the given - // value during compilation - // C recursion limit failure - if (PyErr_Occurred()) { - fail(); - return eval_result; - } callback_result = dynamo_call_callback( callback, frame, locals.get(), cache_entry, frame_state); new_strategy = @@ -377,7 +320,7 @@ PyObject* dynamo__custom_eval_frame( return eval_result; } -PyObject* dynamo_set_code_exec_strategy(PyObject* dummy, PyObject* args) { +PyObject* set_code_exec_strategy(PyObject* dummy, PyObject* args) { PyObject* code_obj = nullptr; PyObject* strategy_obj = nullptr; if (!PyArg_ParseTuple(args, "OO", &code_obj, &strategy_obj)) { @@ -401,7 +344,7 @@ PyObject* dynamo_set_code_exec_strategy(PyObject* dummy, PyObject* args) { Py_RETURN_NONE; } -void dynamo_skip_code_recursive(PyCodeObject* code) { +void skip_code_recursive(PyCodeObject* code) { ExtraState* extra = get_extra_state(code); if (extra == nullptr) { extra = init_and_set_extra_state(code); diff --git a/torch/csrc/dynamo/eval_frame_cpp.h b/torch/csrc/dynamo/eval_frame_cpp.h index 8cc1ab7618b3d..2f3587094f763 100644 --- a/torch/csrc/dynamo/eval_frame_cpp.h +++ b/torch/csrc/dynamo/eval_frame_cpp.h @@ -16,11 +16,8 @@ PyObject* dynamo__custom_eval_frame( int throw_flag, PyObject* callback); -PyObject* dynamo_set_code_exec_strategy(PyObject* dummy, PyObject* obj); -void dynamo_skip_code_recursive(PyCodeObject* code); - -void dynamo_set_c_recursion_limit(int32_t limit); -int32_t dynamo_get_c_recursion_limit(); +PyObject* set_code_exec_strategy(PyObject* dummy, PyObject* obj); +void skip_code_recursive(PyCodeObject* code); #ifdef __cplusplus diff --git a/torch/csrc/dynamo/init.cpp b/torch/csrc/dynamo/init.cpp index 69d6e0555ceb4..9ed9a465642c3 100644 --- a/torch/csrc/dynamo/init.cpp +++ b/torch/csrc/dynamo/init.cpp @@ -7,7 +7,6 @@ #include #include #include -#include #include #include #include @@ -252,9 +251,6 @@ void initDynamoBindings(PyObject* torch) { .def_readwrite("cur_action", &FrameExecStrategy::cur_action) .def_readwrite("recursive_action", &FrameExecStrategy::recursive_action); - m.def("set_c_recursion_limit", &dynamo_set_c_recursion_limit); - m.def("get_c_recursion_limit", &dynamo_get_c_recursion_limit); - m.def("_debug_get_cache_entry_list", &_debug_get_cache_entry_list); m.def("_reset_precompile_entries", &_reset_precompile_entries); m.def("_load_precompile_entry", &_load_precompile_entry); From a1ab3a0ee4cd35e74cdd610fbb55f50062a74916 Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Sun, 23 Nov 2025 04:46:50 +0000 Subject: [PATCH 216/230] [audio hash update] update the pinned audio hash (#168315) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned audio hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168315 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/audio.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index 616dfd88ce812..b65b6a7f117ef 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -ee1a1350eb37804b94334768f328144f058f14e9 +32ce8c011855adb15438ddc9bf6c139d23f8cee5 From 19c34dd14669be332401ac974a648e233ca7e33e Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sat, 22 Nov 2025 15:35:36 -0800 Subject: [PATCH 217/230] [dynamo] Special case handling for tree_map_only (#168365) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168365 Approved by: https://github.com/anijain2305 --- test/dynamo/test_tree_map.py | 88 ++++++++++++++ torch/_dynamo/variables/functions.py | 164 +++++++++++++++++++++++++-- 2 files changed, 244 insertions(+), 8 deletions(-) diff --git a/test/dynamo/test_tree_map.py b/test/dynamo/test_tree_map.py index a7ade021b5acd..0e18d69129d56 100644 --- a/test/dynamo/test_tree_map.py +++ b/test/dynamo/test_tree_map.py @@ -168,6 +168,94 @@ def fn(tree): result = compiled(tree) _assert_trees_allclose(self, expected, result) + def test_tree_map_only_applies_to_tensor_nodes(self) -> None: + tree = {"tensor": torch.ones(2), "int": 3} + + def mapper(node): + if not isinstance(node, torch.Tensor): + raise AssertionError("mapper should only see tensors") + return node + 2 + + def fn(arg): + return pytree.tree_map_only(torch.Tensor, mapper, arg) + + compiled = torch.compile(fn, backend="eager", fullgraph=True) + expected = fn(tree) + result = compiled(tree) + _assert_trees_allclose(self, expected, result) + + def test_tree_map_only_multiple_trees_falls_back(self) -> None: + lhs = {"a": torch.ones(2), "b": torch.ones(2) * 2} + rhs = {"a": torch.ones(2) * 3, "b": torch.ones(2) * 4} + + def fn(a, b): + return pytree.tree_map_only(torch.Tensor, lambda x, y: x + y, a, b) + + with self.assertRaisesRegex(TypeError, "callable"): + fn(lhs, rhs) + + compiled = torch.compile(fn, backend="eager", fullgraph=True) + with self.assertRaisesRegex( + (TypeError, torch._dynamo.exc.Unsupported), + r"(callable|Unsupported function call)", + ): + compiled(lhs, rhs) + + def test_tree_map_only_handles_multiple_types(self) -> None: + tree = {"int": 7, "tuple": (1, 2), "tensor": torch.ones(2)} + + def mapper(node): + if isinstance(node, int): + return node + 1 + if isinstance(node, tuple): + return tuple(val + 10 for val in node) + raise AssertionError("unexpected node passed to mapper") + + def fn(arg): + return pytree.tree_map_only((int, tuple), mapper, arg) + + compiled = torch.compile(fn, backend="eager", fullgraph=True) + expected = fn(tree) + result = compiled(tree) + _assert_trees_allclose(self, expected, result) + + def test_tree_map_is_leaf_non_constant_fallback(self) -> None: + tree = {"a": torch.arange(2.0), "b": torch.arange(2.0) + 1} + + def is_leaf(node): + if isinstance(node, torch.Tensor): + # Depends on runtime tensor value; cannot be folded to a constant. + return (node.sum() > 1).item() + return False + + def mapper(node): + return node * 2 if isinstance(node, torch.Tensor) else node + + def fn(arg): + return pytree.tree_map(mapper, arg, is_leaf=is_leaf) + + compiled = torch.compile(fn, backend="eager", fullgraph=True) + expected = fn(tree) + result = compiled(tree) + _assert_trees_allclose(self, expected, result) + + def test_tree_map_only_predicate_selector_skips_fastpath(self) -> None: + tree = {"keep": torch.ones(2), "other": (1, 2)} + + def selector(node): + return isinstance(node, torch.Tensor) and node.shape == (2,) + + def mapper(node): + return node + 5 if isinstance(node, torch.Tensor) else node + + def fn(arg): + return pytree.tree_map_only(selector, mapper, arg) + + compiled = torch.compile(fn, backend="eager", fullgraph=True) + expected = fn(tree) + result = compiled(tree) + _assert_trees_allclose(self, expected, result) + def test_tree_map_none_nodes_reject_mismatched_siblings(self) -> None: def fn(a, b): return optree.tree_map(lambda u, v: (u, v), a, b) diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 1b85235e7e6dd..deee9bcec42de 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -115,6 +115,7 @@ CO_VARARGS = 0x04 CO_VARKEYWORDS = 0x08 _SUPPORTED_TREE_MAP_KWARGS = frozenset({"namespace", "none_is_leaf", "is_leaf"}) +_TREE_MAP_ONLY_SUPPORTED_KWARGS = frozenset({"is_leaf"}) # Module-level cache keyed by the function object @@ -679,22 +680,31 @@ def _maybe_call_tree_map_fastpath( args: Sequence[VariableTracker], kwargs: dict[str, VariableTracker], ) -> Optional[VariableTracker]: + rewrite = self._rewrite_tree_map_only_call(tx, args, kwargs) + if rewrite is not None: + tree_map_fn, tree_map_args, tree_map_kwargs = rewrite + else: + tree_map_fn = self + tree_map_args = args + tree_map_kwargs = kwargs + if not ( - self._is_tree_map_function() - and not ({*kwargs} - _SUPPORTED_TREE_MAP_KWARGS) - and len(args) >= 2 + isinstance(tree_map_fn, UserFunctionVariable) + and tree_map_fn._is_tree_map_function() + and not ({*tree_map_kwargs} - _SUPPORTED_TREE_MAP_KWARGS) + and len(tree_map_args) >= 2 ): return None - map_fn = args[0] - first_tree = args[1] - rest = args[2:] + map_fn = tree_map_args[0] + first_tree = tree_map_args[1] + rest = tree_map_args[2:] return first_tree.call_tree_map( tx, - self, + tree_map_fn, map_fn, rest, - kwargs, + tree_map_kwargs, ) def _is_tree_map_function(self) -> bool: @@ -703,6 +713,144 @@ def _is_tree_map_function(self) -> bool: and getattr(self.fn, "__module__", None) in self._TREE_MAP_MODULES ) + def _is_tree_map_only_function(self) -> bool: + return ( + getattr(self.fn, "__name__", None) == "tree_map_only" + and getattr(self.fn, "__module__", None) in self._TREE_MAP_MODULES + ) + + def _rewrite_tree_map_only_call( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> Optional[ + tuple[ + "UserFunctionVariable", + Sequence[VariableTracker], + dict[str, VariableTracker], + ] + ]: + if not self._is_tree_map_only_function(): + return None + + if len(args) != 3: + return None + if {*kwargs} - _TREE_MAP_ONLY_SUPPORTED_KWARGS: + return None + + type_selector, map_fn, tree_arg = args + allowed_types = self._extract_tree_map_only_types(type_selector) + if allowed_types is None: + return None + + tree_map_callable = self._lookup_tree_map_function() + if tree_map_callable is None: + return None + + wrapped_map_fn = TreeMapOnlyFunctionVariable( + allowed_types, + map_fn, + source=getattr(map_fn, "source", None), + ) + tree_map_variable = variables.UserFunctionVariable(tree_map_callable) + return tree_map_variable, [wrapped_map_fn, tree_arg], dict(kwargs) + + def _lookup_tree_map_function(self) -> Optional[types.FunctionType]: + module_name = getattr(self.fn, "__module__", None) + if not module_name: + return None + module = sys.modules.get(module_name) + if module is None: + return None + tree_map = getattr(module, "tree_map", None) + if isinstance(tree_map, types.FunctionType): + return tree_map + return None + + def _extract_tree_map_only_types( + self, selector: VariableTracker + ) -> Optional[tuple[type, ...]]: + if not selector.is_python_constant(): + return None + try: + raw_value = selector.as_python_constant() + except NotImplementedError: + return None + + flattened = self._flatten_type_spec(raw_value) + if not flattened: + return None + if not all(isinstance(typ, type) for typ in flattened): + return None + return tuple(dict.fromkeys(flattened)) + + def _flatten_type_spec(self, value: Any) -> Optional[list[type]]: + if isinstance(value, type): + return [value] + if isinstance(value, tuple): + collected: list[type] = [] + for entry in value: + flat = self._flatten_type_spec(entry) + if flat is None: + return None + collected.extend(flat) + return collected + union_type = getattr(types, "UnionType", None) + if union_type is not None and isinstance(value, union_type): + collected = [] + for entry in value.__args__: + flat = self._flatten_type_spec(entry) + if flat is None: + return None + collected.extend(flat) + return collected + return None + + +class TreeMapOnlyFunctionVariable(BaseUserFunctionVariable): + _nonvar_fields = { + "allowed_types", + *BaseUserFunctionVariable._nonvar_fields, + } + + def __init__( + self, + allowed_types: tuple[type, ...], + map_fn: VariableTracker, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.allowed_types = allowed_types + self.map_fn = map_fn + + def python_type(self) -> type: + return FunctionType + + def _matches_allowed_type(self, node: VariableTracker) -> bool: + try: + node_type = node.python_type() + except NotImplementedError: + return False + return any(issubclass(node_type, allowed) for allowed in self.allowed_types) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if not args: + return self.map_fn.call_function(tx, args, kwargs) + leaf = args[0] + if self._matches_allowed_type(leaf): + return self.map_fn.call_function(tx, args, kwargs) + if len(args) != 1 or kwargs: + # Defer to the original map function so we fall back to normal + # tracing instead of triggering a graph break. + return self.map_fn.call_function(tx, args, kwargs) + return leaf + class BuiltinMethodVariable(BaseUserFunctionVariable): def __init__( From d3f61c1b1da0530839547a69602f83f9797ee15a Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sat, 22 Nov 2025 15:35:37 -0800 Subject: [PATCH 218/230] [dynamo] Fix local test failures for dynamo/test_repros.py (#168893) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168893 Approved by: https://github.com/anijain2305 ghstack dependencies: #168365 --- test/dynamo/test_repros.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 24b8f4c48aa32..8eefbefe9237f 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -968,6 +968,15 @@ class LRUCacheWarningTests(LoggingTestCase): @requires_cuda @make_logging_test(dynamo=logging.DEBUG) def test_lru_cache_warning_issued_during_tracing(self, records): + prev_default = torch._C._get_default_device() + + def _restore_default_device(): + if prev_default == "cpu": + torch.set_default_device(None) + else: + torch.set_default_device(prev_default) + + self.addCleanup(_restore_default_device) torch.set_default_device("cuda") @torch.compile(backend="eager") From cb3754f39fce0ce5162fc25f3c7cd0ab451ecd79 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 20 Nov 2025 10:05:35 -0800 Subject: [PATCH 219/230] [DTensor] Refactor strategy/rule registration into dedicated module (#168221) To avoid circular import issues: - utils.py used to include registration functions which import/depend on DTensor.sharding_propagator - I plan to use other utils from utils.py inside sharding_propagator.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/168221 Approved by: https://github.com/albanD --- test/distributed/tensor/test_op_strategy.py | 6 +- torch/distributed/tensor/_ops/_conv_ops.py | 2 +- .../distributed/tensor/_ops/_embedding_ops.py | 6 +- torch/distributed/tensor/_ops/_math_ops.py | 2 +- torch/distributed/tensor/_ops/_matrix_ops.py | 2 +- .../distributed/tensor/_ops/_pointwise_ops.py | 2 +- torch/distributed/tensor/_ops/_random_ops.py | 3 +- torch/distributed/tensor/_ops/_tensor_ops.py | 6 +- torch/distributed/tensor/_ops/_view_ops.py | 2 +- torch/distributed/tensor/_ops/registration.py | 83 +++++++++++++++++++ torch/distributed/tensor/_ops/utils.py | 76 +---------------- 11 files changed, 99 insertions(+), 91 deletions(-) create mode 100644 torch/distributed/tensor/_ops/registration.py diff --git a/test/distributed/tensor/test_op_strategy.py b/test/distributed/tensor/test_op_strategy.py index 42c4ccf122fd9..e1d3f96e9e5f4 100644 --- a/test/distributed/tensor/test_op_strategy.py +++ b/test/distributed/tensor/test_op_strategy.py @@ -30,10 +30,8 @@ EinsumDims, gen_einsum_strategies, ) -from torch.distributed.tensor._ops.utils import ( - register_op_strategy, - replicate_op_strategy, -) +from torch.distributed.tensor._ops.registration import register_op_strategy +from torch.distributed.tensor._ops.utils import replicate_op_strategy from torch.distributed.tensor.debug import ( _clear_fast_path_sharding_prop_cache, _clear_python_sharding_prop_cache, diff --git a/torch/distributed/tensor/_ops/_conv_ops.py b/torch/distributed/tensor/_ops/_conv_ops.py index df9b81ac5df6e..1f456d505c127 100644 --- a/torch/distributed/tensor/_ops/_conv_ops.py +++ b/torch/distributed/tensor/_ops/_conv_ops.py @@ -4,7 +4,7 @@ import torch from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.tensor._op_schema import OpSchema, OutputSharding -from torch.distributed.tensor._ops.utils import register_prop_rule +from torch.distributed.tensor._ops.registration import register_prop_rule aten = torch.ops.aten diff --git a/torch/distributed/tensor/_ops/_embedding_ops.py b/torch/distributed/tensor/_ops/_embedding_ops.py index 41272c0f31a92..b7c4abf353be5 100644 --- a/torch/distributed/tensor/_ops/_embedding_ops.py +++ b/torch/distributed/tensor/_ops/_embedding_ops.py @@ -10,10 +10,8 @@ PlacementList, StrategyType, ) -from torch.distributed.tensor._ops.utils import ( - expand_to_full_mesh_op_strategy, - register_op_strategy, -) +from torch.distributed.tensor._ops.registration import register_op_strategy +from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy from torch.distributed.tensor.placement_types import ( MaskPartial, Partial, diff --git a/torch/distributed/tensor/_ops/_math_ops.py b/torch/distributed/tensor/_ops/_math_ops.py index 7ccca23c0dab5..ac0180f07d05e 100644 --- a/torch/distributed/tensor/_ops/_math_ops.py +++ b/torch/distributed/tensor/_ops/_math_ops.py @@ -17,6 +17,7 @@ RuntimeSchemaInfo, TupleStrategy, ) +from torch.distributed.tensor._ops.registration import register_op_strategy from torch.distributed.tensor._ops.utils import ( as_list, expand_to_full_mesh_op_strategy, @@ -25,7 +26,6 @@ is_tensor_evenly_shardable_on_dim, normalize_dim, normalize_dims, - register_op_strategy, ) from torch.distributed.tensor._utils import normalize_to_torch_size from torch.distributed.tensor.placement_types import ( diff --git a/torch/distributed/tensor/_ops/_matrix_ops.py b/torch/distributed/tensor/_ops/_matrix_ops.py index 5ccf3c37c7855..ecd7938d75e2e 100644 --- a/torch/distributed/tensor/_ops/_matrix_ops.py +++ b/torch/distributed/tensor/_ops/_matrix_ops.py @@ -15,6 +15,7 @@ RuntimeSchemaInfo, ) from torch.distributed.tensor._ops._einsum_strategy import gen_einsum_strategies +from torch.distributed.tensor._ops.registration import register_op_strategy from torch.distributed.tensor._ops.utils import ( expand_to_full_mesh_op_strategy, generate_redistribute_costs, @@ -22,7 +23,6 @@ is_tensor_shardable, map_placements_after_broadcast, prod, - register_op_strategy, ) from torch.distributed.tensor._utils import ( compute_local_shape_and_global_offset, diff --git a/torch/distributed/tensor/_ops/_pointwise_ops.py b/torch/distributed/tensor/_ops/_pointwise_ops.py index 53b759e993c0d..011a1ec667fb4 100644 --- a/torch/distributed/tensor/_ops/_pointwise_ops.py +++ b/torch/distributed/tensor/_ops/_pointwise_ops.py @@ -12,12 +12,12 @@ StrategyType, TupleStrategy, ) +from torch.distributed.tensor._ops.registration import register_op_strategy from torch.distributed.tensor._ops.utils import ( generate_redistribute_costs, infer_broadcast_dims_map, map_placements_after_broadcast, normalize_dim, - register_op_strategy, ) from torch.distributed.tensor.placement_types import ( Partial, diff --git a/torch/distributed/tensor/_ops/_random_ops.py b/torch/distributed/tensor/_ops/_random_ops.py index 9db9b85e58d2d..dd4cf8fec226a 100644 --- a/torch/distributed/tensor/_ops/_random_ops.py +++ b/torch/distributed/tensor/_ops/_random_ops.py @@ -6,7 +6,8 @@ OpStrategy, StrategyType, ) -from torch.distributed.tensor._ops.utils import is_tensor_partial, register_op_strategy +from torch.distributed.tensor._ops.registration import register_op_strategy +from torch.distributed.tensor._ops.utils import is_tensor_partial aten = torch.ops.aten diff --git a/torch/distributed/tensor/_ops/_tensor_ops.py b/torch/distributed/tensor/_ops/_tensor_ops.py index fe20e41f59285..cb336486785af 100644 --- a/torch/distributed/tensor/_ops/_tensor_ops.py +++ b/torch/distributed/tensor/_ops/_tensor_ops.py @@ -18,6 +18,10 @@ ) from torch.distributed.tensor._ops._common_rules import pointwise_rule from torch.distributed.tensor._ops._embedding_ops import MaskPartial +from torch.distributed.tensor._ops.registration import ( + register_op_strategy, + register_prop_rule, +) from torch.distributed.tensor._ops.utils import ( expand_to_full_mesh_op_strategy, generate_redistribute_costs, @@ -25,8 +29,6 @@ is_tensor_evenly_shardable, is_tensor_partial, normalize_dim, - register_op_strategy, - register_prop_rule, shift_shard_dims_after_insert, shift_shard_dims_after_remove, ) diff --git a/torch/distributed/tensor/_ops/_view_ops.py b/torch/distributed/tensor/_ops/_view_ops.py index 2d9e33402c607..6c8954729b976 100644 --- a/torch/distributed/tensor/_ops/_view_ops.py +++ b/torch/distributed/tensor/_ops/_view_ops.py @@ -15,12 +15,12 @@ RuntimeSchemaInfo, StrategyType, ) +from torch.distributed.tensor._ops.registration import register_op_strategy from torch.distributed.tensor._ops.utils import ( generate_redistribute_costs, normalize_dim, normalize_dims, prod, - register_op_strategy, ) from torch.distributed.tensor.placement_types import ( _StridedShard, diff --git a/torch/distributed/tensor/_ops/registration.py b/torch/distributed/tensor/_ops/registration.py new file mode 100644 index 0000000000000..3864d8971069e --- /dev/null +++ b/torch/distributed/tensor/_ops/registration.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from collections.abc import Callable +from typing import Optional, TypeAlias, TypeVar, Union + +import torch +from torch.distributed.tensor._api import DTensor +from torch.distributed.tensor._op_schema import ( + OpSchema, + OutputSharding, + RuntimeSchemaInfo, + StrategyType, +) + + +# convenient wrapper to register sharding propagation rules +def register_prop_rule( + op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]], + schema_info: Optional[RuntimeSchemaInfo] = None, +) -> Callable[ + [Callable[[OpSchema], OutputSharding]], Callable[[OpSchema], OutputSharding] +]: + def wrapper( + impl: Callable[[OpSchema], OutputSharding], + ) -> Callable[[OpSchema], OutputSharding]: + overloads = op if isinstance(op, list) else [op] + for overload in overloads: + DTensor._op_dispatcher.sharding_propagator.register_sharding_prop_rule( + overload, impl, schema_info + ) + return impl + + return wrapper + + +# Note: +# using TypeVar here allows the registration decorator to preserve the specific type info of the wrapped strategy, +# while hardcoding the typing on the wrapper (e.g. Callable[[OpSchema], StrategyType]) would mean mypy would treat +# the return value of the wrapped strategy as always being a `StrategyType` even if it were a derived class like +# MyStrategyType(StrategyType). +_OpSchemaT = TypeVar("_OpSchemaT", bound=OpSchema) +_StrategyTypeT = TypeVar("_StrategyTypeT", bound=StrategyType) +_ShardingStrategyFunc: TypeAlias = Callable[[_OpSchemaT], _StrategyTypeT] + + +def register_op_strategy( + op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]], + schema_info: Optional[RuntimeSchemaInfo] = None, +) -> Callable[[_ShardingStrategyFunc], _ShardingStrategyFunc]: + # For every ATen op that accepts any args in this list, + # the arg itself can impact the strides (and potentially the sharding strategy) + # of the output tensor. + # thus, we will detect ATen schemas with any of these args and ensure + # that they get specialized here. + arg_names_that_require_specializing_cache_strategy = [ + "memory_format", + ] + + def wrapper(impl: _ShardingStrategyFunc) -> _ShardingStrategyFunc: + if isinstance(op, list): + overloads = op + else: + overloads = [op] + + for overload in overloads: + curr_schema_info = None + if schema_info is None: + specialized_args = [ + a.name + for a in overload._schema.arguments + if a.name in arg_names_that_require_specializing_cache_strategy + ] + if any(specialized_args): + curr_schema_info = RuntimeSchemaInfo( + static_kwargkey=specialized_args + ) + else: + curr_schema_info = schema_info + DTensor._op_dispatcher.sharding_propagator.register_op_strategy( + overload, impl, curr_schema_info + ) + return impl + + return wrapper diff --git a/torch/distributed/tensor/_ops/utils.py b/torch/distributed/tensor/_ops/utils.py index a9f42b53fca6e..f09a888734807 100644 --- a/torch/distributed/tensor/_ops/utils.py +++ b/torch/distributed/tensor/_ops/utils.py @@ -4,20 +4,17 @@ import itertools import operator from collections.abc import Callable, Iterable, Sequence -from typing import cast, Optional, TypeAlias, TypeVar, Union +from typing import cast, Optional, Union import torch from torch._prims_common import DimsSequenceType, DimsType -from torch.distributed.tensor._api import DTensor from torch.distributed.tensor._collective_utils import redistribute_cost from torch.distributed.tensor._dtensor_spec import DTensorSpec from torch.distributed.tensor._op_schema import ( OpSchema, OpSpec, OpStrategy, - OutputSharding, PlacementList, - RuntimeSchemaInfo, StrategyType, ) from torch.distributed.tensor.device_mesh import DeviceMesh @@ -29,77 +26,6 @@ ) -# convenient wrapper to register sharding propagation rules -def register_prop_rule( - op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]], - schema_info: Optional[RuntimeSchemaInfo] = None, -) -> Callable[ - [Callable[[OpSchema], OutputSharding]], Callable[[OpSchema], OutputSharding] -]: - def wrapper( - impl: Callable[[OpSchema], OutputSharding], - ) -> Callable[[OpSchema], OutputSharding]: - overloads = op if isinstance(op, list) else [op] - for overload in overloads: - DTensor._op_dispatcher.sharding_propagator.register_sharding_prop_rule( - overload, impl, schema_info - ) - return impl - - return wrapper - - -# Note: -# using TypeVar here allows the registration decorator to preserve the specific type info of the wrapped strategy, -# while hardcoding the typing on the wrapper (e.g. Callable[[OpSchema], StrategyType]) would mean mypy would treat -# the return value of the wrapped strategy as always being a `StrategyType` even if it were a derived class like -# MyStrategyType(StrategyType). -_OpSchemaT = TypeVar("_OpSchemaT", bound=OpSchema) -_StrategyTypeT = TypeVar("_StrategyTypeT", bound=StrategyType) -_ShardingStrategyFunc: TypeAlias = Callable[[_OpSchemaT], _StrategyTypeT] - - -def register_op_strategy( - op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]], - schema_info: Optional[RuntimeSchemaInfo] = None, -) -> Callable[[_ShardingStrategyFunc], _ShardingStrategyFunc]: - # For every ATen op that accepts any args in this list, - # the arg itself can impact the strides (and potentially the sharding strategy) - # of the output tensor. - # thus, we will detect ATen schemas with any of these args and ensure - # that they get specialized here. - arg_names_that_require_specializing_cache_strategy = [ - "memory_format", - ] - - def wrapper(impl: _ShardingStrategyFunc) -> _ShardingStrategyFunc: - if isinstance(op, list): - overloads = op - else: - overloads = [op] - - for overload in overloads: - curr_schema_info = None - if schema_info is None: - specialized_args = [ - a.name - for a in overload._schema.arguments - if a.name in arg_names_that_require_specializing_cache_strategy - ] - if any(specialized_args): - curr_schema_info = RuntimeSchemaInfo( - static_kwargkey=specialized_args - ) - else: - curr_schema_info = schema_info - DTensor._op_dispatcher.sharding_propagator.register_op_strategy( - overload, impl, curr_schema_info - ) - return impl - - return wrapper - - def replicate_op_strategy(op_schema: OpSchema) -> StrategyType: """ Fallback strategy all use Replication() From c740e857bcdf22b43092aaebf14b63983b262812 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Sun, 23 Nov 2025 19:04:11 +0000 Subject: [PATCH 220/230] [BE] Delete `missing_vXXX_neon headers (#168909) This was a workaround for gcc-8 on ARM, introduced by https://github.com/pytorch/pytorch/pull/44199, which is no longer relevant as CentOS-8 is past its EOL Was reminded about it while looking at https://github.com/pytorch/pytorch/issues/168907 Pull Request resolved: https://github.com/pytorch/pytorch/pull/168909 Approved by: https://github.com/Skylion007, https://github.com/nimeduhansaka --- CMakeLists.txt | 35 -- .../ATen/cpu/vec/vec256/missing_vld1_neon.h | 1 - .../ATen/cpu/vec/vec256/missing_vst1_neon.h | 1 - torch/headeronly/cpu/vec/intrinsics.h | 5 - .../cpu/vec/vec256/missing_vld1_neon.h | 396 ------------------ .../cpu/vec/vec256/missing_vst1_neon.h | 7 - 6 files changed, 445 deletions(-) delete mode 100644 aten/src/ATen/cpu/vec/vec256/missing_vld1_neon.h delete mode 100644 aten/src/ATen/cpu/vec/vec256/missing_vst1_neon.h delete mode 100644 torch/headeronly/cpu/vec/vec256/missing_vld1_neon.h delete mode 100644 torch/headeronly/cpu/vec/vec256/missing_vst1_neon.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 877ed9fafd3b1..b2d9ed4188071 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1217,41 +1217,6 @@ else() append_cxx_flag_if_supported("/wd4273" CMAKE_CXX_FLAGS) endif() -if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") - include(CheckCSourceCompiles) - check_c_source_compiles( - "#include -int main() { - float a[] = {1.0, 1.0}; - float32x4x2_t v; - v.val[0] = vcombine_f32 (vcreate_f32 (0UL), vcreate_f32 (0UL)); - v.val[1] = vcombine_f32 (vcreate_f32 (0UL), vcreate_f32 (0UL)); - vst1q_f32_x2(a, v); - return 0; -}" - HAS_VST1) - - if(NOT HAS_VST1) - string(APPEND CMAKE_CXX_FLAGS " -DMISSING_ARM_VST1") - endif() -endif() - -if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") - include(CheckCSourceCompiles) - check_c_source_compiles( - "#include -int main() { - float a[] = {1.0, 1.0}; - vld1q_f32_x2(a); - return 0; -}" - HAS_VLD1) - - if(NOT HAS_VLD1) - string(APPEND CMAKE_CXX_FLAGS " -DMISSING_ARM_VLD1") - endif() -endif() - # Add code coverage flags to supported compilers if(USE_CPP_CODE_COVERAGE) if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") diff --git a/aten/src/ATen/cpu/vec/vec256/missing_vld1_neon.h b/aten/src/ATen/cpu/vec/vec256/missing_vld1_neon.h deleted file mode 100644 index aa40000b6ccdb..0000000000000 --- a/aten/src/ATen/cpu/vec/vec256/missing_vld1_neon.h +++ /dev/null @@ -1 +0,0 @@ -#include diff --git a/aten/src/ATen/cpu/vec/vec256/missing_vst1_neon.h b/aten/src/ATen/cpu/vec/vec256/missing_vst1_neon.h deleted file mode 100644 index b3d721531d246..0000000000000 --- a/aten/src/ATen/cpu/vec/vec256/missing_vst1_neon.h +++ /dev/null @@ -1 +0,0 @@ -#include diff --git a/torch/headeronly/cpu/vec/intrinsics.h b/torch/headeronly/cpu/vec/intrinsics.h index 4342005e30f49..3cf427dae64bc 100644 --- a/torch/headeronly/cpu/vec/intrinsics.h +++ b/torch/headeronly/cpu/vec/intrinsics.h @@ -29,11 +29,6 @@ /* GCC-compatible compiler, targeting ARM with SVE */ #include #endif -#if defined(MISSING_ARM_VLD1) -#include -#elif defined(MISSING_ARM_VST1) -#include -#endif #elif defined(__GNUC__) && defined(__IWMMXT__) /* GCC-compatible compiler, targeting ARM with WMMX */ #include diff --git a/torch/headeronly/cpu/vec/vec256/missing_vld1_neon.h b/torch/headeronly/cpu/vec/vec256/missing_vld1_neon.h deleted file mode 100644 index b78841ead92e9..0000000000000 --- a/torch/headeronly/cpu/vec/vec256/missing_vld1_neon.h +++ /dev/null @@ -1,396 +0,0 @@ -/* Workaround for missing vld1_*_x2 and vst1_*_x2 intrinsics in gcc-7. */ - -__extension__ extern __inline uint8x8x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_u8_x2(const uint8_t* __a) { - uint8x8x2_t ret; - asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline int8x8x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_s8_x2(const int8_t* __a) { - int8x8x2_t ret; - asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline uint16x4x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_u16_x2(const uint16_t* __a) { - uint16x4x2_t ret; - asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline int16x4x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_s16_x2(const int16_t* __a) { - int16x4x2_t ret; - asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline uint32x2x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_u32_x2(const uint32_t* __a) { - uint32x2x2_t ret; - asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline int32x2x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_s32_x2(const int32_t* __a) { - int32x2x2_t ret; - asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline uint64x1x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_u64_x2(const uint64_t* __a) { - uint64x1x2_t ret; - asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline int64x1x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_s64_x2(const int64_t* __a) { - int64x1x2_t ret; - __builtin_aarch64_simd_oi __o; - asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline float16x4x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_f16_x2(const float16_t* __a) { - float16x4x2_t ret; - asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline float32x2x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_f32_x2(const float32_t* __a) { - float32x2x2_t ret; - asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline float64x1x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_f64_x2(const float64_t* __a) { - float64x1x2_t ret; - asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline poly8x8x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_p8_x2(const poly8_t* __a) { - poly8x8x2_t ret; - asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline poly16x4x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_p16_x2(const poly16_t* __a) { - poly16x4x2_t ret; - asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline poly64x1x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_p64_x2(const poly64_t* __a) { - poly64x1x2_t ret; - asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline uint8x16x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_u8_x2(const uint8_t* __a) { - uint8x16x2_t ret; - asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline int8x16x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_s8_x2(const int8_t* __a) { - int8x16x2_t ret; - asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline uint16x8x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_u16_x2(const uint16_t* __a) { - uint16x8x2_t ret; - asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline int16x8x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_s16_x2(const int16_t* __a) { - int16x8x2_t ret; - asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline uint32x4x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_u32_x2(const uint32_t* __a) { - uint32x4x2_t ret; - asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline int32x4x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_s32_x2(const int32_t* __a) { - int32x4x2_t ret; - asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline uint64x2x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_u64_x2(const uint64_t* __a) { - uint64x2x2_t ret; - asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline int64x2x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_s64_x2(const int64_t* __a) { - int64x2x2_t ret; - asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline float16x8x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_f16_x2(const float16_t* __a) { - float16x8x2_t ret; - asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline float32x4x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_f32_x2(const float32_t* __a) { - float32x4x2_t ret; - asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline float64x2x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_f64_x2(const float64_t* __a) { - float64x2x2_t ret; - asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline poly8x16x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_p8_x2(const poly8_t* __a) { - poly8x16x2_t ret; - asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline poly16x8x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_p16_x2(const poly16_t* __a) { - poly16x8x2_t ret; - asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline poly64x2x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_p64_x2(const poly64_t* __a) { - poly64x2x2_t ret; - asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -/* vst1x2 */ - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_s64_x2(int64_t* __a, int64x1x2_t val) { - asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_u64_x2(uint64_t* __a, uint64x1x2_t val) { - asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_f64_x2(float64_t* __a, float64x1x2_t val) { - asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_s8_x2(int8_t* __a, int8x8x2_t val) { - asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_p8_x2(poly8_t* __a, poly8x8x2_t val) { - asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_s16_x2(int16_t* __a, int16x4x2_t val) { - asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_p16_x2(poly16_t* __a, poly16x4x2_t val) { - asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_s32_x2(int32_t* __a, int32x2x2_t val) { - asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_u8_x2(uint8_t* __a, uint8x8x2_t val) { - asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_u16_x2(uint16_t* __a, uint16x4x2_t val) { - asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_u32_x2(uint32_t* __a, uint32x2x2_t val) { - asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_f16_x2(float16_t* __a, float16x4x2_t val) { - asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_f32_x2(float32_t* __a, float32x2x2_t val) { - asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_p64_x2(poly64_t* __a, poly64x1x2_t val) { - asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_s8_x2(int8_t* __a, int8x16x2_t val) { - asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_p8_x2(poly8_t* __a, poly8x16x2_t val) { - asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_s16_x2(int16_t* __a, int16x8x2_t val) { - asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_p16_x2(poly16_t* __a, poly16x8x2_t val) { - asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_s32_x2(int32_t* __a, int32x4x2_t val) { - asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_s64_x2(int64_t* __a, int64x2x2_t val) { - asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_u8_x2(uint8_t* __a, uint8x16x2_t val) { - asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_u16_x2(uint16_t* __a, uint16x8x2_t val) { - asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_u32_x2(uint32_t* __a, uint32x4x2_t val) { - asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_u64_x2(uint64_t* __a, uint64x2x2_t val) { - asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_f16_x2(float16_t* __a, float16x8x2_t val) { - asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_f32_x2(float32_t* __a, float32x4x2_t val) { - asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_f64_x2(float64_t* __a, float64x2x2_t val) { - asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_p64_x2(poly64_t* __a, poly64x2x2_t val) { - asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val)); -} diff --git a/torch/headeronly/cpu/vec/vec256/missing_vst1_neon.h b/torch/headeronly/cpu/vec/vec256/missing_vst1_neon.h deleted file mode 100644 index 93f1110d808c6..0000000000000 --- a/torch/headeronly/cpu/vec/vec256/missing_vst1_neon.h +++ /dev/null @@ -1,7 +0,0 @@ -/* Workaround for missing vst1q_f32_x2 in gcc-8. */ - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_f32_x2(float32_t* __a, float32x4x2_t val) { - asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val)); -} From c9c8a8567dc06a4c6b78fc427cfc4692b38e8ced Mon Sep 17 00:00:00 2001 From: jainapurva Date: Sun, 23 Nov 2025 20:13:33 +0000 Subject: [PATCH 221/230] Add optimizer tests in operator microbenchmarks (#168101) This PR adds comprehensive benchmarks for PyTorch optimizers to measure optimizer.step() performance across different parameter configurations. ### Optimizers benchmarked: - AdamW - Adam - SGD (with momentum=0.9) - RMSprop - Adagrad ### Test configurations: - num_params: [1, 10, 100] - param_size: [100K, 1M, 10M] Pull Request resolved: https://github.com/pytorch/pytorch/pull/168101 Approved by: https://github.com/slayton58 --- .ci/pytorch/test.sh | 6 +- .../operator_benchmark/benchmark_core.py | 19 ++- .../operator_benchmark/benchmark_pytorch.py | 38 ++++++ .../operator_benchmark/pt/addmm_test.py | 40 ++++++ benchmarks/operator_benchmark/pt/bmm_test.py | 33 +++++ benchmarks/operator_benchmark/pt/conv_test.py | 126 ++++++++++++++++++ .../operator_benchmark/pt/matmul_test.py | 16 +++ benchmarks/operator_benchmark/pt/mm_test.py | 14 ++ .../operator_benchmark/pt/optimizer_test.py | 65 +++++++++ 9 files changed, 354 insertions(+), 3 deletions(-) create mode 100644 benchmarks/operator_benchmark/pt/optimizer_test.py diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 01075259e9fe9..7e25c8c6d199c 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1763,12 +1763,14 @@ test_operator_microbenchmark() { mkdir -p "$TEST_REPORTS_DIR" TEST_DIR=$(pwd) + test_inductor_set_cpu_affinity + cd benchmarks/operator_benchmark/pt_extension - python -m pip install . + python -m pip install . -v --no-build-isolation cd "${TEST_DIR}"/benchmarks/operator_benchmark - for OP_BENCHMARK_TESTS in matmul mm addmm bmm conv; do + for OP_BENCHMARK_TESTS in matmul mm addmm bmm conv optimizer; do $TASKSET python -m pt.${OP_BENCHMARK_TESTS}_test --tag-filter long \ --output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_microbenchmark_${OP_BENCHMARK_TESTS}_compile.json" \ --benchmark-name "PyTorch operator microbenchmark" --use-compile diff --git a/benchmarks/operator_benchmark/benchmark_core.py b/benchmarks/operator_benchmark/benchmark_core.py index 4c131843b372b..7a8f0988a1fbf 100644 --- a/benchmarks/operator_benchmark/benchmark_core.py +++ b/benchmarks/operator_benchmark/benchmark_core.py @@ -266,7 +266,13 @@ def _print_perf_result(self, results, test_case): print( f"{mode} Execution Time (us) : {results['reported_run_time_us'][0]:.3f}" ) - print(f"Peak Memory (KB) : {results['peak_memory']}\n") + print(f"Peak Memory (KB) : {results['peak_memory']}") + # Calculate and print memory bandwidth if operator provides memory traffic + if results.get("memory_bandwidth_gb_s") is not None: + print( + f"Memory Bandwidth (GB/s) : {results['memory_bandwidth_gb_s']:.2f}" + ) + print() def _perf_result_to_dict(self, results, test_case): """This function is the parallel of _print_perf_result, which instead of @@ -711,6 +717,17 @@ def run(self): result_dict = dict() result_dict["reported_run_time_us"] = [r[0] for r in results] result_dict["peak_memory"] = results[0][1] + + # Calculate memory bandwidth if operator provides memory traffic + memory_traffic_bytes = test_case.op_bench.get_memory_traffic_bytes() + if memory_traffic_bytes is not None: + execution_time_s = result_dict["reported_run_time_us"][0] / 1e6 + result_dict["memory_bandwidth_gb_s"] = ( + memory_traffic_bytes / execution_time_s / 1e9 + ) + else: + result_dict["memory_bandwidth_gb_s"] = None + self._print_perf_result(results=result_dict, test_case=test_case) # output results to csv diff --git a/benchmarks/operator_benchmark/benchmark_pytorch.py b/benchmarks/operator_benchmark/benchmark_pytorch.py index fa022417da451..ad9180a53da04 100644 --- a/benchmarks/operator_benchmark/benchmark_pytorch.py +++ b/benchmarks/operator_benchmark/benchmark_pytorch.py @@ -118,6 +118,44 @@ def test_name(self, **kargs): name = (self.module_name() + "_" + "_".join(test_name_str)).replace(" ", "") return name + def get_memory_traffic_bytes(self): + """Return the number of bytes read/written by this operator. + + Override this method in subclasses for operations with non-standard memory patterns + (e.g., matmul which is compute-bound rather than memory-bound). + + The framework will use this value along with execution time to compute + and report memory bandwidth in GB/s. + + Default implementation assumes a pointwise-like operation: + - Reads: all input tensors + - Writes: output tensor (estimated as size of largest input) + + This default works correctly for: + - Element-wise operations (add, mul, relu, etc.) + - Activations (gelu, sigmoid, etc.) + - Optimizers (SGD, Adam, etc.) + - Reductions (sum, mean, etc. - may underestimate writes) + + Returns: + int or None: Total bytes transferred (reads + writes), or None if not applicable + """ + if not hasattr(self, "inputs") or not self.inputs: + return None + + input_tensors = [v for v in self.inputs.values() if isinstance(v, torch.Tensor)] + if not input_tensors: + return None + + # Calculate total bytes read from all inputs + bytes_read = sum(t.numel() * t.element_size() for t in input_tensors) + + # Estimate output size as the largest input (common for pointwise ops) + largest_input = max(input_tensors, key=lambda t: t.numel()) + bytes_written = largest_input.numel() * largest_input.element_size() + + return bytes_read + bytes_written + class PyTorchOperatorTestCase: """This class includes all the information needed to benchmark an operator. diff --git a/benchmarks/operator_benchmark/pt/addmm_test.py b/benchmarks/operator_benchmark/pt/addmm_test.py index 3e94a9cd7f3dc..5d9cd14ec696e 100644 --- a/benchmarks/operator_benchmark/pt/addmm_test.py +++ b/benchmarks/operator_benchmark/pt/addmm_test.py @@ -52,6 +52,26 @@ def init(self, M, N, K, device, dtype): def forward(self, input_one, mat1, mat2): return torch.addmm(input_one, mat1, mat2) + def get_memory_traffic_bytes(self): + """Override for addmm: input + (mat1 @ mat2) -> (M, K) + addmm computes: input_one (M, K) + mat1 (M, N) @ mat2 (N, K) + Memory traffic: read(M*K + M*N + N*K) + write(M*K) + """ + input_one = self.inputs["input_one"] + mat1 = self.inputs["mat1"] + mat2 = self.inputs["mat2"] + + M, K = input_one.shape + M_check, N = mat1.shape + N_check, K_check = mat2.shape + assert M == M_check and K == K_check and N == N_check, ( + "Matrix dimensions must match" + ) + + bytes_per_element = input_one.element_size() + total_elements = M * K + M * N + N * K + M * K + return total_elements * bytes_per_element + op_bench.generate_pt_test(addmm_short_configs + addmm_long_configs, AddmmBenchmark) op_bench.generate_pt_gradient_test(addmm_long_configs, AddmmBenchmark) @@ -84,6 +104,26 @@ def init(self, B, M, N, K, device, dtype): def forward(self, input_one, batch1, batch2): return torch.addbmm(input_one, batch1, batch2) + def get_memory_traffic_bytes(self): + """Override for addbmm: input + sum(batch1[i] @ batch2[i]) -> (M, N) + addbmm computes: input_one (M, N) + sum over batch of batch1 (B, M, K) @ batch2 (B, K, N) + Memory traffic: read(M*N + B*M*K + B*K*N) + write(M*N) + """ + input_one = self.inputs["input_one"] + batch1 = self.inputs["batch1"] + batch2 = self.inputs["batch2"] + + M, N = input_one.shape + B, M_check, K = batch1.shape + B_check, K_check, N_check = batch2.shape + assert M == M_check and N == N_check and B == B_check and K == K_check, ( + "Dimensions must match" + ) + + bytes_per_element = input_one.element_size() + total_elements = M * N + B * M * K + B * K * N + M * N + return total_elements * bytes_per_element + addbmm_long_configs = op_bench.cross_product_configs( B=[8, 32], diff --git a/benchmarks/operator_benchmark/pt/bmm_test.py b/benchmarks/operator_benchmark/pt/bmm_test.py index f867f6ac09f8d..234bff20fb499 100644 --- a/benchmarks/operator_benchmark/pt/bmm_test.py +++ b/benchmarks/operator_benchmark/pt/bmm_test.py @@ -52,6 +52,20 @@ def init(self, B, M, N, K, device, dtype, op_func): def forward(self, batch1, batch2): return self.op_func(batch1, batch2) + def get_memory_traffic_bytes(self): + """Override for bmm: (B, M, N) @ (B, N, K) -> (B, M, K) + Memory traffic: read(B*M*N + B*N*K) + write(B*M*K) + """ + batch1 = self.inputs["batch1"] + batch2 = self.inputs["batch2"] + B, M, N = batch1.shape + B_check, N_check, K = batch2.shape + assert B == B_check and N == N_check, "Batch dimensions must match for bmm" + + bytes_per_element = batch1.element_size() + total_elements = B * (M * N + N * K + M * K) + return total_elements * bytes_per_element + op_bench.generate_pt_tests_from_op_list( batched_binary_ops, @@ -90,6 +104,25 @@ def init(self, B, M, N, K, device, dtype, op_func): def forward(self, input_, batch1, batch2): return self.op_func(input_, batch1, batch2) + def get_memory_traffic_bytes(self): + """Override for baddbmm: input + (batch1 @ batch2) -> (B, M, K) + Memory traffic: read(B*M*K + B*M*N + B*N*K) + write(B*M*K) + """ + input_ = self.inputs["input_"] + batch1 = self.inputs["batch1"] + batch2 = self.inputs["batch2"] + B, M, K = input_.shape + B_check1, M_check, N = batch1.shape + B_check2, N_check, K_check = batch2.shape + assert B == B_check1 == B_check2, "Batch dimensions must match" + assert M == M_check and K == K_check and N == N_check, ( + "Matrix dimensions must match" + ) + + bytes_per_element = input_.element_size() + total_elements = B * (M * K + M * N + N * K + M * K) + return total_elements * bytes_per_element + op_bench.generate_pt_tests_from_op_list( batched_ternary_ops, diff --git a/benchmarks/operator_benchmark/pt/conv_test.py b/benchmarks/operator_benchmark/pt/conv_test.py index f972db3f1693e..c7fa3f5c2381d 100644 --- a/benchmarks/operator_benchmark/pt/conv_test.py +++ b/benchmarks/operator_benchmark/pt/conv_test.py @@ -22,6 +22,24 @@ def init(self, IC, OC, kernel, stride, N, L, device): def forward(self, input): return self.conv1d(input) + def get_memory_traffic_bytes(self): + """Calculate memory traffic for Conv1d: read(input + weight) + write(output)""" + input_tensor = self.inputs["input"] + # Run forward to get output shape + with torch.no_grad(): + output = self.conv1d(input_tensor) + + bytes_per_element = input_tensor.element_size() + # Input: N × IC × L + input_elements = input_tensor.numel() + # Weight: OC × IC × kernel + weight_elements = self.conv1d.weight.numel() + # Output: N × OC × L_out + output_elements = output.numel() + + total_elements = input_elements + weight_elements + output_elements + return total_elements * bytes_per_element + class ConvTranspose1dBenchmark(op_bench.TorchBenchmarkBase): def init(self, IC, OC, kernel, stride, N, L, device): @@ -34,6 +52,24 @@ def init(self, IC, OC, kernel, stride, N, L, device): def forward(self, input): return self.convtranspose1d(input) + def get_memory_traffic_bytes(self): + """Calculate memory traffic for ConvTranspose1d: read(input + weight) + write(output)""" + input_tensor = self.inputs["input"] + # Run forward to get output shape + with torch.no_grad(): + output = self.convtranspose1d(input_tensor) + + bytes_per_element = input_tensor.element_size() + # Input: N × IC × L + input_elements = input_tensor.numel() + # Weight: IC × OC × kernel + weight_elements = self.convtranspose1d.weight.numel() + # Output: N × OC × L_out + output_elements = output.numel() + + total_elements = input_elements + weight_elements + output_elements + return total_elements * bytes_per_element + op_bench.generate_pt_test( configs.conv_1d_configs_short + configs.conv_1d_configs_long, Conv1dBenchmark @@ -67,6 +103,24 @@ def init(self, IC, OC, kernel, stride, N, H, W, G, pad, device): def forward(self, input): return self.conv2d(input) + def get_memory_traffic_bytes(self): + """Calculate memory traffic for Conv2d: read(input + weight) + write(output)""" + input_tensor = self.inputs["input"] + # Run forward to get output shape + with torch.no_grad(): + output = self.conv2d(input_tensor) + + bytes_per_element = input_tensor.element_size() + # Input: N × IC × H × W + input_elements = input_tensor.numel() + # Weight: OC × (IC/G) × kernel × kernel + weight_elements = self.conv2d.weight.numel() + # Output: N × OC × H_out × W_out + output_elements = output.numel() + + total_elements = input_elements + weight_elements + output_elements + return total_elements * bytes_per_element + class ConvTranspose2dBenchmark(op_bench.TorchBenchmarkBase): def init(self, IC, OC, kernel, stride, N, H, W, G, pad, device): @@ -79,6 +133,24 @@ def init(self, IC, OC, kernel, stride, N, H, W, G, pad, device): def forward(self, input): return self.convtranspose2d(input) + def get_memory_traffic_bytes(self): + """Calculate memory traffic for ConvTranspose2d: read(input + weight) + write(output)""" + input_tensor = self.inputs["input"] + # Run forward to get output shape + with torch.no_grad(): + output = self.convtranspose2d(input_tensor) + + bytes_per_element = input_tensor.element_size() + # Input: N × IC × H × W + input_elements = input_tensor.numel() + # Weight: IC × (OC/G) × kernel × kernel + weight_elements = self.convtranspose2d.weight.numel() + # Output: N × OC × H_out × W_out + output_elements = output.numel() + + total_elements = input_elements + weight_elements + output_elements + return total_elements * bytes_per_element + class Conv2dPointwiseBenchmark(op_bench.TorchBenchmarkBase): def init(self, IC, OC, stride, N, H, W, G, pad, device): @@ -92,6 +164,24 @@ def init(self, IC, OC, stride, N, H, W, G, pad, device): def forward(self, input): return self.conv2d(input) + def get_memory_traffic_bytes(self): + """Calculate memory traffic for Conv2dPointwise: read(input + weight) + write(output)""" + input_tensor = self.inputs["input"] + # Run forward to get output shape + with torch.no_grad(): + output = self.conv2d(input_tensor) + + bytes_per_element = input_tensor.element_size() + # Input: N × IC × H × W + input_elements = input_tensor.numel() + # Weight: OC × (IC/G) × 1 × 1 + weight_elements = self.conv2d.weight.numel() + # Output: N × OC × H_out × W_out + output_elements = output.numel() + + total_elements = input_elements + weight_elements + output_elements + return total_elements * bytes_per_element + op_bench.generate_pt_test( configs.conv_2d_configs_short + configs.conv_2d_configs_long, Conv2dBenchmark @@ -134,6 +224,24 @@ def init(self, IC, OC, kernel, stride, N, D, H, W, device): def forward(self, input): return self.conv3d(input) + def get_memory_traffic_bytes(self): + """Calculate memory traffic for Conv3d: read(input + weight) + write(output)""" + input_tensor = self.inputs["input"] + # Run forward to get output shape + with torch.no_grad(): + output = self.conv3d(input_tensor) + + bytes_per_element = input_tensor.element_size() + # Input: N × IC × D × H × W + input_elements = input_tensor.numel() + # Weight: OC × IC × kernel × kernel × kernel + weight_elements = self.conv3d.weight.numel() + # Output: N × OC × D_out × H_out × W_out + output_elements = output.numel() + + total_elements = input_elements + weight_elements + output_elements + return total_elements * bytes_per_element + class ConvTranspose3dBenchmark(op_bench.TorchBenchmarkBase): def init(self, IC, OC, kernel, stride, N, D, H, W, device): @@ -146,6 +254,24 @@ def init(self, IC, OC, kernel, stride, N, D, H, W, device): def forward(self, input): return self.convtranspose3d(input) + def get_memory_traffic_bytes(self): + """Calculate memory traffic for ConvTranspose3d: read(input + weight) + write(output)""" + input_tensor = self.inputs["input"] + # Run forward to get output shape + with torch.no_grad(): + output = self.convtranspose3d(input_tensor) + + bytes_per_element = input_tensor.element_size() + # Input: N × IC × D × H × W + input_elements = input_tensor.numel() + # Weight: IC × OC × kernel × kernel × kernel + weight_elements = self.convtranspose3d.weight.numel() + # Output: N × OC × D_out × H_out × W_out + output_elements = output.numel() + + total_elements = input_elements + weight_elements + output_elements + return total_elements * bytes_per_element + op_bench.generate_pt_test(configs.conv_3d_configs_short, Conv3dBenchmark) op_bench.generate_pt_test(configs.conv_3d_configs_short, ConvTranspose3dBenchmark) diff --git a/benchmarks/operator_benchmark/pt/matmul_test.py b/benchmarks/operator_benchmark/pt/matmul_test.py index d0c58aa16e8f3..4bde44d60f381 100644 --- a/benchmarks/operator_benchmark/pt/matmul_test.py +++ b/benchmarks/operator_benchmark/pt/matmul_test.py @@ -59,6 +59,22 @@ def init(self, M, N, K, trans_a, trans_b, device, dtype=torch.float): def forward(self, input_one, input_two): return torch.matmul(input_one, input_two) + def get_memory_traffic_bytes(self): + """Override for matmul: (M, N) @ (N, K) -> (M, K) + Memory traffic: read(M*N + N*K) + write(M*K) + """ + input_one = self.inputs["input_one"] + input_two = self.inputs["input_two"] + + # input_one and input_two are properly shaped for matmul regardless of transpose + M, N = input_one.shape + N_check, K = input_two.shape + assert N == N_check, "Matrix dimensions must match for matmul" + + bytes_per_element = input_one.element_size() + total_elements = M * N + N * K + M * K + return total_elements * bytes_per_element + op_bench.generate_pt_test(mm_long_configs + mm_short_configs, MatMulBenchmark) op_bench.generate_pt_gradient_test(mm_long_configs, MatMulBenchmark) diff --git a/benchmarks/operator_benchmark/pt/mm_test.py b/benchmarks/operator_benchmark/pt/mm_test.py index f9e0743ba7125..07e0b596960fe 100644 --- a/benchmarks/operator_benchmark/pt/mm_test.py +++ b/benchmarks/operator_benchmark/pt/mm_test.py @@ -47,6 +47,20 @@ def init(self, M, N, K, device, dtype, op_func): def forward(self, input_one, input_two): return self.op_func(input_one, input_two) + def get_memory_traffic_bytes(self): + """Override for matmul: (M, N) @ (N, K) -> (M, K) + Memory traffic: read(M*N + N*K) + write(M*K) + """ + input_one = self.inputs["input_one"] + input_two = self.inputs["input_two"] + M, N = input_one.shape + N_check, K = input_two.shape + assert N == N_check, "Matrix dimensions must match for matmul" + + bytes_per_element = input_one.element_size() + total_elements = M * N + N * K + M * K + return total_elements * bytes_per_element + op_bench.generate_pt_tests_from_op_list( ops_list, mm_short_configs + mm_long_configs, MmOpBenchmark diff --git a/benchmarks/operator_benchmark/pt/optimizer_test.py b/benchmarks/operator_benchmark/pt/optimizer_test.py new file mode 100644 index 0000000000000..53bab9773def4 --- /dev/null +++ b/benchmarks/operator_benchmark/pt/optimizer_test.py @@ -0,0 +1,65 @@ +import operator_benchmark as op_bench + +import torch +import torch.optim as optim + + +"""Microbenchmarks for optimizer operators.""" + + +optimizer_list = op_bench.op_list( + attr_names=["op_name", "op_func"], + attrs=[ + ["adamw", optim.AdamW], + ["adam", optim.Adam], + ["sgd", optim.SGD], + ["rmsprop", optim.RMSprop], + ["adagrad", optim.Adagrad], + ], +) + +optimizer_configs_long = op_bench.cross_product_configs( + shape=[(100000,), (1000000,), (10000000,)], + device=["cuda"], + tags=["long"], +) + + +class OptimizerBenchmark(op_bench.TorchBenchmarkBase): + def init(self, op_func, device, shape): + self.op_func = op_func + self.param = torch.randn( + shape, device=device, requires_grad=True, dtype=torch.float32 + ) + self.param.grad = torch.randn(shape, device=device) + + kwargs = {"momentum": 0.9} if op_func == optim.SGD else {} + self.optimizer = op_func([self.param], lr=0.001, **kwargs) + + self.inputs = {"dummy": self.param} + + def forward(self, dummy): + self.optimizer.step() + return self.param + + def get_memory_traffic_bytes(self): + # Memory traffic calculation for bandwidth + total_elements = self.param.numel() + bytes_per_element = self.param.element_size() + # SGD w/ momentum: read(param, grad, momentum) + write(param, momentum) = 5x + # Adam/AdamW: read(param, grad, exp_avg, exp_avg_sq) + write(param, exp_avg, exp_avg_sq) = 7x + # Adagrad/RMSprop: read(param, grad, state) + write(param, state) = 5x + if self.op_func in (optim.Adam, optim.AdamW): + memory_multiplier = 7 + else: + memory_multiplier = 5 + return total_elements * bytes_per_element * memory_multiplier + + +op_bench.generate_pt_tests_from_op_list( + optimizer_list, optimizer_configs_long, OptimizerBenchmark +) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() From 9a38bb8622e5427e28b655df89b81293f63ecaac Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Mon, 24 Nov 2025 00:03:19 +0000 Subject: [PATCH 222/230] [CUDA] Fix truncated error messages in cudaMallocAsync Allocator (#168369) Previously, these error messages would get truncated when they were hit on device 0 because device is a "char" (actually, an int8_t) and therefore '0' is interpreted as the null byte to terminate a string. Essentially, it is the same issue as https://github.com/pytorch/pytorch/pull/123984. There's something strange in the TORCH_CHECK_WITH macro that is causing this. I don't feel like figuring out those obscure macro details right now, though. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168369 Approved by: https://github.com/eqy --- c10/cuda/CUDAMallocAsyncAllocator.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/c10/cuda/CUDAMallocAsyncAllocator.cpp b/c10/cuda/CUDAMallocAsyncAllocator.cpp index 674eb00035c50..48bf95bb976d8 100644 --- a/c10/cuda/CUDAMallocAsyncAllocator.cpp +++ b/c10/cuda/CUDAMallocAsyncAllocator.cpp @@ -320,7 +320,7 @@ void mallocAsync( TORCH_INTERNAL_ASSERT( 0 <= device && device < device_count, "Invalid device index ", - device, + static_cast(device), ": did you call init?"); // If stream is a null (default) stream, @@ -370,7 +370,7 @@ void mallocAsync( OutOfMemoryError, false, "Allocation on device ", - device, + static_cast(device), " would exceed allowed memory. (out of memory)", "\nCurrently allocated : ", format_size(pytorch_used_bytes[device]), From dbe61249eac16d2a658da2896c5c35dfe1d398bf Mon Sep 17 00:00:00 2001 From: Luyao Ren Date: Mon, 24 Nov 2025 03:02:06 +0000 Subject: [PATCH 223/230] [tutorial] typo fix, update torch.compiler_cudagraph_trees.md (#167713) Fix a API typo in the cuda graph tutorial. The API given in cuda graph tutorial is wrong. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167713 Approved by: https://github.com/jerryzh168 --- docs/source/torch.compiler_cudagraph_trees.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/torch.compiler_cudagraph_trees.md b/docs/source/torch.compiler_cudagraph_trees.md index eb137625ea746..f220086f82dc2 100644 --- a/docs/source/torch.compiler_cudagraph_trees.md +++ b/docs/source/torch.compiler_cudagraph_trees.md @@ -319,7 +319,7 @@ Trees, we don’t want to add unintended dependencies between iterations that wo to prematurely free memory from a prior invocation. Our heuristics are in inference we start a new iteration on each invocation for torch.compile, and in training we do the same so long as there is not a pending backward that has not been invoked. If those heuristics are wrong, you can mark the start of a new iteration with -[torch.compiler.mark_step_begin()](https://pytorch.org/docs/stable/generated/torch.compiler.cudagraph_mark_step_begin.html), or clone +[torch.compiler.cudagraph_mark_step_begin()](https://pytorch.org/docs/stable/generated/torch.compiler.cudagraph_mark_step_begin.html), or clone tensors of a prior iteration (outside of torch.compile) before you begin the next run. ### Comparisons From 7833690a37737e9284d1b87d5e0d8db23e9167ac Mon Sep 17 00:00:00 2001 From: hinriksnaer Date: Mon, 24 Nov 2025 04:26:20 +0000 Subject: [PATCH 224/230] Removed deprecated `split_cat_fx_passes` (#167738) ## Remove deprecated `split_cat_fx_passes` First of a couple of small PRs that remove deprecated and unused code. Remove the deprecated `split_cat_fx_passes` configuration variable from inductor config and clean up associated test patches. ### Changes - Remove `split_cat_fx_passes` from `torch/_inductor/config.py` - Remove `@patch.object(config, "split_cat_fx_passes", False)` decorators from tests in `test/inductor/test_perf.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/167738 Approved by: https://github.com/jansel, https://github.com/eellison, https://github.com/cyyever --- test/inductor/test_perf.py | 2 -- torch/_inductor/config.py | 3 --- 2 files changed, 5 deletions(-) diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py index 8a48bee86ba4e..5ad37c10b2c1a 100644 --- a/test/inductor/test_perf.py +++ b/test/inductor/test_perf.py @@ -278,7 +278,6 @@ def f(a, b): inp = (T(10, 10), T(10, 10)) self.assertExpectedInline(count_numel(f, *inp), """680""") - @patch.object(config, "split_cat_fx_passes", False) @patch.object( config, "pre_grad_fusion_options", @@ -300,7 +299,6 @@ def f(*inputs): inp = (T(10, 10) for _ in range(16)) self.assertExpectedInline(count_numel(f, *inp), """6400""") - @patch.object(config, "split_cat_fx_passes", False) @patch.object( config, "pre_grad_fusion_options", diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 7048990692da0..45fa2d74acaed 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -303,9 +303,6 @@ def prologue_fusion_enabled() -> bool: ] ] = None -# Deprecated -split_cat_fx_passes = True - # Optimize conv-batchnorm if batchnorm is in eval mode. Slightly reduces numerical stability. efficient_conv_bn_eval_fx_passes = False From c91c92f3c76466731554170f806a8047ac0a5643 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Mon, 24 Nov 2025 05:18:24 +0000 Subject: [PATCH 225/230] Replace thrust::tie with structure binding (#168943) This PR removes unnecessary thrust::tie. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168943 Approved by: https://github.com/ngimel --- aten/src/ATen/native/cuda/group_norm_kernel.cu | 4 +--- aten/src/ATen/native/cuda/layer_norm_kernel.cu | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/native/cuda/group_norm_kernel.cu b/aten/src/ATen/native/cuda/group_norm_kernel.cu index cd48c16a32eb9..e16a09521754b 100644 --- a/aten/src/ATen/native/cuda/group_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/group_norm_kernel.cu @@ -63,9 +63,7 @@ __global__ void RowwiseMomentsCUDAKernel( val_shared_ptr); } if (threadIdx.x == 0) { - T_ACC m1; - T_ACC m2; - thrust::tie(m2, m1) = welford_op.project(val); + auto [m2, m1] = welford_op.project(val); mean[i] = m1; rstd[i] = c10::cuda::compat::rsqrt(m2 + static_cast(eps)); } diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index 730a7ea910961..1667265aef97a 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -86,9 +86,7 @@ __global__ void RowwiseMomentsCUDAKernel( val_shared_ptr); if (threadIdx.x == 0) { - T_ACC m1; - T_ACC m2; - thrust::tie(m2, m1) = welford_op.project(val); + auto [m2, m1] = welford_op.project(val); if constexpr (!rms_norm){ mean[i] = m1; rstd[i] = c10::cuda::compat::rsqrt(m2 + eps); From 265397e178dab071294f6a10e35226fe333b2983 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Mon, 24 Nov 2025 06:35:42 +0000 Subject: [PATCH 226/230] Remove unnecessary uses of thrust::tuple (#168936) This PR removes unnecessary uses of thrust::tuple before moving to CCCL. Pull Request resolved: https://github.com/pytorch/pytorch/pull/168936 Approved by: https://github.com/ngimel --- aten/src/ATen/native/cuda/ActivationEluKernel.cu | 2 -- aten/src/ATen/native/cuda/ActivationGeluKernel.cu | 1 - aten/src/ATen/native/cuda/ActivationGluKernel.cu | 2 -- aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu | 2 -- aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu | 2 -- aten/src/ATen/native/cuda/ActivationHardswishKernel.cu | 2 -- aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu | 2 -- aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu | 2 -- aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu | 2 -- aten/src/ATen/native/cuda/ActivationMishKernel.cu | 2 -- aten/src/ATen/native/cuda/ActivationSiluKernel.cu | 2 -- aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu | 2 -- aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu | 2 -- aten/src/ATen/native/cuda/ActivationThresholdKernel.cu | 2 -- aten/src/ATen/native/cuda/Loops.cuh | 2 +- aten/src/ATen/native/cuda/group_norm_kernel.cu | 1 - aten/src/ATen/native/cuda/layer_norm_kernel.cu | 3 +-- 17 files changed, 2 insertions(+), 31 deletions(-) diff --git a/aten/src/ATen/native/cuda/ActivationEluKernel.cu b/aten/src/ATen/native/cuda/ActivationEluKernel.cu index 5ad1f806f9ba5..9fc29aa5539b5 100644 --- a/aten/src/ATen/native/cuda/ActivationEluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationEluKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationGeluKernel.cu b/aten/src/ATen/native/cuda/ActivationGeluKernel.cu index cd5a0ae85e61c..87781c44e3348 100644 --- a/aten/src/ATen/native/cuda/ActivationGeluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationGeluKernel.cu @@ -5,7 +5,6 @@ #include -#include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationGluKernel.cu b/aten/src/ATen/native/cuda/ActivationGluKernel.cu index e28a6d61ea152..8a782a129c9fb 100644 --- a/aten/src/ATen/native/cuda/ActivationGluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationGluKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu b/aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu index 2a0be3f5d27bf..f0968b957aa6d 100644 --- a/aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu b/aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu index fcacef37ceaf0..813a8c07ccfac 100644 --- a/aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu b/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu index 1642d0909f7f0..651cdef82543b 100644 --- a/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu b/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu index a18072f7a27bc..85aa7ccd22a9e 100644 --- a/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu b/aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu index 72130739898fe..340a6f97d00de 100644 --- a/aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu b/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu index 9a1d672428b48..2175920917852 100644 --- a/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationMishKernel.cu b/aten/src/ATen/native/cuda/ActivationMishKernel.cu index 0db0e96bb180a..25ba9810e37cf 100644 --- a/aten/src/ATen/native/cuda/ActivationMishKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationMishKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationSiluKernel.cu b/aten/src/ATen/native/cuda/ActivationSiluKernel.cu index f7ddfd8502a18..ebdfe245b6166 100644 --- a/aten/src/ATen/native/cuda/ActivationSiluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationSiluKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu b/aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu index 64ffc21123707..65f4f3679f862 100644 --- a/aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu b/aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu index 0c2dc63dbcf45..712c86e0e5216 100644 --- a/aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu b/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu index 2d1cb4a47d7d8..430f9cbfa78bb 100644 --- a/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/Loops.cuh b/aten/src/ATen/native/cuda/Loops.cuh index a80c51fa6a9cb..e739d7d2ecee2 100644 --- a/aten/src/ATen/native/cuda/Loops.cuh +++ b/aten/src/ATen/native/cuda/Loops.cuh @@ -282,7 +282,7 @@ void gpu_kernel_multiple_outputs_impl(TensorIteratorBase& iter, const func_t& f) using traits = function_traits; using output_t = typename traits::result_type; static_assert(is_tuple::value, "f's return type must be `thrust::tuple`"); - constexpr int num_outputs = thrust::tuple_size::value; + constexpr int num_outputs = std::tuple_size::value; constexpr int num_inputs = traits::arity; constexpr int ntensors = num_outputs + num_inputs; diff --git a/aten/src/ATen/native/cuda/group_norm_kernel.cu b/aten/src/ATen/native/cuda/group_norm_kernel.cu index e16a09521754b..d144a9954ed33 100644 --- a/aten/src/ATen/native/cuda/group_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/group_norm_kernel.cu @@ -3,7 +3,6 @@ #include -#include #include #include diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index 1667265aef97a..937008f1e83bd 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -1,10 +1,9 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include +#include #include -#include - #include #include #include From f1c49c9372b9af1063b98a70c8528969e68ba04d Mon Sep 17 00:00:00 2001 From: Kushagra Rastogi Date: Mon, 24 Nov 2025 09:05:47 +0000 Subject: [PATCH 227/230] Checking if the input is finite before calculation in lowering of pow func (#167723) Fixes #167197 The inductor backend is trying to convert the float infinity value to an integer in pow lowering (possibly for indexing, iteration counts, or type conversions). Python/C++ cannot convert float('inf') to an integer, causing the overflow error Pull Request resolved: https://github.com/pytorch/pytorch/pull/167723 Approved by: https://github.com/shunting314 --- test/inductor/test_torchinductor.py | 35 +++++++++++++++++++++++++++++ torch/_inductor/lowering.py | 2 +- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index b1cea5eac77d7..3bc1dba12acd8 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -5528,6 +5528,32 @@ def fn(x): check_lowp=not is_halide_backend(self.device), # misaligned addr fp16 ) + def test_lp_pool1d_with_inf_norm(self): + # https://github.com/pytorch/pytorch/issues/167197 + # Test that LPPool1d works with infinity norm (should behave like max pooling) + def fn(x): + return torch.nn.functional.lp_pool1d( + x, norm_type=float("inf"), kernel_size=2, stride=2 + ) + + self.common( + fn, + (torch.randn(3, 4, 8),), + ) + + def test_lp_pool2d_with_inf_norm(self): + # https://github.com/pytorch/pytorch/issues/167197 + # Test that LPPool2d works with infinity norm (should behave like max pooling) + def fn(x): + return torch.nn.functional.lp_pool2d( + x, norm_type=float("inf"), kernel_size=2, stride=2 + ) + + self.common( + fn, + (torch.randn(3, 4, 8, 8),), + ) + @tf32_on_and_off(0.006) @skip_if_gpu_halide # slow def test_alexnet_prefix(self): @@ -6307,6 +6333,15 @@ def fn(x): x = torch.randn([16, 16], device=self.device) self.assertEqual(cfn(x), fn(x)) + def test_pow_infinite(self): + def fn(a, b): + return torch.pow(a, b) + + opt = torch.compile(fn, backend="inductor") + a = torch.randn((3, 4, 8), device=self.device) + b = float("inf") + self.assertTrue(same(opt(a, b), fn(a, b))) + def test_glu(self): def fn(x): return aten.glu(x, -1), aten.glu(x, 1), aten.glu(x, 2) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 090265d208c92..d9890f1958edd 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -6361,7 +6361,7 @@ def pow_native(a, b): @register_lowering(aten.pow, broadcast=True) def pow(a, b): - if isinstance(b, float) and b == int(b): + if isinstance(b, float) and math.isfinite(b) and b == int(b): return pow(a, int(b)) elif isinstance(b, float) and b == 0.5: return sqrt(a) From 1aaedbcfdd5c0615a882eebba3f51b2409162142 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 21 Nov 2025 16:22:11 -0800 Subject: [PATCH 228/230] [dynamo][hops] Add xfail tests for side effects (#168394) Pull Request resolved: https://github.com/pytorch/pytorch/pull/168394 Approved by: https://github.com/jansel --- test/dynamo/test_autograd_function.py | 38 +++++++++++++++++++ test/higher_order_ops/test_invoke_subgraph.py | 29 ++++++++++++++ 2 files changed, 67 insertions(+) diff --git a/test/dynamo/test_autograd_function.py b/test/dynamo/test_autograd_function.py index 326a1e627b3f4..f2a99dd18e2b1 100644 --- a/test/dynamo/test_autograd_function.py +++ b/test/dynamo/test_autograd_function.py @@ -2,6 +2,7 @@ # flake8: noqa: B950 import copy import math +import unittest from dataclasses import dataclass import torch @@ -1543,6 +1544,43 @@ def f(x, y): loss.backward() self.assertEqual(x + y, z) + @unittest.expectedFailure + def test_nonlocal_list_mutation_in_autograd_function(self): + """Test that nonlocal list mutation in autograd.Function forward is handled correctly.""" + + class SimpleAutogradFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, x, z): + # Simple computation + o = torch.matmul(x, x) @ x + out = x.sin() + # Mutate the nonlocal list + z.append(out) + return torch.cos(torch.sin(o)), torch.sin(x) + + @staticmethod + def backward(ctx, grad_output1, grad_output2): + # Simple backward + return grad_output1 + grad_output2, None + + def fn(x): + z = [] + + outs = SimpleAutogradFunc.apply(x, z) + out1 = outs[0] + # Check that the extra output pytree handling is done properly + out2 = outs[-1] + + return out1 + out2, z[0] + + x = torch.randn(4, 4, requires_grad=True) + ref = fn(x) + + opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref[0], res[0]) + self.assertEqual(ref[1], res[1]) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index 329f20f81cdb5..00cb0e7b8b21a 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -910,6 +910,35 @@ def forward(self, a: "f32[8]", l_y_: "f32[8]"): """, ) + @unittest.expectedFailure + def test_nonlocal_list_mutation_hidden(self): + """Test that nonlocal list mutation inside nested_compile_region is handled correctly.""" + + @nested_compile_region + def gn(x, z): + o = torch.matmul(x, x) @ x + out = x.sin() + z.append(out) + return torch.cos(torch.sin(o)), torch.sin(x) + + def fn(x): + z = [] + + outs = gn(x, z) + out1 = outs[0] + # Check that the extra output pytree handling is done properly + out2 = outs[-1] + + return out1 + out2, z[0] + + x = torch.randn(4, 4, requires_grad=True) + ref = fn(x) + + opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref[0], res[0]) + self.assertEqual(ref[1], res[1]) + @inductor_config.patch("fx_graph_cache", False) def test_view_to_reshape(self): @nested_compile_region From 5ff187d18b4056e0f696bed54232257133a162ef Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 24 Nov 2025 14:14:21 +0000 Subject: [PATCH 229/230] [Intel GPU] Update Intel Triton commit pin (#166436) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166436 Approved by: https://github.com/etaf, https://github.com/EikanWang, https://github.com/jansel --- .ci/docker/ci_commit_pins/triton-xpu.txt | 2 +- test/inductor/test_analysis.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.ci/docker/ci_commit_pins/triton-xpu.txt b/.ci/docker/ci_commit_pins/triton-xpu.txt index b03606f6defc1..b660ae2dd3eaa 100644 --- a/.ci/docker/ci_commit_pins/triton-xpu.txt +++ b/.ci/docker/ci_commit_pins/triton-xpu.txt @@ -1 +1 @@ -1b0418a9a454b2b93ab8d71f40e59d2297157fae +aa01f5c2cd4db2b7bfa53ea98a1a8dfbd6d77c92 diff --git a/test/inductor/test_analysis.py b/test/inductor/test_analysis.py index a6946cb7b31a7..147760fe4df67 100644 --- a/test/inductor/test_analysis.py +++ b/test/inductor/test_analysis.py @@ -25,6 +25,7 @@ from torch.testing._internal.common_utils import ( parametrize, run_tests, + skipIfXpu, TEST_WITH_SLOW, TestCase, ) @@ -402,6 +403,9 @@ def verify_triton(comp): (not torch.xpu.is_available()) and (not SM80OrLater), "Requires XPU or CUDA SM80", ) + @skipIfXpu( + msg="Intel triton issue: https://github.com/intel/intel-xpu-backend-for-triton/issues/5491" + ) @skipXPUIf(TEST_WITH_SLOW, "Skip because test too slow on XPU") @dtypes(torch.float, torch.float16) @parametrize( From 654c5fba3e6cd8d5d171828b0aaef868d93d42ad Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 24 Nov 2025 14:24:17 +0000 Subject: [PATCH 230/230] Revert "bucketing compile time improve (#168122)" This reverts commit 1328a02d2eb26ee67346048ee242327ea90d6315. Reverted https://github.com/pytorch/pytorch/pull/168122 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/168122#issuecomment-3571037473)) --- .../test_overlap_bucketing_unit.py | 48 ++++++++++++-- .../fx_passes/overlap_manual_scheduling.py | 1 + .../fx_passes/overlap_preserving_bucketer.py | 65 ++++--------------- .../_inductor/fx_passes/overlap_scheduling.py | 1 + 4 files changed, 56 insertions(+), 59 deletions(-) diff --git a/test/distributed/test_overlap_bucketing_unit.py b/test/distributed/test_overlap_bucketing_unit.py index 8dd937a31c240..c0c4c31cc1a81 100644 --- a/test/distributed/test_overlap_bucketing_unit.py +++ b/test/distributed/test_overlap_bucketing_unit.py @@ -93,6 +93,28 @@ def build_collective_info(graph, hiding_annotations): return collective_info +def compute_ancestors(graph): + """Compute ancestor sets for all nodes in the graph.""" + node_ancestors = {} + + for node in graph.nodes: + ancestors = OrderedSet() + stack = list(node.all_input_nodes) + visited = set() + + while stack: + current = stack.pop() + if current in visited: + continue + visited.add(current) + ancestors.add(current) + stack.extend(current.all_input_nodes) + + node_ancestors[node] = ancestors + + return node_ancestors + + @requires_accelerator_dist_backend() @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @instantiate_parametrized_tests @@ -168,8 +190,9 @@ def func(a, b): ag2: mm2, # mm2 hides ag2 } - # Build collective info and scheduled + # Build collective info and ancestors collective_info = build_collective_info(traced.graph, hiding_annotations) + node_ancestors = compute_ancestors(traced.graph) scheduled = OrderedSet(traced.graph.nodes) # Run bucketing @@ -180,6 +203,7 @@ def func(a, b): bucketer = OverlapPreservingBucketer( traced.graph, collective_info, + node_ancestors, scheduled, ) bucketer.bucket_collectives() @@ -254,8 +278,9 @@ def func(a, b): ag2: mm2, # mm2 hides ag2 } - # Build collective info and scheduled + # Build collective info and ancestors collective_info = build_collective_info(traced.graph, hiding_annotations) + node_ancestors = compute_ancestors(traced.graph) scheduled = OrderedSet(traced.graph.nodes) # Run bucketing @@ -266,6 +291,7 @@ def func(a, b): bucketer = OverlapPreservingBucketer( traced.graph, collective_info, + node_ancestors, scheduled, ) bucketer.bucket_collectives() @@ -355,8 +381,9 @@ def func(a, b, c): if final_mm_hidden: hiding_annotations[rs] = mm2 - # Build collective info and scheduled + # Build collective info and ancestors collective_info = build_collective_info(traced.graph, hiding_annotations) + node_ancestors = compute_ancestors(traced.graph) scheduled = OrderedSet(traced.graph.nodes) # Run bucketing logic to find buckets (without applying them, which would require process groups) @@ -367,6 +394,7 @@ def func(a, b, c): bucketer = OverlapPreservingBucketer( traced.graph, collective_info, + node_ancestors, scheduled, ) @@ -439,6 +467,7 @@ def func(a, b): # Build collective info collective_info = build_collective_info(traced.graph, hiding_annotations) + node_ancestors = compute_ancestors(traced.graph) scheduled = OrderedSet(traced.graph.nodes) # Run bucketing @@ -449,6 +478,7 @@ def func(a, b): bucketer = OverlapPreservingBucketer( traced.graph, collective_info, + node_ancestors, scheduled, ) bucketer.bucket_collectives() @@ -520,8 +550,9 @@ def func(a, b): ag2: mm2, # mm2 hides ag2 } - # Build collective info and scheduled + # Build collective info and ancestors collective_info = build_collective_info(traced.graph, hiding_annotations) + node_ancestors = compute_ancestors(traced.graph) scheduled = OrderedSet(traced.graph.nodes) # Run bucketing with multidtype mode @@ -532,6 +563,7 @@ def func(a, b): bucketer = OverlapPreservingBucketer( traced.graph, collective_info, + node_ancestors, scheduled, bucket_mode="custom_ops_multidtype", ) @@ -603,8 +635,9 @@ def func(a, b): ag2: [mm2, mm3], # ag2 is hidden by mm2 and mm3 } - # Build collective info and scheduled + # Build collective info and ancestors collective_info = build_collective_info(traced.graph, hiding_annotations) + node_ancestors = compute_ancestors(traced.graph) scheduled = OrderedSet(traced.graph.nodes) # Verify hiding_nodes are correctly set @@ -623,6 +656,7 @@ def func(a, b): bucketer = OverlapPreservingBucketer( traced.graph, collective_info, + node_ancestors, scheduled, ) bucketer.bucket_collectives() @@ -695,8 +729,9 @@ def func(a, b, c): ag3: mm, } - # Build collective info and scheduled + # Build collective info and ancestors collective_info = build_collective_info(traced.graph, hiding_annotations) + node_ancestors = compute_ancestors(traced.graph) scheduled = OrderedSet(traced.graph.nodes) # Run bucketing @@ -707,6 +742,7 @@ def func(a, b, c): bucketer = OverlapPreservingBucketer( traced.graph, collective_info, + node_ancestors, scheduled, ) bucketer.bucket_collectives() diff --git a/torch/_inductor/fx_passes/overlap_manual_scheduling.py b/torch/_inductor/fx_passes/overlap_manual_scheduling.py index d2c8b588d2011..c8af70dc598f4 100644 --- a/torch/_inductor/fx_passes/overlap_manual_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_manual_scheduling.py @@ -182,6 +182,7 @@ def __init__( self.bucketer = ManualOverlapPreservingBucketer( graph=self.graph, collective_info=self.collective_info, + node_ancestors=self.node_ancestors, node_users=self.node_users, scheduled=OrderedSet(self.graph.nodes), ) diff --git a/torch/_inductor/fx_passes/overlap_preserving_bucketer.py b/torch/_inductor/fx_passes/overlap_preserving_bucketer.py index eb239a3a219a6..b5ef930b8fa8f 100644 --- a/torch/_inductor/fx_passes/overlap_preserving_bucketer.py +++ b/torch/_inductor/fx_passes/overlap_preserving_bucketer.py @@ -1,4 +1,3 @@ -import itertools import logging from collections import defaultdict from dataclasses import dataclass @@ -131,6 +130,7 @@ def __init__( self, graph: fx.Graph, collective_info: dict[fx.Node, CollectiveInfo], + node_ancestors: dict[fx.Node, OrderedSet[fx.Node]], scheduled: OrderedSet[fx.Node], max_bucket_memory_gb: float = 1.0, max_coll_distance: int = 1000, @@ -139,45 +139,18 @@ def __init__( ): self.graph = graph self.collective_info = collective_info + self.node_ancestors = node_ancestors self.scheduled = scheduled self.max_bucket_memory_gb = max_bucket_memory_gb self.node_idx = {n: i for i, n in enumerate(scheduled)} + self.aug_graph = AugmentedGraphHelper(self.graph, self.node_ancestors) self.max_coll_distance = max_coll_distance self.insert_overlap_deps = insert_overlap_deps self.bucket_mode = bucket_mode self.node_to_event: dict[fx.Node, PGEvent] = {} - - # Compute ancestors including original graph edges and hiding interval dependencies - self.node_ancestors = self._compute_node_ancestors() - self.aug_graph = AugmentedGraphHelper(self.graph, self.node_ancestors) - - # Build timelines and add constraints to aug_graph self.pg_to_timeline_head: dict[str, Optional[PGEvent]] = self.build_timelines() - self._add_hiding_interval_constraints() - - def _compute_node_ancestors(self) -> dict[fx.Node, OrderedSet[fx.Node]]: - """ - Compute ancestor sets for all nodes including: - 1. Original graph edges - 2. Hiding interval deps: collective_start -> hiding_node -> wait - """ - augmented_inputs: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet) - for start, info in self.collective_info.items(): - if info.is_exposed: - continue - for hiding_node in info.hiding_nodes: - augmented_inputs[hiding_node].add(start) - augmented_inputs[info.wait_node].add(hiding_node) - node_ancestors: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet) - for node in self.scheduled: - for input_node in itertools.chain( - augmented_inputs[node], node.all_input_nodes - ): - node_ancestors[node].add(input_node) - node_ancestors[node] |= node_ancestors[input_node] - - return node_ancestors + self._add_hiding_interval_constraints() def build_timelines(self) -> dict[str, Optional[PGEvent]]: "Construct each process groups ordered series of event" @@ -364,30 +337,21 @@ def _find_buckets( ) processed.add(start_node) - # Greedy optimization: stop after consecutive failures - consecutive_failures = 0 - max_consecutive_failures = 20 - # Check candidates in sorted order, break when beyond max distance for candidate in sorted_collectives[i + 1 : i + 1 + self.max_coll_distance]: + if candidate in processed: + continue + candidate_bytes = self.collective_info[candidate].size_bytes # proxy on memory use, if we see a too large bucket, # dont look for another, later bucket if bucket_info.total_bytes + candidate_bytes > max_bucket_bytes: break - if candidate in processed: - continue - if self._can_add_to_bucket(bucket_info, candidate): bucket_info.collectives.append(candidate) bucket_info.total_bytes += candidate_bytes processed.add(candidate) - consecutive_failures = 0 # Reset on success - else: - consecutive_failures += 1 - if consecutive_failures >= max_consecutive_failures: - break if len(bucket_info.collectives) > 1: buckets.append(bucket_info) @@ -692,28 +656,23 @@ def _has_ancestor_conflicts( candidate_wait = candidate_info.wait_node for coll in bucket_info.collectives: - if ( - coll in self.node_ancestors[candidate] - or candidate in self.node_ancestors[coll] - ): + # Check if collectives are ancestors of each other + if self._ancestor_dep(coll, candidate): return True # Check if waits are ancestors of each other coll_wait = self.collective_info[coll].wait_node - if ( - coll_wait in self.node_ancestors[candidate_wait] - or candidate_wait in self.node_ancestors[coll_wait] - ): + if self._ancestor_dep(candidate_wait, coll_wait): return True # Check if existing hiding node conflicts with candidate wait for old_hiding_node in self.collective_info[coll].hiding_nodes: - if candidate_wait in self.node_ancestors[old_hiding_node]: + if self._ancestor_dep(old_hiding_node, candidate_wait): return True # Check if candidate hiding node conflicts with existing wait for new_hiding_node in candidate_info.hiding_nodes: - if coll_wait in self.node_ancestors[new_hiding_node]: + if self._ancestor_dep(new_hiding_node, coll_wait): return True return False diff --git a/torch/_inductor/fx_passes/overlap_scheduling.py b/torch/_inductor/fx_passes/overlap_scheduling.py index 14555c84b43ce..436a3ab0db81b 100644 --- a/torch/_inductor/fx_passes/overlap_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_scheduling.py @@ -1125,6 +1125,7 @@ def _bucket_collectives(self) -> None: bucketer = OverlapPreservingBucketer( graph=self.graph, collective_info=self.collective_info, + node_ancestors=self.node_ancestors, scheduled=self.scheduled, max_bucket_memory_gb=2.0, # Could make this configurable max_coll_distance=self.max_node_distance,