Skip to content

Commit

Permalink
Remove incorrect THP{Cpp,}Function_traverse PyObject traversals (pyto…
Browse files Browse the repository at this point in the history
…rch#102860)

Fixes pytorch#102174

Pull Request resolved: pytorch#102860
Approved by: https://github.com/albanD
  • Loading branch information
soulitzer authored and alimoezzi committed Jun 3, 2023
1 parent 7cdd3c4 commit 4092828
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 51 deletions.
72 changes: 72 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7462,6 +7462,78 @@ def forward(self, x):
gc.collect()
self.assertIsNone(ref_())

@parametrize("use_custom_function", [True, False])
@parametrize("use_tensor_hook", [True, False])
def test_hook_closure_cycle(self, use_custom_function, use_tensor_hook):
# This creates a cycle between the hook and grad_fn_b
# hook -> closure -> grad_fn_b (python) -> grad_fn (cpp) -> hook (cpp)
# -> dict -> hook
#
# This test is testing that the grad_fn_b (python) only traverses the
# dict if it is the only one holding a reference to the grad_fn_b (cpp)
# shared_ptr
#
# See: https://github.com/pytorch/pytorch/issues/102174
class Function(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x

@staticmethod
def backward(ctx, grad):
return grad

class Test():
pass

count = [0]

def scope():
a = torch.tensor(1., requires_grad=True)
if use_custom_function:
b = Function.apply(a)
else:
b = a.clone()
grad_fn_b = b.grad_fn
obj = Test()

def hook(*args):
# Make sure this hook's closure holds onto grad_fn_b
# This forms a cycle between the hook and grad_fn_b
# We also hold onto a sentinel object 'obj' to track
# whether this cycle is still alive. See 'ref' below.
grad_fn_b
obj
count[0] += 1
if use_tensor_hook:
b.register_hook(hook)
else:
b.grad_fn.register_hook(hook)
c = b.clone()
ref = weakref.ref(obj)
return c, ref

with disable_gc():
out, ref = scope()
out.backward(retain_graph=True)

gc.collect()

# Make sure gc does not clear the cycle noted above.
# e.g. the hook is alive and gets fired even after gc runs
out.backward(retain_graph=True)
self.assertEqual(count[0], 2)

# ref is still alive because the use_count of the cpp grad_fn
# shared_ptr > 1 since (1) the python grad_fn is alive, and (2) the
# rest of the graph holds onto the shared_ptr
self.assertIsNotNone(ref())

# Then delete the rest of the graph and check that ref is dead
del out
gc.collect()
self.assertIsNone(ref())

def test_full_backward_hook_double_backward(self):
x = torch.rand(1, requires_grad=True)
y = torch.rand_like(x)
Expand Down
49 changes: 28 additions & 21 deletions torch/csrc/autograd/python_cpp_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,30 +76,37 @@ PyObject* THPCppFunction_call(
}

int THPCppFunction_traverse(PyObject* self, visitproc visit, void* arg) {
auto& fn = *((THPCppFunction*)self)->cdata;
for (const auto& hook : fn.tensor_pre_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
if ((((THPCppFunction*)self)->cdata).use_count() == 1) {
// The fields traversed below are owned by the cpp grad_fn, which we own a
// reference to. We should only them traverse however if we are the only
// owner of the grad_fn, otherwise we risk prematurely gc'ing the grad_fn.
//
// See: https://github.com/pytorch/pytorch/issues/102174
auto& fn = *((THPCppFunction*)self)->cdata;
for (const auto& hook : fn.tensor_pre_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
}
}
// NOTE [retains_grad_hook PyObject traversal]
// In theory this shouldn't be necessary, because retains_grad_hooks should
// not contain any PyFunctionTensorPreHooks. The alternative is to have a
// check that actually guarantees this.
for (const auto& pair : fn.retains_grad_hooks()) {
if (auto pyhook =
dynamic_cast<PyFunctionTensorPreHook*>(pair.second.get())) {
Py_VISIT(pyhook->dict);
// NOTE [retains_grad_hook PyObject traversal]
// In theory this shouldn't be necessary, because retains_grad_hooks should
// not contain any PyFunctionTensorPreHooks. The alternative is to have a
// check that actually guarantees this.
for (const auto& pair : fn.retains_grad_hooks()) {
if (auto pyhook =
dynamic_cast<PyFunctionTensorPreHook*>(pair.second.get())) {
Py_VISIT(pyhook->dict);
}
}
}
for (const auto& hook : fn.pre_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
for (const auto& hook : fn.pre_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
}
}
for (const auto& hook : fn.post_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
for (const auto& hook : fn.post_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
}
}
return 0;
Expand Down
32 changes: 2 additions & 30 deletions torch/csrc/autograd/python_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,36 +206,8 @@ auto PyNode::name() const -> std::string {

// Traverse and clear are required for supporting Python's GC cycle handling.
static int THPFunction_traverse(THPFunction* self, visitproc visit, void* arg) {
// cdata could be null if the PyNode has already gone out of scope
// by the time we're GC'ing this THPFunction (e.g., the user saved grad_fn
// only).
//
// TODO: I'm not really sure if we're actually obligated to traverse PyObject
// that is stored in PyNode, since we don't really own that C++ object.
if (auto cdata = self->cdata.lock()) {
for (const auto& hook : cdata->tensor_pre_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
}
// See NOTE [retains_grad_hook PyObject traversal]
for (const auto& pair : cdata->retains_grad_hooks()) {
if (auto pyhook =
dynamic_cast<PyFunctionTensorPreHook*>(pair.second.get())) {
Py_VISIT(pyhook->dict);
}
}
for (const auto& hook : cdata->pre_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
}
for (const auto& hook : cdata->post_hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
}
}
// NB: We should not traverse PyObbject stored on PyNode, since we only hold
// as weak reference to the PyNode.
Py_VISIT(self->to_save);
Py_VISIT(self->non_differentiable);
Py_VISIT(self->dirty_tensors);
Expand Down

0 comments on commit 4092828

Please sign in to comment.