Skip to content

Commit

Permalink
More tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ask committed Jun 4, 2012
1 parent 313b71f commit 8af6e58
Show file tree
Hide file tree
Showing 12 changed files with 394 additions and 30 deletions.
38 changes: 21 additions & 17 deletions celery/app/amqp.py
Expand Up @@ -22,6 +22,7 @@
from celery.utils import cached_property, uuid from celery.utils import cached_property, uuid
from celery.utils.text import indent as textindent from celery.utils.text import indent as textindent


from . import app_or_default
from . import routes as _routes from . import routes as _routes


#: Human readable queue declaration. #: Human readable queue declaration.
Expand All @@ -48,7 +49,7 @@ class Queues(dict):
#: The rest of the queues are then used for routing only. #: The rest of the queues are then used for routing only.
_consume_from = None _consume_from = None


def __init__(self, queues, default_exchange=None, create_missing=True): def __init__(self, queues=None, default_exchange=None, create_missing=True):
dict.__init__(self) dict.__init__(self)
self.aliases = WeakValueDictionary() self.aliases = WeakValueDictionary()
self.default_exchange = default_exchange self.default_exchange = default_exchange
Expand All @@ -65,11 +66,9 @@ def __getitem__(self, name):
return dict.__getitem__(self, name) return dict.__getitem__(self, name)


def __setitem__(self, name, queue): def __setitem__(self, name, queue):
if self.default_exchange: if self.default_exchange and (not queue.exchange or
if not queue.exchange or not queue.exchange.name: not queue.exchange.name):
queue.exchange = self.default_exchange queue.exchange = self.default_exchange
if queue.exchange.type == 'direct' and not queue.routing_key:
queue.routing_key = name
dict.__setitem__(self, name, queue) dict.__setitem__(self, name, queue)
if queue.alias: if queue.alias:
self.aliases[queue.alias] = queue self.aliases[queue.alias] = queue
Expand Down Expand Up @@ -135,19 +134,16 @@ def consume_from(self):




class TaskProducer(Producer): class TaskProducer(Producer):
app = None
auto_declare = False auto_declare = False
retry = False retry = False
retry_policy = None retry_policy = None


def __init__(self, channel=None, exchange=None, *args, **kwargs): def __init__(self, channel=None, exchange=None, *args, **kwargs):
self.app = kwargs.get("app") or self.app
self.retry = kwargs.pop("retry", self.retry) self.retry = kwargs.pop("retry", self.retry)
self.retry_policy = kwargs.pop("retry_policy", self.retry_policy = kwargs.pop("retry_policy",
self.retry_policy or {}) self.retry_policy or {})
exchange = exchange or self.exchange exchange = exchange or self.exchange
if not isinstance(exchange, Exchange):
exchange = Exchange(exchange,
kwargs.get("exchange_type") or self.exchange_type)
self.queues = self.app.amqp.queues # shortcut self.queues = self.app.amqp.queues # shortcut
super(TaskProducer, self).__init__(channel, exchange, *args, **kwargs) super(TaskProducer, self).__init__(channel, exchange, *args, **kwargs)


Expand Down Expand Up @@ -216,7 +212,21 @@ def delay_task(self, task_name, task_args=None, task_kwargs=None,
expires=expires, expires=expires,
queue=queue) queue=queue)
return task_id return task_id
TaskPublisher = TaskProducer # compat
class TaskPublisher(TaskProducer):
"""Deprecated version of :class:`TaskProducer`."""

def __init__(self, channel=None, exchange=None, *args, **kwargs):
self.app = app_or_default(kwargs.pop("app", self.app))
self.retry = kwargs.pop("retry", self.retry)
self.retry_policy = kwargs.pop("retry_policy",
self.retry_policy or {})
exchange = exchange or self.exchange
if not isinstance(exchange, Exchange):
exchange = Exchange(exchange,
kwargs.pop("exchange_type", "direct"))
self.queues = self.app.amqp.queues # shortcut
super(TaskPublisher, self).__init__(channel, exchange, *args, **kwargs)




