diff --git a/aten/src/ATen/core/boxing/KernelFunction.h b/aten/src/ATen/core/boxing/KernelFunction.h index 4300217235b8..06bcc5d4f49b 100644 --- a/aten/src/ATen/core/boxing/KernelFunction.h +++ b/aten/src/ATen/core/boxing/KernelFunction.h @@ -6,8 +6,6 @@ #include #include #include -#include -#include #include namespace c10 { @@ -19,9 +17,6 @@ class OperatorHandle; struct OperatorKernel; class KernelFunction; -class KernelToken; -class SafeKernelFunction; - template using has_symint = std::disjunction< std::is_same, @@ -95,12 +90,6 @@ class TORCH_API KernelFunction final { BoxedKernel::BoxedKernelFunction_withDispatchKeys; KernelFunction(); - ~KernelFunction(); - - KernelFunction(const KernelFunction& other); - KernelFunction& operator=(const KernelFunction& other); - - KernelFunction(KernelFunction&&) noexcept = default; // Fast path for dispatch to allow not touching the boxed kernel in // the common case where unboxed is available. @@ -273,9 +262,6 @@ class TORCH_API KernelFunction final { // For testing internal invariants only bool _equalsBoxedAndUnboxed(const KernelFunction&) const; - // Register a token to be invalidated when this KernelFunction is destroyed - void registerToken(std::weak_ptr token) const; - private: explicit KernelFunction( std::unique_ptr functor, @@ -290,50 +276,6 @@ class TORCH_API KernelFunction final { BoxedKernel boxed_kernel_func_; void* unboxed_kernel_func_; void* sym_unboxed_kernel_func_; - // List of tokens that need to be invalidated when this KernelFunction is - // destroyed (lazy allocation to save memory when empty) - mutable std::unique_ptr>> tokens_; -}; - -// Token held by SafeKernelFunction that gets invalidated when KernelFunction is -// destroyed -class KernelToken { - public: - bool isValid() const; - void invalidate(); - - private: - std::atomic invalid_{false}; -}; - -class SafeKernelFunction { - public: - SafeKernelFunction( - const KernelFunction* kernel, - std::string debug, - std::shared_ptr opHandle); - - // Safe callBoxed - checks token validity first - void callBoxed( - const OperatorHandle& opHandle, - DispatchKeySet dispatchKeySet, - Stack* stack) const; - - // Get debug information - const std::string& debug() const { - return debug_; - } - - // Get the OpHandle that lives on this SafeKernelFunction - const OperatorHandle& opHandle() const { - return *opHandle_; - } - - private: - KernelFunction kernel_; - std::shared_ptr token_; - std::string debug_; - std::shared_ptr opHandle_; }; } // namespace c10 diff --git a/aten/src/ATen/core/boxing/KernelFunction_impl.h b/aten/src/ATen/core/boxing/KernelFunction_impl.h index 672309ec19a2..a89a0e8952b6 100644 --- a/aten/src/ATen/core/boxing/KernelFunction_impl.h +++ b/aten/src/ATen/core/boxing/KernelFunction_impl.h @@ -24,36 +24,6 @@ inline KernelFunction::KernelFunction() unboxed_kernel_func_(nullptr), sym_unboxed_kernel_func_(nullptr) {} -inline KernelFunction::~KernelFunction() { - if (tokens_) { - for (auto& weak_token : *tokens_) { - if (auto token = weak_token.lock()) { - token->invalidate(); - } - } - } -} - -inline KernelFunction::KernelFunction(const KernelFunction& other) - : boxed_kernel_func_(other.boxed_kernel_func_), - unboxed_kernel_func_(other.unboxed_kernel_func_), - sym_unboxed_kernel_func_(other.sym_unboxed_kernel_func_) { - // tokens_ is intentionally not copied as we only care about invalidating - // tokens if the original KernelFunction is destroyed -} - -inline KernelFunction& KernelFunction::operator=(const KernelFunction& other) { - if (this != &other) { - boxed_kernel_func_ = other.boxed_kernel_func_; - unboxed_kernel_func_ = other.unboxed_kernel_func_; - sym_unboxed_kernel_func_ = other.sym_unboxed_kernel_func_; - - // tokens_ is intentionally not copied as we only care about invalidating - // tokens if the original KernelFunction is destroyed - } - return *this; -} - inline KernelFunction::KernelFunction( std::unique_ptr functor, InternalBoxedKernelFunction* boxed_kernel_func, @@ -187,14 +157,6 @@ C10_ALWAYS_INLINE Return KernelFunction::call( std::forward(args)...); } -inline void KernelFunction::registerToken( - std::weak_ptr token) const { - if (!tokens_) { - tokens_ = std::make_unique>>(); - } - tokens_->push_back(std::move(token)); -} - inline KernelFunction KernelFunction::makeFromBoxedKernel( BoxedKernel boxed_fn) { return KernelFunction( @@ -355,38 +317,4 @@ KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) { std::forward(lambda))); } -inline bool KernelToken::isValid() const { - return !invalid_.load(std::memory_order_acquire); -} - -inline void KernelToken::invalidate() { - invalid_.store(true, std::memory_order_release); -} - -inline SafeKernelFunction::SafeKernelFunction( - const KernelFunction* kernel, - std::string debug, - std::shared_ptr opHandle) - : kernel_(kernel ? *kernel : KernelFunction()), - token_(std::make_shared()), - debug_(std::move(debug)), - opHandle_(std::move(opHandle)) { - // Register the token with the original kernel so it gets invalidated when the - // kernel is destroyed - if (kernel) { - kernel->registerToken(token_); - } -} - -inline void SafeKernelFunction::callBoxed( - const OperatorHandle& opHandle, - DispatchKeySet dispatchKeySet, - Stack* stack) const { - TORCH_CHECK( - token_ && token_->isValid(), - "SafeKernelFunction has been invalidated ", - debug_); - kernel_.callBoxed(opHandle, dispatchKeySet, stack); -} - } // namespace c10 diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index 43eb0028c70f..bc043df6a93e 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -487,10 +487,6 @@ class TORCH_API OperatorHandle { return operatorDef_->op.hasComputedKernelForDispatchKey(k); } - SafeKernelFunction getComputedKernelForDispatchKey(DispatchKey k) const { - return operatorDef_->op.getComputedKernelForDispatchKey(k); - } - std::string dumpComputedTable() const { return operatorDef_->op.dumpComputedTable(); } diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp index c172e9b9c609..b4063fb720be 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp +++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp @@ -315,42 +315,6 @@ const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispat return nullptr; } -SafeKernelFunction OperatorEntry::getComputedKernelForDispatchKey( - DispatchKey k) const { - TORCH_CHECK( - !isAliasDispatchKey(k), - "Alias keys do not have runtime kernel registrations."); - const auto dispatch_ix = getDispatchTableIndexForDispatchKey(k); - TORCH_CHECK( - dispatchTable_[dispatch_ix].isValid(), - "no kernel for ", - k, - " for ", - name_); - - // Get the KernelFunction object from kernels_ to pass to SafeKernelFunction - - // The KernelFunction object in dispatchTable_ is a copy of the KernelFunction - // in the AnnotatedKernel in kernels_. A KernelFunction is only truly - // deregistered when the kernel is removed from kernels_. However, the - // KernelFunction in dispatchTable_ might be removed before it is deregistered - // (when a newer kernel is registered). Therefore, here we want to return a - // SafeKernelFunction that is backed by the original KernelFunction in - // kernels_, so that we only invalidate it when the kernel is deregistered. - auto [annotatedKernel, _] = - computeDispatchTableEntryWithDebug(c10::Dispatcher::singleton(), k); - - // Use findSchemaOrThrow to get OpHandle for the OperatorEntry - auto& dispatcher = c10::Dispatcher::singleton(); - auto opHandle = dispatcher.findSchemaOrThrow( - name_.name.c_str(), name_.overload_name.c_str()); - - return SafeKernelFunction( - &annotatedKernel.kernel, - annotatedKernel.debug, - std::make_shared(opHandle)); -} - const std::vector& OperatorEntry::getTags() const { #if defined C10_MOBILE TORCH_CHECK(false, "tags are not saved for Mobile"); diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.h b/aten/src/ATen/core/dispatch/OperatorEntry.h index 59b54ce1d9d3..83200ff9c94f 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.h +++ b/aten/src/ATen/core/dispatch/OperatorEntry.h @@ -217,8 +217,6 @@ class TORCH_API OperatorEntry final { const KernelFunction& kernelForDispatchKey(DispatchKey k) const; // Returns true if the "computed table" has an entry for a particular key. bool hasComputedKernelForDispatchKey(DispatchKey k) const; - // Returns a KernelFunction corresponding to the kernel in dispatchTable - SafeKernelFunction getComputedKernelForDispatchKey(DispatchKey k) const; // Returns all the operator tags added at the time of registration const std::vector& getTags() const; void setReportErrorCallback_(std::unique_ptr callback); diff --git a/docs/source/library.md b/docs/source/library.md index b31ca95d5b6a..9d706e2e1080 100644 --- a/docs/source/library.md +++ b/docs/source/library.md @@ -56,7 +56,6 @@ via PyTorch's C++ operator registration APIs). .. autofunction:: infer_schema .. autoclass:: torch._library.custom_ops.CustomOpDef :members: set_kernel_enabled -.. autofunction:: get_kernel ``` ## Low-level APIs diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index 491648494f6f..5a494f548742 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -11,7 +11,6 @@ import tempfile import typing import unittest -from functools import partial from pathlib import Path from typing import * # noqa: F403 @@ -4157,148 +4156,6 @@ def test_any_output_is_alias_to_input_or_output(self): ) ) - def test_library_get_kernel(self): - """Test registering a custom kernel, using it, then deregistering and verifying error.""" - - # Register a dummy kernel for arange to the CPU key that returns a tensor of ones - def dummy_arange_cpu( - dispatch_keys, - start, - end, - dtype=None, - layout=torch.strided, - device=None, - pin_memory=False, - ): - size = max(0, int(end - start)) - return torch.ones(size, dtype=dtype, device=device) - - with torch.library._scoped_library("aten", "IMPL") as lib: - lib.impl("arange.start", dummy_arange_cpu, "CPU", with_keyset=True) - - kernel = torch.library.get_kernel("aten::arange.start", "CPU") - dispatch_keys = torch._C.DispatchKeySet(torch._C.DispatchKey.CPU) - result = kernel.call_boxed(dispatch_keys, 0, 5) - - self.assertEqual(result, torch.ones(5)) - - # The kernel should now be invalidated after exiting the scoped_library context - with self.assertRaisesRegex(RuntimeError, "has been invalidated"): - kernel.call_boxed(dispatch_keys, 0, 5) - - def test_library_get_kernel_with_conditional_dispatch(self): - """Test registering a custom kernel with conditional dispatch logic.""" - - def conditional_arange_cpu1( - original_kernel, - dispatch_keys, - start, - end, - dtype=None, - layout=torch.strided, - device=None, - pin_memory=False, - ): - # If end is even, use the original kernel, otherwise return ones tensor - if end % 2 == 0: - op_handle = torch.ops.aten.arange.start._handle - return original_kernel.call_boxed( - dispatch_keys, - start, - end, - dtype=dtype, - layout=layout, - device=device, - pin_memory=pin_memory, - ) - else: - size = max(0, int(end - start)) - return torch.ones(size, dtype=dtype, device=device) - - def conditional_arange_cpu2( - original_kernel, - dispatch_keys, - start, - end, - dtype=None, - layout=torch.strided, - device=None, - pin_memory=False, - ): - # If start is even, use the original kernel, otherwise return twos tensor - if start % 2 == 0: - op_handle = torch.ops.aten.arange.start._handle - return original_kernel.call_boxed( - dispatch_keys, - start, - end, - dtype=dtype, - layout=layout, - device=device, - pin_memory=pin_memory, - ) - else: - size = max(0, int(end - start)) - return torch.empty(size, dtype=dtype, device=device).fill_(2) - - original_kernel = torch.library.get_kernel("aten::arange.start", "CPU") - expected_result1, expected_result2 = torch.ones(5), torch.arange(0, 6) - expected_result3, expected_result4, expected_result5 = ( - torch.ones(5), - torch.arange(0, 6), - torch.ones(5).fill_(2), - ) - - with torch.library._scoped_library("aten", "IMPL") as lib2: - with torch.library._scoped_library("aten", "IMPL") as lib1: - lib1.impl( - "arange.start", - partial(conditional_arange_cpu1, original_kernel), - "CPU", - with_keyset=True, - ) - - self.assertEqual(torch.arange(0, 5), expected_result1) - self.assertEqual(torch.arange(0, 6), expected_result2) - new_original_kernel = torch.library.get_kernel( - "aten::arange.start", "CPU" - ) - lib2.impl( - "arange.start", - partial(conditional_arange_cpu2, new_original_kernel), - "CPU", - allow_override=True, - with_keyset=True, - ) - - self.assertEqual(torch.arange(0, 5), expected_result3) - self.assertEqual(torch.arange(0, 6), expected_result4) - self.assertEqual(torch.arange(1, 6), expected_result5) - - # The kernel should now be invalidated after destroying lib1 - with self.assertRaisesRegex(RuntimeError, "has been invalidated"): - torch.arange(0, 5) - - # Should still work after destroying lib1 - self.assertEqual(torch.arange(1, 6), expected_result5) - - def test_library_get_kernel_invalid(self): - """Test that get_kernel raises an error when no kernel is available.""" - with torch.library._scoped_library("test_invalid_kernel", "DEF") as lib: - lib.define("cpu_only_op(Tensor x) -> Tensor") - lib.impl("cpu_only_op", lambda x: x * 2, "CPU") - - cpu_kernel = torch.library.get_kernel( - "test_invalid_kernel::cpu_only_op", "CPU" - ) - self.assertIsNotNone(cpu_kernel) - - # CUDA should fail at the isValid() check since no CUDA kernel exists - with self.assertRaisesRegex( - RuntimeError, "no kernel for CUDA for test_invalid_kernel::cpu_only_op" - ): - torch.library.get_kernel("test_invalid_kernel::cpu_only_op", "CUDA") - class MiniOpTestOther(CustomOpTestCaseBase): test_ns = "mini_op_test" diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 80437aa1d833..5fe3f7e178b7 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1695,11 +1695,6 @@ class _DispatchModule: _after_ADInplaceOrView_keyset: DispatchKeySet _after_autograd_keyset: DispatchKeySet -class _SafeKernelFunction: - def call_boxed(self, keyset: DispatchKeySet, *args, **kwargs) -> Any: ... - @property - def op_handle(self) -> _DispatchOperatorHandle: ... - def _dispatch_library( kind: str, name: str, @@ -1737,10 +1732,6 @@ def _dispatch_has_computed_kernel_for_dispatch_key( name: str, dispatch: _dispatchkey, ) -> _bool: ... -def _dispatch_get_computed_kernel_for_dispatch_key( - name: str, - dispatch: _dispatchkey, -) -> _SafeKernelFunction: ... def _dispatch_find_dangling_impls() -> list[str]: ... def _dispatch_get_all_op_names() -> list[str]: ... def _dispatch_tls_set_dispatch_key_excluded( diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 9d6eb35c7178..07fa4ea5e1dd 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -620,43 +620,6 @@ void initDispatchBindings(PyObject* module) { c10::parseDispatchKey(dispatch)); }); - // Bind SafeKernelFunction class - py::class_(m, "_SafeKernelFunction") - .def( - "call_boxed", - [](const c10::SafeKernelFunction& self, - c10::DispatchKeySet keyset, - py::args args, - const py::kwargs& kwargs) { - const auto& op = self.opHandle(); - auto stack = torch::jit::createStackForSchema( - op.schema(), - std::move(args), - kwargs, - /*self=*/std::nullopt); - self.callBoxed(op, keyset, &stack); - return torch::jit::createPyObjectForStack(std::move(stack)); - }) - .def( - "__repr__", - [](const c10::SafeKernelFunction& self) { - return "SafeKernelFunction(debug='" + self.debug() + "')"; - }) - .def_property_readonly( - "op_handle", [](const c10::SafeKernelFunction& self) -> py::object { - return py::cast(self.opHandle()); - }); - - m.def( - "_dispatch_get_computed_kernel_for_dispatch_key", - [](const char* name, - c10::DispatchKey dispatch) -> c10::SafeKernelFunction { - auto op = - c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); - TORCH_CHECK(op, "operator ", name, " does not exist"); - return op->getComputedKernelForDispatchKey(dispatch); - }); - m.def("_dispatch_find_dangling_impls", []() -> std::vector { auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls(); diff --git a/torch/library.py b/torch/library.py index d36c18158148..372037f09dbe 100644 --- a/torch/library.py +++ b/torch/library.py @@ -45,7 +45,6 @@ "register_torch_dispatch", "register_vmap", "get_ctx", - "get_kernel", "custom_op", "triton_op", "wrap_triton", @@ -1476,80 +1475,6 @@ def get_ctx() -> "torch._library.fake_impl.FakeImplCtx": return torch._library.fake_impl.global_ctx_getter() -def get_kernel( - op: _op_identifier, dispatch_key: Union[str, torch.DispatchKey] -) -> torch._C._SafeKernelFunction: - """Returns the computed kernel for a given operator and dispatch key. - - This function retrieves the kernel that would be executed for a given - operator and dispatch key combination. The returned SafeKernelFunction - can be used to call the kernel in a boxed fashion. The intended use - case for this function is to retrieve the original kernel for a given - dispatch key and then register another kernel to the same dispatch key - that calls into the original kernel for certain cases. - - Args: - op: Operator name (along with the overload) or OpOverload object - Can be a string (e.g., "aten::add.Tensor"), an OpOverload, or a CustomOpDef. - dispatch_key (str | torch.DispatchKey): The dispatch key to get the kernel for. - Can be a string (e.g., "CPU", "CUDA") or a DispatchKey enum value. - - Returns: - torch._C._SafeKernelFunction: A safe kernel function that can be used to - call the kernel. - - Raises: - RuntimeError: If the operator does not exist. - - Example: - >>> # Get the CPU kernel for torch.add - >>> kernel = torch.library.get_kernel("aten::add.Tensor", "CPU") - >>> - >>> # You can also use DispatchKey enum - >>> kernel = torch.library.get_kernel("aten::add.Tensor", torch.DispatchKey.CPU) - >>> - >>> # Or use an OpOverload directly - >>> kernel = torch.library.get_kernel(torch.ops.aten.add.Tensor, "CPU") - >>> - >>> # Example: Using get_kernel in a custom op with conditional dispatch - >>> # Get the original kernel for torch.sin - >>> original_sin_kernel = torch.library.get_kernel("aten::sin", "CPU") - >>> - >>> # If input has negative values, use original sin, otherwise return zeros - >>> def conditional_sin_impl(dispatch_keys, x): - >>> if (x < 0).any(): - >>> return original_sin_kernel.call_boxed(dispatch_keys, x) - >>> else: - >>> return torch.zeros_like(x) - >>> - >>> lib = torch.library.Library("aten", "IMPL") - >>> # with_keyset=True so the first argument to the impl is the current DispatchKeySet - >>> which needs to be the first argument to ``kernel.call_boxed`` - >>> lib.impl("sin", conditional_sin_impl, "CPU", with_keyset=True) - >>> - >>> # Test the conditional behavior - >>> x_positive = torch.tensor([1.0, 2.0]) - >>> x_mixed = torch.tensor([-1.0, 2.0]) - >>> torch.sin(x_positive) - tensor([0., 0.]) - >>> torch.sin(x_mixed) - tensor([-0.8415, 0.9093]) - """ - if not isinstance(op, (str, torch._ops.OpOverload)): - raise ValueError(f"get_kernel({op}): got unexpected type for op: {type(op)}") - - if isinstance(op, torch._ops.OpOverload): - op = op._name - - if isinstance(dispatch_key, str): - try: - dispatch_key = torch._C.DispatchKey.__members__[dispatch_key] - except KeyError: - raise ValueError(f"Invalid dispatch key: {dispatch_key}") from None - - return torch._C._dispatch_get_computed_kernel_for_dispatch_key(op, dispatch_key) - - _OPCHECK_DEFAULT_UTILS = ( "test_schema", "test_autograd_registration",