Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 0 additions & 58 deletions aten/src/ATen/core/boxing/KernelFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
#include <c10/core/DispatchKeySet.h>
#include <c10/util/TypeList.h>
#include <c10/util/intrusive_ptr.h>
#include <atomic>
#include <memory>
#include <type_traits>

namespace c10 {
Expand All @@ -19,9 +17,6 @@ class OperatorHandle;
struct OperatorKernel;
class KernelFunction;

class KernelToken;
class SafeKernelFunction;

template <typename T>
using has_symint = std::disjunction<
std::is_same<c10::SymInt, T>,
Expand Down Expand Up @@ -95,12 +90,6 @@ class TORCH_API KernelFunction final {
BoxedKernel::BoxedKernelFunction_withDispatchKeys;

KernelFunction();
~KernelFunction();

KernelFunction(const KernelFunction& other);
KernelFunction& operator=(const KernelFunction& other);

KernelFunction(KernelFunction&&) noexcept = default;

// Fast path for dispatch to allow not touching the boxed kernel in
// the common case where unboxed is available.
Expand Down Expand Up @@ -273,9 +262,6 @@ class TORCH_API KernelFunction final {
// For testing internal invariants only
bool _equalsBoxedAndUnboxed(const KernelFunction&) const;

// Register a token to be invalidated when this KernelFunction is destroyed
void registerToken(std::weak_ptr<KernelToken> token) const;

private:
explicit KernelFunction(
std::unique_ptr<OperatorKernel> functor,
Expand All @@ -290,50 +276,6 @@ class TORCH_API KernelFunction final {
BoxedKernel boxed_kernel_func_;
void* unboxed_kernel_func_;
void* sym_unboxed_kernel_func_;
// List of tokens that need to be invalidated when this KernelFunction is
// destroyed (lazy allocation to save memory when empty)
mutable std::unique_ptr<std::vector<std::weak_ptr<KernelToken>>> tokens_;
};

// Token held by SafeKernelFunction that gets invalidated when KernelFunction is
// destroyed
class KernelToken {
public:
bool isValid() const;
void invalidate();

private:
std::atomic<bool> invalid_{false};
};

class SafeKernelFunction {
public:
SafeKernelFunction(
const KernelFunction* kernel,
std::string debug,
std::shared_ptr<OperatorHandle> opHandle);

// Safe callBoxed - checks token validity first
void callBoxed(
const OperatorHandle& opHandle,
DispatchKeySet dispatchKeySet,
Stack* stack) const;

// Get debug information
const std::string& debug() const {
return debug_;
}

// Get the OpHandle that lives on this SafeKernelFunction
const OperatorHandle& opHandle() const {
return *opHandle_;
}

private:
KernelFunction kernel_;
std::shared_ptr<KernelToken> token_;
std::string debug_;
std::shared_ptr<OperatorHandle> opHandle_;
};

} // namespace c10
Expand Down
72 changes: 0 additions & 72 deletions aten/src/ATen/core/boxing/KernelFunction_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,36 +24,6 @@ inline KernelFunction::KernelFunction()
unboxed_kernel_func_(nullptr),
sym_unboxed_kernel_func_(nullptr) {}

inline KernelFunction::~KernelFunction() {
if (tokens_) {
for (auto& weak_token : *tokens_) {
if (auto token = weak_token.lock()) {
token->invalidate();
}
}
}
}

inline KernelFunction::KernelFunction(const KernelFunction& other)
: boxed_kernel_func_(other.boxed_kernel_func_),
unboxed_kernel_func_(other.unboxed_kernel_func_),
sym_unboxed_kernel_func_(other.sym_unboxed_kernel_func_) {
// tokens_ is intentionally not copied as we only care about invalidating
// tokens if the original KernelFunction is destroyed
}

inline KernelFunction& KernelFunction::operator=(const KernelFunction& other) {
if (this != &other) {
boxed_kernel_func_ = other.boxed_kernel_func_;
unboxed_kernel_func_ = other.unboxed_kernel_func_;
sym_unboxed_kernel_func_ = other.sym_unboxed_kernel_func_;

// tokens_ is intentionally not copied as we only care about invalidating
// tokens if the original KernelFunction is destroyed
}
return *this;
}

inline KernelFunction::KernelFunction(
std::unique_ptr<OperatorKernel> functor,
InternalBoxedKernelFunction* boxed_kernel_func,
Expand Down Expand Up @@ -187,14 +157,6 @@ C10_ALWAYS_INLINE Return KernelFunction::call(
std::forward<Args>(args)...);
}

inline void KernelFunction::registerToken(
std::weak_ptr<KernelToken> token) const {
if (!tokens_) {
tokens_ = std::make_unique<std::vector<std::weak_ptr<KernelToken>>>();
}
tokens_->push_back(std::move(token));
}

inline KernelFunction KernelFunction::makeFromBoxedKernel(
BoxedKernel boxed_fn) {
return KernelFunction(
Expand Down Expand Up @@ -355,38 +317,4 @@ KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) {
std::forward<Lambda>(lambda)));
}

inline bool KernelToken::isValid() const {
return !invalid_.load(std::memory_order_acquire);
}

inline void KernelToken::invalidate() {
invalid_.store(true, std::memory_order_release);
}

inline SafeKernelFunction::SafeKernelFunction(
const KernelFunction* kernel,
std::string debug,
std::shared_ptr<OperatorHandle> opHandle)
: kernel_(kernel ? *kernel : KernelFunction()),
token_(std::make_shared<KernelToken>()),
debug_(std::move(debug)),
opHandle_(std::move(opHandle)) {
// Register the token with the original kernel so it gets invalidated when the
// kernel is destroyed
if (kernel) {
kernel->registerToken(token_);
}
}

inline void SafeKernelFunction::callBoxed(
const OperatorHandle& opHandle,
DispatchKeySet dispatchKeySet,
Stack* stack) const {
TORCH_CHECK(
token_ && token_->isValid(),
"SafeKernelFunction has been invalidated ",
debug_);
kernel_.callBoxed(opHandle, dispatchKeySet, stack);
}

} // namespace c10
4 changes: 0 additions & 4 deletions aten/src/ATen/core/dispatch/Dispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -487,10 +487,6 @@ class TORCH_API OperatorHandle {
return operatorDef_->op.hasComputedKernelForDispatchKey(k);
}

SafeKernelFunction getComputedKernelForDispatchKey(DispatchKey k) const {
return operatorDef_->op.getComputedKernelForDispatchKey(k);
}

std::string dumpComputedTable() const {
return operatorDef_->op.dumpComputedTable();
}
Expand Down
36 changes: 0 additions & 36 deletions aten/src/ATen/core/dispatch/OperatorEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,42 +315,6 @@ const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispat
return nullptr;
}

