forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
hooks.py
59 lines (46 loc) · 1.88 KB
/
hooks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from __future__ import absolute_import, division, print_function, unicode_literals
import collections
import weakref
import warnings
class RemovableHandle(object):
"""A handle which provides the capability to remove a hook."""
next_id = 0
def __init__(self, hooks_dict):
self.hooks_dict_ref = weakref.ref(hooks_dict)
self.id = RemovableHandle.next_id
RemovableHandle.next_id += 1
def remove(self):
hooks_dict = self.hooks_dict_ref()
if hooks_dict is not None and self.id in hooks_dict:
del hooks_dict[self.id]
def __getstate__(self):
return (self.hooks_dict_ref(), self.id)
def __setstate__(self, state):
if state[0] is None:
# create a dead reference
self.hooks_dict_ref = weakref.ref(collections.OrderedDict())
else:
self.hooks_dict_ref = weakref.ref(state[0])
self.id = state[1]
RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1)
def __enter__(self):
return self
def __exit__(self, type, value, tb):
self.remove()
def unserializable_hook(f):
"""
Decorator which marks a function as an unserializable hook.
This suppresses warnings that would otherwise arise if you attempt
to serialize a tensor that has a hook.
"""
f.__torch_unserializable__ = True
return f
def warn_if_has_hooks(tensor):
if tensor._backward_hooks:
for k in tensor._backward_hooks:
hook = tensor._backward_hooks[k]
if not hasattr(k, "__torch_unserializable__"):
warnings.warn("backward hook {} on tensor will not be "
"serialized. If this is expected, you can "
"decorate the function with @torch.utils.hooks.unserializable_hook "
"to suppress this warning".format(repr(hook)))