class TaskConsumer(Consumer): class TaskConsumer(Consumer):
Expand Down Expand Up @@ -267,11 +277,6 @@ def TaskConsumer(self):
reverse="amqp.TaskConsumer") reverse="amqp.TaskConsumer")
get_task_consumer = TaskConsumer # XXX compat get_task_consumer = TaskConsumer # XXX compat


def queue_or_default(self, q):
if q:
return self.queues[q] if not isinstance(q, Queue) else q
return self.default_queue

@cached_property @cached_property
def TaskProducer(self): def TaskProducer(self):
"""Returns publisher used to send tasks. """Returns publisher used to send tasks.
Expand All @@ -283,7 +288,6 @@ def TaskProducer(self):
return self.app.subclass_with_self(TaskProducer, return self.app.subclass_with_self(TaskProducer,
reverse="amqp.TaskProducer", reverse="amqp.TaskProducer",
exchange=self.default_exchange, exchange=self.default_exchange,
exchange_type=self.default_exchange.type,
routing_key=conf.CELERY_DEFAULT_ROUTING_KEY, routing_key=conf.CELERY_DEFAULT_ROUTING_KEY,
serializer=conf.CELERY_TASK_SERIALIZER, serializer=conf.CELERY_TASK_SERIALIZER,
compression=conf.CELERY_MESSAGE_COMPRESSION, compression=conf.CELERY_MESSAGE_COMPRESSION,
Expand Down
3 changes: 2 additions & 1 deletion celery/app/builtins.py
Expand Up @@ -121,7 +121,8 @@ def run(self, tasks, result, setid):
[subtask(task).apply(taskset_id=setid) [subtask(task).apply(taskset_id=setid)
for task in tasks]) for task in tasks])
with app.default_producer() as pub: with app.default_producer() as pub:
[subtask(task).apply_async(taskset_id=setid, publisher=pub) [subtask(task).apply_async(taskset_id=setid, publisher=pub,
add_to_parent=False)
for task in tasks] for task in tasks]
parent = get_current_worker_task() parent = get_current_worker_task()
if parent: if parent:
Expand Down
14 changes: 10 additions & 4 deletions celery/app/task.py
Expand Up @@ -376,7 +376,8 @@ def delay(self, *args, **kwargs):


def apply_async(self, args=None, kwargs=None, def apply_async(self, args=None, kwargs=None,
task_id=None, producer=None, connection=None, router=None, task_id=None, producer=None, connection=None, router=None,
link=None, link_error=None, publisher=None, **options): link=None, link_error=None, publisher=None, add_to_parent=True,
**options):
"""Apply tasks asynchronously by sending a message. """Apply tasks asynchronously by sending a message.
:keyword args: The positional arguments to pass on to the :keyword args: The positional arguments to pass on to the
Expand Down Expand Up @@ -459,6 +460,10 @@ def apply_async(self, args=None, kwargs=None,
if an error occurs while executing the task. if an error occurs while executing the task.
:keyword producer: :class:~@amqp.TaskProducer` instance to use. :keyword producer: :class:~@amqp.TaskProducer` instance to use.
:keyword add_to_parent: If set to True (default) and the task
is applied while executing another task, then the result
will be appended to the parent tasks ``request.children``
attribute.
:keyword publisher: Deprecated alias to ``producer``. :keyword publisher: Deprecated alias to ``producer``.
.. note:: .. note::
Expand Down Expand Up @@ -495,9 +500,10 @@ def apply_async(self, args=None, kwargs=None,
errbacks=maybe_list(link_error), errbacks=maybe_list(link_error),
**options) **options)
result = self.AsyncResult(task_id) result = self.AsyncResult(task_id)
parent = get_current_worker_task() if add_to_parent:
if parent: parent = get_current_worker_task()
parent.request.children.append(result) if parent:
parent.request.children.append(result)
return result return result