SafeKernelFunction OperatorEntry::getComputedKernelForDispatchKey(
DispatchKey k) const {
TORCH_CHECK(
!isAliasDispatchKey(k),
"Alias keys do not have runtime kernel registrations.");
const auto dispatch_ix = getDispatchTableIndexForDispatchKey(k);
TORCH_CHECK(
dispatchTable_[dispatch_ix].isValid(),
"no kernel for ",
k,
" for ",
name_);

// Get the KernelFunction object from kernels_ to pass to SafeKernelFunction

// The KernelFunction object in dispatchTable_ is a copy of the KernelFunction
// in the AnnotatedKernel in kernels_. A KernelFunction is only truly
// deregistered when the kernel is removed from kernels_. However, the
// KernelFunction in dispatchTable_ might be removed before it is deregistered
// (when a newer kernel is registered). Therefore, here we want to return a
// SafeKernelFunction that is backed by the original KernelFunction in
// kernels_, so that we only invalidate it when the kernel is deregistered.
auto [annotatedKernel, _] =
computeDispatchTableEntryWithDebug(c10::Dispatcher::singleton(), k);

// Use findSchemaOrThrow to get OpHandle for the OperatorEntry
auto& dispatcher = c10::Dispatcher::singleton();
auto opHandle = dispatcher.findSchemaOrThrow(
name_.name.c_str(), name_.overload_name.c_str());

return SafeKernelFunction(
&annotatedKernel.kernel,
annotatedKernel.debug,
std::make_shared<OperatorHandle>(opHandle));
}

const std::vector<at::Tag>& OperatorEntry::getTags() const {
#if defined C10_MOBILE
TORCH_CHECK(false, "tags are not saved for Mobile");
Expand Down
2 changes: 0 additions & 2 deletions aten/src/ATen/core/dispatch/OperatorEntry.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,6 @@ class TORCH_API OperatorEntry final {
const KernelFunction& kernelForDispatchKey(DispatchKey k) const;
// Returns true if the "computed table" has an entry for a particular key.
bool hasComputedKernelForDispatchKey(DispatchKey k) const;
// Returns a KernelFunction corresponding to the kernel in dispatchTable
SafeKernelFunction getComputedKernelForDispatchKey(DispatchKey k) const;
// Returns all the operator tags added at the time of registration
const std::vector<at::Tag>& getTags() const;
void setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback);
Expand Down
1 change: 0 additions & 1 deletion docs/source/library.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ via PyTorch's C++ operator registration APIs).
.. autofunction:: infer_schema
.. autoclass:: torch._library.custom_ops.CustomOpDef
:members: set_kernel_enabled
.. autofunction:: get_kernel
```

## Low-level APIs
Expand Down
Loading