Permalink
Browse files

amqp related improvements

  • Loading branch information...
1 parent 293a587 commit fa957dbc29537d56262faeecc3d5819a4b1b0595 @ask committed Jun 1, 2012
View
@@ -319,9 +319,10 @@ def router(self):
return self.Router()
@cached_property
- def publisher_pool(self):
+ def producer_pool(self):
return ProducerPool(self.app.pool, limit=self.app.pool.limit,
Producer=self.TaskProducer)
+ publisher_pool = producer_pool # compat alias
@cached_property
def default_exchange(self):
View
@@ -223,7 +223,7 @@ def default_producer(self, producer=None):
if producer:
yield producer
else:
- with self.amqp.publisher_pool.acquire(block=True) as producer:
+ with self.amqp.producer_pool.acquire(block=True) as producer:
yield producer
def with_default_connection(self, fun):
View
@@ -9,8 +9,8 @@
:license: BSD, see LICENSE for more details.
"""
-
from __future__ import absolute_import
+from __future__ import with_statement
import logging
import sys
@@ -460,8 +460,8 @@ def delay(self, *args, **kwargs):
return self.apply_async(args, kwargs)
def apply_async(self, args=None, kwargs=None,
- task_id=None, publisher=None, connection=None,
- router=None, link=None, link_error=None, **options):
+ task_id=None, producer=None, connection=None, router=None,
+ link=None, link_error=None, publisher=None, **options):
"""Apply tasks asynchronously by sending a message.
:keyword args: The positional arguments to pass on to the
@@ -494,7 +494,7 @@ def apply_async(self, args=None, kwargs=None,
in the event of connection loss or failure. Default
is taken from the :setting:`CELERY_TASK_PUBLISH_RETRY`
setting. Note you need to handle the
- publisher/connection manually for this to work.
+ producer/connection manually for this to work.
:keyword retry_policy: Override the retry policy used. See the
:setting:`CELERY_TASK_PUBLISH_RETRY` setting.
@@ -543,11 +543,15 @@ def apply_async(self, args=None, kwargs=None,
:keyword link_error: A single, or a list of subtasks to apply
if an error occurs while executing the task.
+ :keyword producer: :class:~@amqp.TaskProducer` instance to use.
+ :keyword publisher: Deprecated alias to ``producer``.
+
.. note::
If the :setting:`CELERY_ALWAYS_EAGER` setting is set, it will
be replaced by a local :func:`apply` call instead.
"""
+ producer = producer or publisher
app = self._get_app()
router = router or self.app.amqp.router
conf = app.conf
@@ -562,24 +566,19 @@ def apply_async(self, args=None, kwargs=None,
options = router.route(options, self.name, args, kwargs)
if connection:
- publisher = app.amqp.TaskProducer(connection)
- publish = publisher or app.amqp.publisher_pool.acquire(block=True)
- evd = None
- if conf.CELERY_SEND_TASK_SENT_EVENT:
- evd = app.events.Dispatcher(channel=publish.channel,
- buffer_while_offline=False)
-
- try:
- task_id = publish.delay_task(self.name, args, kwargs,
- task_id=task_id,
- event_dispatcher=evd,
- callbacks=maybe_list(link),
- errbacks=maybe_list(link_error),
- **options)
- finally:
- if not publisher:
- publish.release()
-
+ producer = app.amqp.TaskProducer(connection)
+ with app.default_producer(producer) as P:
+ evd = None
+ if conf.CELERY_SEND_TASK_SENT_EVENT:
+ evd = app.events.Dispatcher(channel=P.channel,
+ buffer_while_offline=False)
+
+ task_id = P.delay_task(self.name, args, kwargs,
+ task_id=task_id,
+ event_dispatcher=evd,
+ callbacks=maybe_list(link),
+ errbacks=maybe_list(link_error),
+ **options)
result = self.AsyncResult(task_id)
parent = get_current_worker_task()
if parent:
View
@@ -6,16 +6,17 @@
import threading
import time
-from itertools import count
-
from kombu.entity import Exchange, Queue
from kombu.messaging import Consumer, Producer
from celery import states
from celery.exceptions import TimeoutError
+from celery.utils.log import get_logger
from .base import BaseDictBackend
+logger = get_logger(__name__)
+
class BacklogLimitExceeded(Exception):
"""Too much state history to fast-forward."""
@@ -39,6 +40,13 @@ class AMQPBackend(BaseDictBackend):
supports_native_join = True
+ retry_policy = {
+ "max_retries": 20,
+ "interval_start": 0,
+ "interval_step": 1,
+ "interval_max": 1,
+ }
+
def __init__(self, connection=None, exchange=None, exchange_type=None,
persistent=None, serializer=None, auto_delete=True,
**kwargs):
@@ -83,19 +91,6 @@ def _create_binding(self, task_id):
auto_delete=self.auto_delete,
queue_arguments=self.queue_arguments)
- def _create_producer(self, task_id, connection):
- self._create_binding(task_id)(connection.default_channel).declare()
- return self.Producer(connection, exchange=self.exchange,
- routing_key=task_id.replace("-", ""),
- serializer=self.serializer)
-
- def _create_consumer(self, bindings, channel):
- return self.Consumer(channel, bindings, no_ack=True)
-
- def _publish_result(self, connection, task_id, meta):
- # cache single channel
- self._create_producer(task_id, connection).publish(meta)
-
def revive(self, channel):
pass
@@ -104,27 +99,18 @@ def _store_result(self, task_id, result, status, traceback=None,
interval_max=1):
"""Send task return value and status."""
with self.mutex:
- with self.app.pool.acquire(block=True) as conn:
-
- def errback(error, delay):
- print("Couldn't send result for %r: %r. Retry in %rs." % (
- task_id, error, delay))
-
- send = conn.ensure(self, self._publish_result,
- max_retries=max_retries,
- errback=errback,
- interval_start=interval_start,
- interval_step=interval_step,
- interval_max=interval_max)
- send(conn, task_id, {"task_id": task_id, "status": status,
- "result": self.encode_result(result, status),
- "traceback": traceback,
- "children": self.current_task_children()})
+ with self.app.amqp.producer_pool.acquire(block=True) as pub:
+ pub.publish({"task_id": task_id, "status": status,
+ "result": self.encode_result(result, status),
+ "traceback": traceback,
+ "children": self.current_task_children()},
+ exchange=self.exchange,
+ routing_key=task_id.replace("-", ""),
+ serializer=self.serializer,
+ retry=True, retry_policy=self.retry_policy,
+ declare=[self._create_binding(task_id)])
return result
- def get_task_meta(self, task_id, cache=True):
- return self.poll(task_id)
-
def wait_for(self, task_id, timeout=None, cache=True, propagate=True,
**kwargs):
cached_meta = self._cache.get(task_id)
@@ -147,23 +133,30 @@ def wait_for(self, task_id, timeout=None, cache=True, propagate=True,
else:
return self.wait_for(task_id, timeout, cache)
- def poll(self, task_id, backlog_limit=100):
+ def get_task_meta(self, task_id, backlog_limit=1000):
+ # Polling and using basic_get
with self.app.pool.acquire_channel(block=True) as (_, channel):
binding = self._create_binding(task_id)(channel)
binding.declare()
latest, acc = None, None
- for i in count(): # fast-forward
+ for i in xrange(backlog_limit):
latest, acc = acc, binding.get(no_ack=True)
- if not acc:
+ if not acc: # no more messages
break
- if i > backlog_limit:
- raise self.BacklogLimitExceeded(task_id)
+ else:
+ raise self.BacklogLimitExceeded(task_id)
+
if latest:
+ # new state to report
payload = self._cache[task_id] = latest.payload
return payload
- elif task_id in self._cache: # use previously received state.
- return self._cache[task_id]
- return {"status": states.PENDING, "result": None}
+ else:
+ # no new state, use previous
+ try:
+ return self._cache[task_id]
+ except KeyError:
+ # result probably pending.
+ return {"status": states.PENDING, "result": None}
def drain_events(self, connection, consumer, timeout=None, now=time.time):
wait = connection.drain_events
@@ -190,13 +183,12 @@ def callback(meta, message):
def consume(self, task_id, timeout=None):
with self.app.pool.acquire_channel(block=True) as (conn, channel):
binding = self._create_binding(task_id)
- with self._create_consumer(binding, channel) as consumer:
+ with self.Consumer(channel, binding, no_ack=True) as consumer:
return self.drain_events(conn, consumer, timeout).values()[0]
def get_many(self, task_ids, timeout=None, **kwargs):
with self.app.pool.acquire_channel(block=True) as (conn, channel):
ids = set(task_ids)
- cached_ids = set()
for task_id in ids:
try:
cached = self._cache[task_id]
@@ -205,11 +197,10 @@ def get_many(self, task_ids, timeout=None, **kwargs):
else:
if cached["status"] in states.READY_STATES:
yield task_id, cached
- cached_ids.add(task_id)
+ ids.discard(task_id)
- ids ^= cached_ids
bindings = [self._create_binding(task_id) for task_id in task_ids]
- with self._create_consumer(bindings, channel) as consumer:
+ with self.Consumer(channel, bindings, no_ack=True) as consumer:
while ids:
r = self.drain_events(conn, consumer, timeout)
ids ^= set(r)
@@ -238,12 +229,11 @@ def delete_taskset(self, taskset_id):
"delete_taskset is not supported by this backend.")
def __reduce__(self, args=(), kwargs={}):
- kwargs.update(
- dict(connection=self._connection,
- exchange=self.exchange.name,
- exchange_type=self.exchange.type,
- persistent=self.persistent,
- serializer=self.serializer,
- auto_delete=self.auto_delete,
- expires=self.expires))
+ kwargs.update(connection=self._connection,
+ exchange=self.exchange.name,
+ exchange_type=self.exchange.type,
+ persistent=self.persistent,
+ serializer=self.serializer,
+ auto_delete=self.auto_delete,
+ expires=self.expires)
return super(AMQPBackend, self).__reduce__(args, kwargs)
@@ -261,7 +261,7 @@ def State(self):
@contextmanager
def default_dispatcher(self, hostname=None, enabled=True,
buffer_while_offline=False):
- with self.app.amqp.publisher_pool.acquire(block=True) as pub:
+ with self.app.amqp.producer_pool.acquire(block=True) as pub:
with self.Dispatcher(pub.connection, hostname, enabled,
pub.channel, buffer_while_offline) as d:
yield d
View
@@ -290,7 +290,7 @@ def trace_task(task, uuid, args, kwargs, request=None, **opts):
def trace_task_ret(task, uuid, args, kwargs, request):
- task.__tracer__(uuid, args, kwargs, request)
+ return task.__tracer__(uuid, args, kwargs, request)[0]
def eager_trace_task(task, uuid, args, kwargs, request=None, **opts):
@@ -46,9 +46,9 @@ def test_setup_nolimit(self):
delattr(self.app, "_pool")
except AttributeError:
pass
- self.app.amqp.__dict__.pop("publisher_pool", None)
+ self.app.amqp.__dict__.pop("producer_pool", None)
try:
- pool = self.app.amqp.publisher_pool
+ pool = self.app.amqp.producer_pool
self.assertEqual(pool.limit, self.app.pool.limit)
self.assertFalse(pool._resource.queue)
@@ -68,9 +68,9 @@ def test_setup(self):
delattr(self.app, "_pool")
except AttributeError:
pass
- self.app.amqp.__dict__.pop("publisher_pool", None)
+ self.app.amqp.__dict__.pop("producer_pool", None)
try:
- pool = self.app.amqp.publisher_pool
+ pool = self.app.amqp.producer_pool
self.assertEqual(pool.limit, self.app.pool.limit)
self.assertTrue(pool._resource.queue)
View
@@ -355,7 +355,8 @@ def on_retry(self, exc_info):
if _does_info:
info(self.retry_msg.strip(), {
"id": self.id, "name": self.name,
- "exc": safe_repr(exc_info.exception.exc)}, exc_info=exc_info.exc_info)
+ "exc": safe_repr(exc_info.exception.exc)},
+ exc_info=exc_info.exc_info)
def on_failure(self, exc_info):
"""Handler called if the task raised an exception."""

0 comments on commit fa957db

Please sign in to comment.