Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TransformWrapper pickling fixes #4915

Merged
merged 3 commits into from Oct 8, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 29 additions & 0 deletions lib/matplotlib/tests/test_pickle.py
Expand Up @@ -12,6 +12,7 @@

from matplotlib.testing.decorators import cleanup, image_comparison
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms


def depth_getter(obj,
Expand Down Expand Up @@ -252,6 +253,34 @@ def test_polar():
plt.draw()


class TransformBlob(object):
def __init__(self):
self.identity = mtransforms.IdentityTransform()
self.identity2 = mtransforms.IdentityTransform()
# Force use of the more complex composition.
self.composite = mtransforms.CompositeGenericTransform(
self.identity,
self.identity2)
# Check parent -> child links of TransformWrapper.
self.wrapper = mtransforms.TransformWrapper(self.composite)
# Check child -> parent links of TransformWrapper.
self.composite2 = mtransforms.CompositeGenericTransform(
self.wrapper,
self.identity)


def test_transform():
obj = TransformBlob()
pf = pickle.dumps(obj)
del obj

obj = pickle.loads(pf)
# Check parent -> child links of TransformWrapper.
assert_equal(obj.wrapper._child, obj.composite)
# Check child -> parent links of TransformWrapper.
assert_equal(list(obj.wrapper._parents.values()), [obj.composite2])


if __name__ == '__main__':
import nose
nose.runmodule(argv=['-s'])
18 changes: 14 additions & 4 deletions lib/matplotlib/transforms.py
Expand Up @@ -1533,6 +1533,10 @@ def __init__(self, child):
msg = ("'child' must be an instance of"
" 'matplotlib.transform.Transform'")
raise ValueError(msg)
self._init(child)
self.set_children(child)

def _init(self, child):
Transform.__init__(self)
self.input_dims = child.input_dims
self.output_dims = child.output_dims
Expand All @@ -1548,12 +1552,18 @@ def __str__(self):
return str(self._child)

def __getstate__(self):
# only store the child
return {'child': self._child}
# only store the child and parents
return {
'child': self._child,
# turn the weakkey dictionary into a normal dictionary
'parents': dict(six.iteritems(self._parents))
}

def __setstate__(self, state):
# re-initialise the TransformWrapper with the state's child
self.__init__(state['child'])
self._init(state['child'])
# turn the normal dictionary back into a WeakValueDictionary
self._parents = WeakValueDictionary(state['parents'])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not clear that this is going to work, are there going to be other refs to these instances of the objects will be held so they won't get immediately gc'd

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same method used in the grandparent TransformNode, so if it were going to fail, it would have done so a while ago.

There should be a strong reference from parent -> child assuming the transform's implemented correctly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair enough.


def __repr__(self):
return "TransformWrapper(%r)" % self._child
Expand All @@ -1564,7 +1574,6 @@ def frozen(self):

def _set(self, child):
self._child = child
self.set_children(child)

self.transform = child.transform
self.transform_affine = child.transform_affine
Expand Down Expand Up @@ -1593,6 +1602,7 @@ def set(self, child):
" output dimensions as the current child.")
raise ValueError(msg)

self.set_children(child)
self._set(child)

self._invalid = 0
Expand Down