Skip to content

Commit

Permalink
Revert "Revert "[functorch] Refactor life handle storage (pytorch#90317
Browse files Browse the repository at this point in the history
…)"" (pytorch#90856)

Adds the fix for -Wsign-compare.

See original PR (pytorch#90317) for
commit message
Pull Request resolved: pytorch#90856
Approved by: https://github.com/samdow
  • Loading branch information
zou3519 authored and pytorchmergebot committed Dec 15, 2022
1 parent 81f351a commit abc54f9
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 57 deletions.
19 changes: 13 additions & 6 deletions aten/src/ATen/functorch/ADInterpreters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,12 @@ static void autogradBasedTransformProcess(
static void autogradBasedTransformSendToNext(
const c10::OperatorHandle& op,
torch::jit::Stack* stack,
int64_t current_level,
const Interpreter& interpreter,
TransformType transform_type,
optional<bool> prev_grad_mode,
optional<bool> prev_fwd_grad_mode,
bool grad_special_case) {
auto current_level = interpreter.level();
if (transform_type == TransformType::Grad) {
TORCH_INTERNAL_ASSERT(prev_grad_mode.has_value());
}
Expand Down Expand Up @@ -110,7 +111,7 @@ static void autogradBasedTransformSendToNext(
// if (c10::show_dispatch_trace_enabled()) {
// std::cout << "wrap " << current_level << std::endl;
// }
return makeTensorWrapper(tensor, current_level, is_immutable);
return makeTensorWrapper(tensor, interpreter, is_immutable);
};

// TODO: we only need to do the following (marked with !) on in-place functions
Expand Down Expand Up @@ -208,8 +209,11 @@ void GradInterpreterPtr::sendToNextInterpreterImpl(
torch::jit::Stack* stack,
bool grad_special_case) {
autogradBasedTransformSendToNext(
op, stack, level(),
TransformType::Grad, prevGradMode(), nullopt, grad_special_case);
op, stack, *base_,
TransformType::Grad,
prevGradMode(),
nullopt,
grad_special_case);
}

void JvpInterpreterPtr::processImpl(
Expand All @@ -223,8 +227,11 @@ void JvpInterpreterPtr::sendToNextInterpreterImpl(
torch::jit::Stack* stack,
bool grad_special_case) {
autogradBasedTransformSendToNext(
op, stack, level(),
TransformType::Jvp, nullopt, prevFwdGradMode(), grad_special_case);
op, stack, *base_,
TransformType::Jvp,
nullopt,
prevFwdGradMode(),
grad_special_case);
}

}} // namespace at::functorch
45 changes: 14 additions & 31 deletions aten/src/ATen/functorch/DynamicLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,6 @@ RandomnessType DynamicLayer::randomness() const {
return VmapInterpreterPtr(&interpreter_).randomness();
}

// Maps level to life handle, see NOTE: [Life handles and lexically scoped transforms]
// for details
using DynmetaData = std::unordered_map<int64_t, std::shared_ptr<bool>>;
DynmetaData kDynMetaDataSingleton;

static DynmetaData& getGlobalDynmetaData() {
return kDynMetaDataSingleton;
}

// functorch stores some TLS. Inside the TLS is the stack of transforms.
// Unfortunately, since functorch isn't a part of libtorch, we have
// a level of indirection. FuncTorchTLSBase is the interface that lives in libtorch,
Expand Down Expand Up @@ -166,10 +157,16 @@ static std::vector<DynamicLayer>& dynamicLayerStackAccessor() {
return getRawFunctorchTLS()->dynamicLayerStack;
}

std::shared_ptr<bool> getLifeHandleForLevel(int64_t level) {
auto it = getGlobalDynmetaData().find(level);
TORCH_INTERNAL_ASSERT(it != kDynMetaDataSingleton.end(), "level should be alive");
return it->second;
const std::shared_ptr<bool>& getLifeHandleForLevel(int64_t level) {
auto& dynamicLayerStack = dynamicLayerStackAccessor();
TORCH_INTERNAL_ASSERT(
(int64_t)dynamicLayerStack.size() >= level && level >= 1,
"If you're trying to construct a tensor with the current level (",
level,
") then the interpreter for that level must be on the DynamicLayerStack ");

auto& dynamic_layer = dynamicLayerStack[level - 1];
return dynamic_layer.interpreter().is_alive_ptr();
}

optional<DynamicLayer> maybeCurrentDynamicLayer() {
Expand Down Expand Up @@ -209,11 +206,6 @@ void setDynamicLayerStack(const std::vector<DynamicLayer>& stack) {
dynamicLayerStackAccessor() = stack;
}

bool areTransformsActive() {
const auto& data = getGlobalDynmetaData();
return !data.empty();
}

DynamicLayer popDynamicLayer() {
auto& dynamicLayerStack = dynamicLayerStackAccessor();
TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0);
Expand Down Expand Up @@ -262,32 +254,23 @@ int64_t initAndPushDynamicLayer(
DynamicLayer new_layer(transform_type, layerId, batch_size, randomness, prev_grad_mode, prev_fwd_grad_mode, functionalize_add_back_views);
pushDynamicLayer(std::move(new_layer));

auto& data = getGlobalDynmetaData();
// NB: this function should be called while holding the GIL to avoid races
new_layer.interpreter().set_is_alive(true);

TORCH_INTERNAL_ASSERT(data.find(layerId) == data.end());
if (transform_type == TransformType::Grad) {
TORCH_INTERNAL_ASSERT(prev_grad_mode.has_value());
}
if (transform_type == TransformType::Jvp) {
TORCH_INTERNAL_ASSERT(prev_fwd_grad_mode.has_value());
}
data[layerId] = std::make_shared<bool>(true);
return layerId;
}

DynamicLayer popDynamicLayerAndDeleteMetadata() {
auto result = popDynamicLayer();
auto level = result.layerId();

// TODO: is this lock safe? No one else should be writing to the same bucket
auto& data = getGlobalDynmetaData();
auto it = data.find(level);
if (it == data.end()) {
return result;
}
// invalidate the thing
*(it->second) = false;
data.erase(level);
// NB: this function should be called while holding the GIL to avoid races
result.interpreter().set_is_alive(false);
return result;
}

Expand Down
8 changes: 1 addition & 7 deletions aten/src/ATen/functorch/DynamicLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,6 @@ TORCH_API const std::vector<DynamicLayer>& getDynamicLayerStack();
TORCH_API void setDynamicLayerStack(const std::vector<DynamicLayer>& stack);
TORCH_API void setDynamicLayerFrontBackKeysIncluded(bool included);

// NB: Not lock safe, you should only call this from Python where the GIL will
// prevent race conditions.
TORCH_API bool areTransformsActive();

// NOTE: [Life handles and lexically scoped transforms]
// functorch transforms are lexically scoped.
// Given a level, we store a "life handle" that is a boolean that tells us if the
Expand All @@ -92,9 +88,7 @@ TORCH_API bool areTransformsActive();
// functorch's TensorWrapper (for grad transforms) stores a life handle.
// If a TensorWrapper escapes from the scope of the transform, then somehow
// it must know it escaped; it can tell by querying the life handle.
//
// NB: not lock safe. TODO: does it need a lock?
TORCH_API std::shared_ptr<bool> getLifeHandleForLevel(int64_t level);
TORCH_API const std::shared_ptr<bool>& getLifeHandleForLevel(int64_t level);

// Returns if an operator is in-place. An operator is inplace if:
// 1. The first argument is a Tensor and it is being written to
Expand Down
16 changes: 15 additions & 1 deletion aten/src/ATen/functorch/Interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,17 +155,31 @@ struct Interpreter {
return *savedLocalDispatchKeySet_;
}

// An Interpreter is alive if we are currently inside the ongoing transform
// for the interpreter. For example, vmap(f)(x); inside of f, the vmap's
// corresponding Interpreter is alive, even when it is not on the DynamicLayerStack.
bool is_alive() const {
return *is_alive_;
}
const std::shared_ptr<bool>& is_alive_ptr() const {
return is_alive_;
}
void set_is_alive(bool alive) {
*is_alive_ = alive;
}

// Please don't use this
explicit Interpreter() = default;

private:
explicit Interpreter(TransformType type, int64_t level, InterpreterMeta meta):
type_(type), level_(level), meta_(meta) {}
type_(type), level_(level), is_alive_(std::make_shared<bool>(false)), meta_(meta) {}

// fields
TransformType type_;
int64_t level_;
optional<c10::impl::LocalDispatchKeySet> savedLocalDispatchKeySet_;
std::shared_ptr<bool> is_alive_;
InterpreterMeta meta_;
};

Expand Down
43 changes: 32 additions & 11 deletions aten/src/ATen/functorch/TensorWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,22 @@ void dumpTensorCout(const Tensor& tensor) {
std::cout << std::endl;
}

c10::intrusive_ptr<TensorWrapper> makeTensorWrapperPtr(const Tensor& tensor, int64_t level, bool should_be_alive) {
c10::intrusive_ptr<TensorWrapper> makeTensorWrapperPtr(const Tensor& tensor, int64_t level, const std::shared_ptr<bool>& life_handle) {
auto keys_to_propagate = kKeysToPropagateToWrapper | DispatchKeySet({
DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA});
auto key_set = getKeysToPropagateToWrapper(tensor, keys_to_propagate);
key_set = key_set.add(DispatchKey::FuncTorchGradWrapper);
if (should_be_alive) {
return c10::make_intrusive<TensorWrapper>(key_set, tensor, level, getLifeHandleForLevel(level));
} else {
return c10::make_intrusive<TensorWrapper>(key_set, tensor, level, std::make_shared<bool>(false));
}
return c10::make_intrusive<TensorWrapper>(key_set, tensor, level, life_handle);
}

Tensor makeTensorWrapper(const Tensor& tensor, int64_t level, bool is_immutable) {
// use makeTensorWrapper instead to avoid potential footguns:
// unsafeMakeTensorWrapper doesn't check that level and life_handle
// refer to the same interpreter
static Tensor unsafeMakeTensorWrapper(
const Tensor& tensor,
int64_t level,
bool is_immutable,
const std::shared_ptr<bool>& life_handle) {
auto wrapped = maybeGetTensorWrapper(tensor);
if (wrapped) {
TORCH_INTERNAL_ASSERT(wrapped->level() < level);
Expand All @@ -80,20 +83,38 @@ Tensor makeTensorWrapper(const Tensor& tensor, int64_t level, bool is_immutable)
DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA});
auto key_set = getKeysToPropagateToWrapper(tensor, keys_to_propagate);
key_set = key_set.add(DispatchKey::FuncTorchGradWrapper);
auto life_handle = getLifeHandleForLevel(level);
auto result = at::detail::make_tensor<TensorWrapper>(key_set, tensor, level, std::move(life_handle), is_immutable);
auto result = at::detail::make_tensor<TensorWrapper>(
key_set, tensor, level, life_handle, is_immutable);
TORCH_INTERNAL_ASSERT(result.key_set().has(DispatchKey::FuncTorchGradWrapper));
return result;
}

Tensor makeTensorWrapper(const Tensor& tensor, int64_t level, bool is_immutable) {
auto life_handle = getLifeHandleForLevel(level);
return unsafeMakeTensorWrapper(
tensor,
level,
is_immutable,
getLifeHandleForLevel(level));
}

Tensor makeTensorWrapper(const Tensor& tensor, const Interpreter& interpreter, bool is_immutable) {
return unsafeMakeTensorWrapper(
tensor,
interpreter.level(),
is_immutable,
interpreter.is_alive_ptr());
}


bool TensorWrapper::is_alive() const {
return *is_alive_;
}

c10::intrusive_ptr<TensorImpl> TensorWrapper::shallow_copy_and_detach(
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const {
auto dest_impl = makeTensorWrapperPtr(value(), level_, is_alive());
auto dest_impl = makeTensorWrapperPtr(value(), level_, is_alive_);
dest_impl->set_version_counter(version_counter);

// TODO: is this even right?
Expand All @@ -104,7 +125,7 @@ c10::intrusive_ptr<TensorImpl> TensorWrapper::shallow_copy_and_detach(
c10::intrusive_ptr<TensorImpl> TensorWrapper::shallow_copy_and_detach(
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const {
auto dest_impl = makeTensorWrapperPtr(value(), level_, is_alive());
auto dest_impl = makeTensorWrapperPtr(value(), level_, is_alive_);
dest_impl->set_version_counter(version_counter);

// TODO: is this even right?
Expand Down
12 changes: 12 additions & 0 deletions aten/src/ATen/functorch/TensorWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <ATen/functorch/Macros.h>
#include <ATen/Tensor.h>
#include <ATen/functorch/Interpreter.h>

namespace at {
namespace functorch {
Expand Down Expand Up @@ -89,7 +90,18 @@ struct TORCH_API TensorWrapper : public c10::TensorImpl {
std::shared_ptr<bool> is_alive_;
};

// There are two variants of makeTensorWrapper: one that accepts a level
// and one that accepts an Interpreter.
//
// The one that accepts a level tries to automatically get the life handle from the
// interpreter on the DynamicLayerStack.
// It needs to be used with caution: if the interpreter is not on the
// DynamicLayerStack, then we won't be able to find the life handle.
//
// In practice this isn't a problem: when we're constructing TensorWrapper in
// Python, the corresponding interpreter is on the stack.
TORCH_API Tensor makeTensorWrapper(const Tensor& tensor, int64_t level, bool is_immutable=false);
TORCH_API Tensor makeTensorWrapper(const Tensor& tensor, const Interpreter& interpreter, bool is_immutable=false);
TORCH_API TensorWrapper* maybeGetTensorWrapper(const Tensor& tensor);
TORCH_API void dumpTensor(std::ostream & ss, const Tensor& tensor);
TORCH_API void dumpTensorCout(const Tensor& tensor);
Expand Down
1 change: 0 additions & 1 deletion torch/csrc/functorch/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,6 @@ void initFuncTorchBindings(PyObject* module) {
m.def("dump_tensor", &dump_tensor, "dump_tensor");
m.def("reshape_dim_into", &at::functorch::reshape_dim_into);
m.def("reshape_dim_outof", &at::functorch::reshape_dim_outof);
m.def("are_transforms_active", &at::functorch::areTransformsActive);
// various debugging things. Maybe we should offer these as first-class APIs
// on Tensors?
m.def("is_batchedtensor", &is_batchedtensor);
Expand Down

0 comments on commit abc54f9

Please sign in to comment.