Skip to content

Commit

Permalink
[PyTorch] Save a single add instruction in the dispatcher (pytorch#52543
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: pytorch#52543

This saves one (1) add instruction. New code comments should
explain exactly why. In short, we store a direct pointer in
`OperatorHandle` in addition to the `std::list<OperatorDef>::iterator`
because converting the latter to the former requires an add instruction.

It is not clear to me whether this is a particularly great tradeoff,
but I spent (more) time on it (than I expected), so here it is for
review.
ghstack-source-id: 122147199

Test Plan:
Inspect assembly for at::empty in benchmark code -- see add
instruction disappeared.

Compare empty benchmark performance to baseline with perf stat.

Baseline:
          5,077.43 msec task-clock                #    1.000 CPUs utilized            ( +-  0.25% )
               405      context-switches          #    0.080 K/sec                    ( +-  1.37% )
                 3      cpu-migrations            #    0.001 K/sec                    ( +- 18.22% )
            12,259      page-faults               #    0.002 M/sec                    ( +-  0.10% )
    10,089,754,343      cycles                    #    1.987 GHz                      ( +-  0.25% )  (50.04%)
    29,516,000,227      instructions              #    2.93  insn per cycle           ( +-  0.04% )  (50.08%)
     5,662,629,032      branches                  # 1115.256 M/sec                    ( +-  0.02% )  (50.08%)
         1,955,729      branch-misses             #    0.03% of all branches          ( +-  0.88% )  (50.04%)

            5.0796 +- 0.0128 seconds time elapsed  ( +-  0.25% )

After:
```
          5,017.77 msec task-clock                #    1.001 CPUs utilized            ( +-  0.19% )
               400      context-switches          #    0.080 K/sec                    ( +-  3.09% )
                 4      cpu-migrations            #    0.001 K/sec                    ( +- 46.91% )
            12,240      page-faults               #    0.002 M/sec                    ( +-  0.37% )
     9,960,189,535      cycles                    #    1.985 GHz                      ( +-  0.19% )  (50.02%)
    29,467,149,773      instructions              #    2.96  insn per cycle           ( +-  0.11% )  (50.03%)
     5,661,074,219      branches                  # 1128.206 M/sec                    ( +-  0.02% )  (50.07%)
         2,032,712      branch-misses             #    0.04% of all branches          ( +-  1.35% )  (50.07%)

            5.0151 +- 0.0101 seconds time elapsed  ( +-  0.20% )
```

1.2% cycles win, outside the noise
0.16% instruction count win, barely outside noise

I am surprised at the size of the cycles win.

Reviewed By: bhosmer

Differential Revision: D26564192

fbshipit-source-id: 71f731ba54ec1cb407673db691eaf77a257de4a9
  • Loading branch information
swolchok authored and aocsa committed Mar 5, 2021
1 parent 6dbad1e commit b066add
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 36 deletions.
42 changes: 22 additions & 20 deletions aten/src/ATen/core/dispatch/Dispatcher.cpp
Expand Up @@ -134,15 +134,15 @@ RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::strin
OperatorName op_name = schema.operator_name();
auto op = findOrRegisterName_(op_name);

TORCH_CHECK(op.operatorIterator_->def_count == 0, "Tried to register an operator (", schema, ") with the same name and overload name multiple times.",
TORCH_CHECK(op.operatorDef_->def_count == 0, "Tried to register an operator (", schema, ") with the same name and overload name multiple times.",
" Each overload's schema should only be registered with a single call to def().",
" Duplicate registration: ", debug, ". Original registration: ", op.operatorIterator_->op.debug());
op.operatorIterator_->op.registerSchema(std::move(schema), std::move(debug));
" Duplicate registration: ", debug, ". Original registration: ", op.operatorDef_->op.debug());
op.operatorDef_->op.registerSchema(std::move(schema), std::move(debug));
listeners_->callOnOperatorRegistered(op);

// NB: do not increment the counts until AFTER error checking
++op.operatorIterator_->def_count;
++op.operatorIterator_->def_and_impl_count;
++op.operatorDef_->def_count;
++op.operatorDef_->def_and_impl_count;

return RegistrationHandleRAII([this, op, op_name] {
deregisterDef_(op, op_name);
Expand All @@ -156,17 +156,17 @@ void Dispatcher::deregisterDef_(const OperatorHandle& op, const OperatorName& op
TORCH_INTERNAL_ASSERT(op.schema().operator_name() == op_name);

// reduce def_count and actually deregister if no references left
TORCH_INTERNAL_ASSERT(op.operatorIterator_->def_count > 0);
TORCH_INTERNAL_ASSERT(op.operatorIterator_->def_and_impl_count > 0);
TORCH_INTERNAL_ASSERT(op.operatorDef_->def_count > 0);
TORCH_INTERNAL_ASSERT(op.operatorDef_->def_and_impl_count > 0);

--op.operatorIterator_->def_count;
--op.operatorIterator_->def_and_impl_count;
if (0 == op.operatorIterator_->def_count) {
--op.operatorDef_->def_count;
--op.operatorDef_->def_and_impl_count;
if (0 == op.operatorDef_->def_count) {
// note: call listeners *before* operator is removed, i.e. dispatcher is still valid for removed op
// TODO: check that listeners are not relying on prepareForDeregistration()
// invariant
listeners_->callOnOperatorDeregistered(op);
op.operatorIterator_->op.deregisterSchema();
op.operatorDef_->op.deregisterSchema();
}

cleanup(op, op_name);
Expand All @@ -184,7 +184,7 @@ RegistrationHandleRAII Dispatcher::registerImpl(

auto op = findOrRegisterName_(op_name);

auto handle = op.operatorIterator_->op.registerKernel(
auto handle = op.operatorDef_->op.registerKernel(
*this,
dispatch_key,
std::move(kernel),
Expand All @@ -193,7 +193,7 @@ RegistrationHandleRAII Dispatcher::registerImpl(
std::move(debug)
);

++op.operatorIterator_->def_and_impl_count;
++op.operatorDef_->def_and_impl_count;

return RegistrationHandleRAII([this, op, op_name, dispatch_key, handle] {
deregisterImpl_(op, op_name, dispatch_key, handle);
Expand All @@ -203,20 +203,20 @@ RegistrationHandleRAII Dispatcher::registerImpl(
void Dispatcher::deregisterImpl_(const OperatorHandle& op, const OperatorName& op_name, c10::optional<DispatchKey> dispatch_key, std::list<impl::AnnotatedKernel>::iterator handle) {
std::lock_guard<std::mutex> lock(mutex_);

op.operatorIterator_->op.deregisterKernel_(*this, dispatch_key, handle);
op.operatorDef_->op.deregisterKernel_(*this, dispatch_key, handle);

TORCH_INTERNAL_ASSERT(op.operator_name() == op_name);

TORCH_INTERNAL_ASSERT(op.operatorIterator_->def_and_impl_count > 0);
--op.operatorIterator_->def_and_impl_count;
TORCH_INTERNAL_ASSERT(op.operatorDef_->def_and_impl_count > 0);
--op.operatorDef_->def_and_impl_count;

cleanup(op, op_name);
}

RegistrationHandleRAII Dispatcher::registerName(OperatorName op_name) {
std::lock_guard<std::mutex> lock(mutex_);
auto op = findOrRegisterName_(op_name);
++op.operatorIterator_->def_and_impl_count;
++op.operatorDef_->def_and_impl_count;
return RegistrationHandleRAII(
[this, op, op_name] { deregisterName_(op, op_name); });
}
Expand All @@ -226,14 +226,16 @@ void Dispatcher::deregisterName_(
const OperatorName& op_name) {
std::lock_guard<std::mutex> lock(mutex_);
TORCH_INTERNAL_ASSERT(op.operator_name() == op_name);
TORCH_INTERNAL_ASSERT(op.operatorIterator_->def_and_impl_count > 0);
--op.operatorIterator_->def_and_impl_count;
TORCH_INTERNAL_ASSERT(op.operatorDef_->def_and_impl_count > 0);
--op.operatorDef_->def_and_impl_count;
cleanup(op, op_name);
}

// Test if the operator entry is completely dead, and if so remove it completely
void Dispatcher::cleanup(const OperatorHandle& op, const OperatorName& op_name) {
if (0 == op.operatorIterator_->def_and_impl_count) {
if (0 == op.operatorDef_->def_and_impl_count) {
// NOTE: Making this call fast is the only reason OperatorHandle
// stores operatorIterator_!
operators_.erase(op.operatorIterator_);
operatorLookupTable_.write([&] (ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) {
operatorLookupTable.erase(op_name);
Expand Down
45 changes: 29 additions & 16 deletions aten/src/ATen/core/dispatch/Dispatcher.h
Expand Up @@ -291,31 +291,31 @@ class TORCH_API OperatorHandle {
OperatorHandle& operator=(const OperatorHandle&) = default;

const OperatorName& operator_name() const {
return operatorIterator_->op.operator_name();
return operatorDef_->op.operator_name();
}

bool hasSchema() const {
return operatorIterator_->op.hasSchema();
return operatorDef_->op.hasSchema();
}

const FunctionSchema& schema() const {
return operatorIterator_->op.schema();
return operatorDef_->op.schema();
}

const std::string& debug() const {
return operatorIterator_->op.debug();
return operatorDef_->op.debug();
}

std::string dumpState() const {
return operatorIterator_->op.dumpState();
return operatorDef_->op.dumpState();
}

std::string dumpComputedTable() const {
return operatorIterator_->op.dumpComputedTable();
return operatorDef_->op.dumpComputedTable();
}

void checkInvariants() const {
return operatorIterator_->op.checkInvariants();
return operatorDef_->op.checkInvariants();
}

template<class FuncType>
Expand All @@ -327,7 +327,7 @@ class TORCH_API OperatorHandle {
// in core library this won't happen, because all the static registrations
// will be done by the time a typed() handle is acquired.
#if !defined C10_MOBILE
operatorIterator_->op.assertSignatureIsCorrect<FuncType>();
operatorDef_->op.assertSignatureIsCorrect<FuncType>();
#endif
return TypedOperatorHandle<FuncType>(operatorIterator_);
}
Expand All @@ -342,10 +342,23 @@ class TORCH_API OperatorHandle {

private:
explicit OperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorIterator)
: operatorIterator_(std::move(operatorIterator)) {}
: operatorDef_(&*operatorIterator), operatorIterator_(operatorIterator) {}
friend class Dispatcher;
template<class> friend class TypedOperatorHandle;

// Storing a direct pointer to the OperatorDef even though we
// already have the iterator saves an instruction in the critical
// dispatch path. The iterator is effectively a
// pointer-to-std::list-node, and (at least in libstdc++'s
// implementation) the element is at an offset 16 bytes from that,
// because the prev/next pointers come first in the list node
// struct. So, an add instruction would be necessary to convert from the
// iterator to an OperatorDef*.
Dispatcher::OperatorDef* operatorDef_;

// We need to store this iterator in order to make
// Dispatcher::cleanup() fast -- it runs a lot on program
// termination (and presuambly library unloading).
std::list<Dispatcher::OperatorDef>::iterator operatorIterator_;
};

Expand Down Expand Up @@ -377,7 +390,7 @@ class TypedOperatorHandle<Return (Args...)> final : public OperatorHandle {

private:
explicit TypedOperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorIterator)
: OperatorHandle(std::move(operatorIterator)) {}
: OperatorHandle(operatorIterator) {}
friend class OperatorHandle;
};

Expand All @@ -396,7 +409,7 @@ inline Return Dispatcher::callWithDispatchKeySlowPath(const TypedOperatorHandle<
at::RecordFunction guard(at::RecordScope::FUNCTION, pre_sampled);
if (C10_UNLIKELY(guard.isActive())) {
auto dispatchKey = dispatchKeySet.highestPriorityTypeId();
if (op.operatorIterator_->op.isObserved()) {
if (op.operatorDef_->op.isObserved()) {
if (guard.needsInputs()) {
runRecordFunction(guard, op, dispatchKey, impl::boxArgs(args...));
} else {
Expand All @@ -411,13 +424,13 @@ inline Return Dispatcher::callWithDispatchKeySlowPath(const TypedOperatorHandle<
template<class Return, class... Args>
C10_ALWAYS_INLINE Return Dispatcher::call(const TypedOperatorHandle<Return(Args...)>& op, Args... args) const {
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
auto dispatchKeySet = op.operatorIterator_->op.dispatchKeyExtractor()
auto dispatchKeySet = op.operatorDef_->op.dispatchKeyExtractor()
.template getDispatchKeySetUnboxed<Args...>(
DispatchKeySet::FULL,
args...
);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c10::isAliasDispatchKey(dispatchKeySet.highestPriorityTypeId()));
const KernelFunction& kernel = op.operatorIterator_->op.lookup(dispatchKeySet.highestPriorityTypeId());
const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet.highestPriorityTypeId());
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
// By default, when there're no high-frequency or non-sampled callbacks,
// RecordFunction is pre-sampled as a perf optimization;
Expand All @@ -437,13 +450,13 @@ template<class Return, class... Args>
inline Return Dispatcher::redispatch(const TypedOperatorHandle<Return (Args...)>& op, DispatchKeySet currentDispatchKeySet, Args... args) const {
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
// do not use RecordFunction on redispatch
const KernelFunction& kernel = op.operatorIterator_->op.lookup(currentDispatchKeySet.highestPriorityTypeId());
const KernelFunction& kernel = op.operatorDef_->op.lookup(currentDispatchKeySet.highestPriorityTypeId());
return kernel.template call<Return, Args...>(op, currentDispatchKeySet, std::forward<Args>(args)...);
}

inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const {
// note: this doesn't need the mutex because write operations on the list keep iterators intact.
const auto& entry = op.operatorIterator_->op;
const auto& entry = op.operatorDef_->op;
auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack);
const auto& kernel = entry.lookup(dispatchKeySet.highestPriorityTypeId());
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
Expand Down Expand Up @@ -471,7 +484,7 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const

inline void Dispatcher::redispatchBoxed(const OperatorHandle& op, DispatchKeySet dispatchKeySet, Stack* stack) const {
// note: this doesn't need the mutex because write operations on the list keep iterators intact.
const auto& entry = op.operatorIterator_->op;
const auto& entry = op.operatorDef_->op;
const auto& kernel = entry.lookup(dispatchKeySet.highestPriorityTypeId());
return kernel.callBoxed(op, dispatchKeySet, stack);
}
Expand Down

0 comments on commit b066add

Please sign in to comment.