Skip to content

Commit

Permalink
Redis backend chord optimization: Avoid save group at apply and pipel…
Browse files Browse the repository at this point in the history
…ine + O(1) join.

This change is backward incompatible and so is not enabled by default.

To enable this optimization you have to set the `new_join` option
and it must be enabled by all clients and workers part of the chord::

    redis://?new_join=1
  • Loading branch information
ask committed Feb 14, 2014
1 parent 5ec5463 commit f09b041
Show file tree
Hide file tree
Showing 11 changed files with 211 additions and 98 deletions.
1 change: 1 addition & 0 deletions celery/app/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def run(self, header, body, partial_args=(), interval=None,
if eager:
return header.apply(args=partial_args, task_id=group_id)

body.setdefault('chord_size', len(header.tasks))
results = [AsyncResult(prepare_member(task, body, group_id))
for task in header.tasks]

Expand Down
2 changes: 1 addition & 1 deletion celery/app/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def trace_task(uuid, args, kwargs, request=None):
# -* POST *-
if state not in IGNORE_STATES:
if task_request.chord:
on_chord_part_return(task)
on_chord_part_return(task, state, R)
if task_after_return:
task_after_return(
state, retval, uuid, args, kwargs, None,
Expand Down
25 changes: 17 additions & 8 deletions celery/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def process_cleanup(self):
def on_task_call(self, producer, task_id):
return {}

def on_chord_part_return(self, task, propagate=False):
def on_chord_part_return(self, task, state, result, propagate=False):
pass

def fallback_chord_unlock(self, group_id, body, result=None,
Expand Down Expand Up @@ -374,17 +374,26 @@ def incr(self, key):
def expire(self, key, value):
pass

def get_key_for_task(self, task_id):
def get_key_for_task(self, task_id, key=''):
"""Get the cache key for a task by id."""
return self.task_keyprefix + self.key_t(task_id)
key_t = self.key_t
return ''.join([
self.task_keyprefix, key_t(task_id), key_t(key),
])

def get_key_for_group(self, group_id):
def get_key_for_group(self, group_id, key=''):
"""Get the cache key for a group by id."""
return self.group_keyprefix + self.key_t(group_id)
key_t = self.key_t
return ''.join([
self.group_keyprefix, key_t(group_id), key_t(key),
])

def get_key_for_chord(self, group_id):
def get_key_for_chord(self, group_id, key=''):
"""Get the cache key for the chord waiting on group with given id."""
return self.chord_keyprefix + self.key_t(group_id)
key_t = self.key_t
return ''.join([
self.chord_keyprefix, key_t(group_id), key_t(key),
])

def _strip_prefix(self, key):
"""Takes bytes, emits string."""
Expand Down Expand Up @@ -479,7 +488,7 @@ def _apply_chord_incr(self, header, partial_args, group_id, body,
self.save_group(group_id, self.app.GroupResult(group_id, result))
return header(*partial_args, task_id=group_id)

def on_chord_part_return(self, task, propagate=None):
def on_chord_part_return(self, task, state, result, propagate=None):
if not self.implements_incr:
return
app = self.app
Expand Down
72 changes: 69 additions & 3 deletions celery/backends/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
from kombu.utils import cached_property, retry_over_time
from kombu.utils.url import _parse_url

from celery.exceptions import ImproperlyConfigured
from celery import states
from celery.canvas import maybe_signature
from celery.exceptions import ChordError, ImproperlyConfigured
from celery.five import string_t
from celery.utils import deprecated_property
from celery.utils import deprecated_property, strtobool
from celery.utils.functional import dictfilter
from celery.utils.log import get_logger
from celery.utils.timeutils import humanize_seconds
Expand Down Expand Up @@ -56,7 +58,7 @@ class RedisBackend(KeyValueStoreBackend):

def __init__(self, host=None, port=None, db=None, password=None,
expires=None, max_connections=None, url=None,
connection_pool=None, **kwargs):
connection_pool=None, new_join=False, **kwargs):
super(RedisBackend, self).__init__(**kwargs)
conf = self.app.conf
if self.redis is None:
Expand Down Expand Up @@ -90,6 +92,14 @@ def _get(key):
self.url = url
self.expires = self.prepare_expires(expires, type=int)

try:
new_join = strtobool(self.connparams.pop('new_join'))
except KeyError:
pass
if new_join:
self.apply_chord = self._new_chord_apply
self.on_chord_part_return = self._new_chord_return

self.connection_errors, self.channel_errors = get_redis_error_classes()

def _params_from_url(self, url, defaults):
Expand Down Expand Up @@ -165,6 +175,62 @@ def incr(self, key):
def expire(self, key, value):
return self.client.expire(key, value)

def _unpack_chord_result(self, tup, decode,
PROPAGATE_STATES=states.PROPAGATE_STATES):
_, tid, state, retval = decode(tup)
if state in PROPAGATE_STATES:
raise ChordError('Dependency {0} raised {1!r}'.format(tid, retval))
return retval

def _new_chord_apply(self, header, partial_args, group_id, body,
result=None, **options):
# avoids saving the group in the redis db.
return header(*partial_args, task_id=group_id)

def _new_chord_return(self, task, state, result, propagate=None,
PROPAGATE_STATES=states.PROPAGATE_STATES):
app = self.app
if propagate is None:
propagate = self.app.conf.CELERY_CHORD_PROPAGATES
request = task.request
tid, gid = request.id, request.group
if not gid or not tid:
return

client = self.client
jkey = self.get_key_for_group(gid, '.j')
result = self.encode_result(result, state)
_, readycount, _ = client.pipeline() \
.rpush(jkey, self.encode([1, tid, state, result])) \
.llen(jkey) \
.expire(jkey, 86400) \
.execute()

try:
callback = maybe_signature(request.chord, app=app)
total = callback['chord_size']
if readycount >= total:
decode, unpack = self.decode, self._unpack_chord_result
resl, _ = client.pipeline() \
.lrange(jkey, 0, total) \
.delete(jkey) \
.execute()
try:
callback.delay([unpack(tup, decode) for tup in resl])
except Exception as exc:
app._tasks[callback.task].backend.fail_from_current_stack(
callback.id,
exc=ChordError('Callback error: {0!r}'.format(exc)),
)
except ChordError as exc:
app._tasks[callback.task].backend.fail_from_current_stack(
callback.id, exc=exc,
)
except Exception as exc:
app._tasks[callback.task].backend.fail_from_current_stack(
callback.id, exc=ChordError('Join error: {0!r}').format(exc),
)

@property
def ConnectionPool(self):
if self._ConnectionPool is None:
Expand Down
2 changes: 1 addition & 1 deletion celery/tests/app/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ def test_error_mail_disabled(self):

class test_defaults(AppCase):

def test_str_to_bool(self):
def test_strtobool(self):
for s in ('false', 'no', '0'):
self.assertFalse(defaults.strtobool(s))
for s in ('true', 'yes', '1'):
Expand Down
16 changes: 9 additions & 7 deletions celery/tests/backends/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_forget(self):
self.b.forget('SOMExx-N0nex1stant-IDxx-')

def test_on_chord_part_return(self):
self.b.on_chord_part_return(None)
self.b.on_chord_part_return(None, None, None)

def test_apply_chord(self, unlock='celery.chord_unlock'):
self.app.tasks[unlock] = Mock()
Expand Down Expand Up @@ -246,7 +246,7 @@ def setup(self):

def test_on_chord_part_return(self):
assert not self.b.implements_incr
self.b.on_chord_part_return(None)
self.b.on_chord_part_return(None, None, None)

def test_get_store_delete_result(self):
tid = uuid()
Expand Down Expand Up @@ -282,12 +282,14 @@ def test_get_many_times_out(self):
def test_chord_part_return_no_gid(self):
self.b.implements_incr = True
task = Mock()
state = 'SUCCESS'
result = 10
task.request.group = None
self.b.get_key_for_chord = Mock()
self.b.get_key_for_chord.side_effect = AssertionError(
'should not get here',
)
self.assertIsNone(self.b.on_chord_part_return(task))
self.assertIsNone(self.b.on_chord_part_return(task, state, result))

@contextmanager
def _chord_part_context(self, b):
Expand Down Expand Up @@ -315,14 +317,14 @@ def callback(result):

def test_chord_part_return_propagate_set(self):
with self._chord_part_context(self.b) as (task, deps, _):
self.b.on_chord_part_return(task, propagate=True)
self.b.on_chord_part_return(task, 'SUCCESS', 10, propagate=True)
self.assertFalse(self.b.expire.called)
deps.delete.assert_called_with()
deps.join_native.assert_called_with(propagate=True, timeout=3.0)

def test_chord_part_return_propagate_default(self):
with self._chord_part_context(self.b) as (task, deps, _):
self.b.on_chord_part_return(task, propagate=None)
self.b.on_chord_part_return(task, 'SUCCESS', 10, propagate=None)
self.assertFalse(self.b.expire.called)
deps.delete.assert_called_with()
deps.join_native.assert_called_with(
Expand All @@ -334,7 +336,7 @@ def test_chord_part_return_join_raises_internal(self):
with self._chord_part_context(self.b) as (task, deps, callback):
deps._failed_join_report = lambda: iter([])
deps.join_native.side_effect = KeyError('foo')
self.b.on_chord_part_return(task)
self.b.on_chord_part_return(task, 'SUCCESS', 10)
self.assertTrue(self.b.fail_from_current_stack.called)
args = self.b.fail_from_current_stack.call_args
exc = args[1]['exc']
Expand All @@ -348,7 +350,7 @@ def test_chord_part_return_join_raises_task(self):
self.app.AsyncResult('culprit'),
])
deps.join_native.side_effect = KeyError('foo')
b.on_chord_part_return(task)
b.on_chord_part_return(task, 'SUCCESS', 10)
self.assertTrue(b.fail_from_current_stack.called)
args = b.fail_from_current_stack.call_args
exc = args[1]['exc']
Expand Down
4 changes: 2 additions & 2 deletions celery/tests/backends/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ def test_on_chord_part_return(self, restore):
tb.apply_chord(group(app=self.app), (), gid, {}, result=res)

self.assertFalse(deps.join_native.called)
tb.on_chord_part_return(task)
tb.on_chord_part_return(task, 'SUCCESS', 10)
self.assertFalse(deps.join_native.called)

tb.on_chord_part_return(task)
tb.on_chord_part_return(task, 'SUCCESS', 10)
deps.join_native.assert_called_with(propagate=True, timeout=3.0)
deps.delete.assert_called_with()

Expand Down
Loading

0 comments on commit f09b041

Please sign in to comment.