def retry(self, args=None, kwargs=None, exc=None, throw=True, def retry(self, args=None, kwargs=None, exc=None, throw=True,
Expand Down
2 changes: 1 addition & 1 deletion celery/bin/celeryd.py
Expand Up @@ -153,7 +153,7 @@ def run(self, *args, **kwargs):
if loglevel: if loglevel:
try: try:
kwargs["loglevel"] = mlevel(loglevel) kwargs["loglevel"] = mlevel(loglevel)
except KeyError: except KeyError: # pragma: no cover
self.die("Unknown level %r. Please use one of %s." % ( self.die("Unknown level %r. Please use one of %s." % (
loglevel, "|".join(l for l in LOG_LEVELS.keys() loglevel, "|".join(l for l in LOG_LEVELS.keys()
if isinstance(l, basestring)))) if isinstance(l, basestring))))
Expand Down
1 change: 1 addition & 0 deletions celery/tests/__init__.py
Expand Up @@ -16,6 +16,7 @@
os.environ["EVENTLET_NOPATCH"] = "yes" os.environ["EVENTLET_NOPATCH"] = "yes"
os.environ["GEVENT_NOPATCH"] = "yes" os.environ["GEVENT_NOPATCH"] = "yes"
os.environ["KOMBU_DISABLE_LIMIT_PROTECTION"] = "yes" os.environ["KOMBU_DISABLE_LIMIT_PROTECTION"] = "yes"
os.environ["CELERY_BROKER_URL"] = "memory://"


try: try:
WindowsError = WindowsError # noqa WindowsError = WindowsError # noqa
Expand Down
38 changes: 37 additions & 1 deletion celery/tests/app/test_amqp.py
@@ -1,9 +1,10 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import with_statement from __future__ import with_statement


from kombu import Exchange, Queue
from mock import Mock from mock import Mock


from celery.app.amqp import Queues from celery.app.amqp import Queues, TaskPublisher
from celery.tests.utils import AppCase from celery.tests.utils import AppCase




Expand Down Expand Up @@ -37,6 +38,22 @@ def test_publish_no_retry(self):
self.assertFalse(pub.connection.ensure.call_count) self.assertFalse(pub.connection.ensure.call_count)




class test_compat_TaskPublisher(AppCase):

def test_compat_exchange_is_string(self):
producer = TaskPublisher(exchange="foo", app=self.app)
self.assertIsInstance(producer.exchange, Exchange)
self.assertEqual(producer.exchange.name, "foo")
self.assertEqual(producer.exchange.type, "direct")
producer = TaskPublisher(exchange="foo", exchange_type="topic",
app=self.app)
self.assertEqual(producer.exchange.type, "topic")

def test_compat_exchange_is_Exchange(self):
producer = TaskPublisher(exchange=Exchange("foo"))
self.assertEqual(producer.exchange.name, "foo")


class test_PublisherPool(AppCase): class test_PublisherPool(AppCase):


def test_setup_nolimit(self): def test_setup_nolimit(self):
Expand Down Expand Up @@ -100,3 +117,22 @@ def test_queues_format(self):


def test_with_defaults(self): def test_with_defaults(self):
self.assertEqual(Queues(None), {}) self.assertEqual(Queues(None), {})

def test_add(self):
q = Queues()
q.add("foo", exchange="ex", routing_key="rk")
self.assertIn("foo", q)
self.assertIsInstance(q["foo"], Queue)
self.assertEqual(q["foo"].routing_key, "rk")

def test_add_default_exchange(self):
ex = Exchange("fff", "fanout")
q = Queues(default_exchange=ex)
q.add(Queue("foo"))
self.assertEqual(q["foo"].exchange, ex)

def test_alias(self):
q = Queues()
q.add(Queue("foo", alias="barfoo"))
self.assertIs(q["barfoo"], q["foo"])

0 comments on commit 8af6e58

Please sign in to comment.