Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CallbackRegistry fix #4118

Merged
merged 6 commits into from Mar 12, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
52 changes: 42 additions & 10 deletions lib/matplotlib/cbook.py
Expand Up @@ -361,9 +361,14 @@ class _BoundMethodProxy(object):
Minor bugfixes by Michael Droettboom
'''
def __init__(self, cb):
self._hash = hash(cb)
self._destroy_callbacks = []
try:
try:
self.inst = ref(cb.im_self)
if six.PY3:
self.inst = ref(cb.__self__, self._destroy)
else:
self.inst = ref(cb.im_self, self._destroy)
except TypeError:
self.inst = None
if six.PY3:
Expand All @@ -377,6 +382,16 @@ def __init__(self, cb):
self.func = cb
self.klass = None

def add_destroy_callback(self, callback):
self._destroy_callbacks.append(_BoundMethodProxy(callback))

def _destroy(self, wk):
for callback in self._destroy_callbacks:
try:
callback(self)
except ReferenceError:
pass

def __getstate__(self):
d = self.__dict__.copy()
# de-weak reference inst
Expand Down Expand Up @@ -433,6 +448,9 @@ def __ne__(self, other):
'''
return not self.__eq__(other)

def __hash__(self):
return self._hash


class CallbackRegistry(object):
"""
Expand Down Expand Up @@ -492,17 +510,32 @@ def connect(self, s, func):
func will be called
"""
self._func_cid_map.setdefault(s, WeakKeyDictionary())
if func in self._func_cid_map[s]:
return self._func_cid_map[s][func]
# Note proxy not needed in python 3.
# TODO rewrite this when support for python2.x gets dropped.
proxy = _BoundMethodProxy(func)
if proxy in self._func_cid_map[s]:
return self._func_cid_map[s][proxy]

proxy.add_destroy_callback(self._remove_proxy)
self._cid += 1
cid = self._cid
self._func_cid_map[s][func] = cid
self._func_cid_map[s][proxy] = cid
self.callbacks.setdefault(s, dict())
proxy = _BoundMethodProxy(func)
self.callbacks[s][cid] = proxy
return cid

def _remove_proxy(self, proxy):
for signal, proxies in list(six.iteritems(self._func_cid_map)):
try:
del self.callbacks[signal][proxies[proxy]]
except KeyError:
pass

if len(self.callbacks[signal]) == 0:
del self.callbacks[signal]
del self._func_cid_map[signal]


def disconnect(self, cid):
"""
disconnect the callback registered with callback id *cid*
Expand All @@ -513,7 +546,7 @@ def disconnect(self, cid):
except KeyError:
continue
else:
for category, functions in list(
for signal, functions in list(
six.iteritems(self._func_cid_map)):
for function, value in list(six.iteritems(functions)):
if value == cid:
Expand All @@ -527,11 +560,10 @@ def process(self, s, *args, **kwargs):
"""
if s in self.callbacks:
for cid, proxy in list(six.iteritems(self.callbacks[s])):
# Clean out dead references
if proxy.inst is not None and proxy.inst() is None:
del self.callbacks[s][cid]
else:
try:
proxy(*args, **kwargs)
except ReferenceError:
self._remove_proxy(proxy)


class Scheduler(threading.Thread):
Expand Down
46 changes: 45 additions & 1 deletion lib/matplotlib/tests/test_cbook.py
Expand Up @@ -8,7 +8,7 @@
import numpy as np
from numpy.testing.utils import (assert_array_equal, assert_approx_equal,
assert_array_almost_equal)
from nose.tools import assert_equal, raises, assert_true
from nose.tools import assert_equal, assert_not_equal, raises, assert_true

import matplotlib.cbook as cbook
import matplotlib.colors as mcolors
Expand Down Expand Up @@ -243,3 +243,47 @@ def test_label_error(self):
def test_bad_dims(self):
data = np.random.normal(size=(34, 34, 34))
results = cbook.boxplot_stats(data)


class Test_callback_registry(object):
def setup(self):
self.signal = 'test'
self.callbacks = cbook.CallbackRegistry()

def connect(self, s, func):
return self.callbacks.connect(s, func)

def is_empty(self):
assert_equal(self.callbacks._func_cid_map, {})
assert_equal(self.callbacks.callbacks, {})

def is_not_empty(self):
assert_not_equal(self.callbacks._func_cid_map, {})
assert_not_equal(self.callbacks.callbacks, {})

def test_callback_complete(self):
# ensure we start with an empty registry
self.is_empty()

# create a class for testing
mini_me = Test_callback_registry()

# test that we can add a callback
cid1 = self.connect(self.signal, mini_me.dummy)
assert_equal(type(cid1), int)
self.is_not_empty()

# test that we don't add a second callback
cid2 = self.connect(self.signal, mini_me.dummy)
assert_equal(cid1, cid2)
self.is_not_empty()
assert_equal(len(self.callbacks._func_cid_map), 1)
assert_equal(len(self.callbacks.callbacks), 1)

del mini_me

# check we now have no callbacks registered
self.is_empty()

def dummy(self):
pass