Permalink
Browse files

More tests

  • Loading branch information...
1 parent 313b71f commit 8af6e5859891ddbcab744ec60895432dc697729c @ask committed Jun 4, 2012
View
@@ -22,6 +22,7 @@
from celery.utils import cached_property, uuid
from celery.utils.text import indent as textindent
+from . import app_or_default
from . import routes as _routes
#: Human readable queue declaration.
@@ -48,7 +49,7 @@ class Queues(dict):
#: The rest of the queues are then used for routing only.
_consume_from = None
- def __init__(self, queues, default_exchange=None, create_missing=True):
+ def __init__(self, queues=None, default_exchange=None, create_missing=True):
dict.__init__(self)
self.aliases = WeakValueDictionary()
self.default_exchange = default_exchange
@@ -65,11 +66,9 @@ def __getitem__(self, name):
return dict.__getitem__(self, name)
def __setitem__(self, name, queue):
- if self.default_exchange:
- if not queue.exchange or not queue.exchange.name:
- queue.exchange = self.default_exchange
- if queue.exchange.type == 'direct' and not queue.routing_key:
- queue.routing_key = name
+ if self.default_exchange and (not queue.exchange or
+ not queue.exchange.name):
+ queue.exchange = self.default_exchange
dict.__setitem__(self, name, queue)
if queue.alias:
self.aliases[queue.alias] = queue
@@ -135,19 +134,16 @@ def consume_from(self):
class TaskProducer(Producer):
+ app = None
auto_declare = False
retry = False
retry_policy = None
def __init__(self, channel=None, exchange=None, *args, **kwargs):
- self.app = kwargs.get("app") or self.app
self.retry = kwargs.pop("retry", self.retry)
self.retry_policy = kwargs.pop("retry_policy",
self.retry_policy or {})
exchange = exchange or self.exchange
- if not isinstance(exchange, Exchange):
- exchange = Exchange(exchange,
- kwargs.get("exchange_type") or self.exchange_type)
self.queues = self.app.amqp.queues # shortcut
super(TaskProducer, self).__init__(channel, exchange, *args, **kwargs)
@@ -216,7 +212,21 @@ def delay_task(self, task_name, task_args=None, task_kwargs=None,
expires=expires,
queue=queue)
return task_id
-TaskPublisher = TaskProducer # compat
+
+class TaskPublisher(TaskProducer):
+ """Deprecated version of :class:`TaskProducer`."""
+
+ def __init__(self, channel=None, exchange=None, *args, **kwargs):
+ self.app = app_or_default(kwargs.pop("app", self.app))
+ self.retry = kwargs.pop("retry", self.retry)
+ self.retry_policy = kwargs.pop("retry_policy",
+ self.retry_policy or {})
+ exchange = exchange or self.exchange
+ if not isinstance(exchange, Exchange):
+ exchange = Exchange(exchange,
+ kwargs.pop("exchange_type", "direct"))
+ self.queues = self.app.amqp.queues # shortcut
+ super(TaskPublisher, self).__init__(channel, exchange, *args, **kwargs)
class TaskConsumer(Consumer):
@@ -267,11 +277,6 @@ def TaskConsumer(self):
reverse="amqp.TaskConsumer")
get_task_consumer = TaskConsumer # XXX compat
- def queue_or_default(self, q):
- if q:
- return self.queues[q] if not isinstance(q, Queue) else q
- return self.default_queue
-
@cached_property
def TaskProducer(self):
"""Returns publisher used to send tasks.
@@ -283,7 +288,6 @@ def TaskProducer(self):
return self.app.subclass_with_self(TaskProducer,
reverse="amqp.TaskProducer",
exchange=self.default_exchange,
- exchange_type=self.default_exchange.type,
routing_key=conf.CELERY_DEFAULT_ROUTING_KEY,
serializer=conf.CELERY_TASK_SERIALIZER,
compression=conf.CELERY_MESSAGE_COMPRESSION,
View
@@ -121,7 +121,8 @@ def run(self, tasks, result, setid):
[subtask(task).apply(taskset_id=setid)
for task in tasks])
with app.default_producer() as pub:
- [subtask(task).apply_async(taskset_id=setid, publisher=pub)
+ [subtask(task).apply_async(taskset_id=setid, publisher=pub,
+ add_to_parent=False)
for task in tasks]
parent = get_current_worker_task()
if parent:
View
@@ -376,7 +376,8 @@ def delay(self, *args, **kwargs):
def apply_async(self, args=None, kwargs=None,
task_id=None, producer=None, connection=None, router=None,
- link=None, link_error=None, publisher=None, **options):
+ link=None, link_error=None, publisher=None, add_to_parent=True,
+ **options):
"""Apply tasks asynchronously by sending a message.
:keyword args: The positional arguments to pass on to the
@@ -459,6 +460,10 @@ def apply_async(self, args=None, kwargs=None,
if an error occurs while executing the task.
:keyword producer: :class:~@amqp.TaskProducer` instance to use.
+ :keyword add_to_parent: If set to True (default) and the task
+ is applied while executing another task, then the result
+ will be appended to the parent tasks ``request.children``
+ attribute.
:keyword publisher: Deprecated alias to ``producer``.
.. note::
@@ -495,9 +500,10 @@ def apply_async(self, args=None, kwargs=None,
errbacks=maybe_list(link_error),
**options)
result = self.AsyncResult(task_id)
- parent = get_current_worker_task()
- if parent:
- parent.request.children.append(result)
+ if add_to_parent:
+ parent = get_current_worker_task()
+ if parent:
+ parent.request.children.append(result)
return result
def retry(self, args=None, kwargs=None, exc=None, throw=True,
View
@@ -153,7 +153,7 @@ def run(self, *args, **kwargs):
if loglevel:
try:
kwargs["loglevel"] = mlevel(loglevel)
- except KeyError:
+ except KeyError: # pragma: no cover
self.die("Unknown level %r. Please use one of %s." % (
loglevel, "|".join(l for l in LOG_LEVELS.keys()
if isinstance(l, basestring))))
View
@@ -16,6 +16,7 @@
os.environ["EVENTLET_NOPATCH"] = "yes"
os.environ["GEVENT_NOPATCH"] = "yes"
os.environ["KOMBU_DISABLE_LIMIT_PROTECTION"] = "yes"
+os.environ["CELERY_BROKER_URL"] = "memory://"
try:
WindowsError = WindowsError # noqa
@@ -1,9 +1,10 @@
from __future__ import absolute_import
from __future__ import with_statement
+from kombu import Exchange, Queue
from mock import Mock
-from celery.app.amqp import Queues
+from celery.app.amqp import Queues, TaskPublisher
from celery.tests.utils import AppCase
@@ -37,6 +38,22 @@ def test_publish_no_retry(self):
self.assertFalse(pub.connection.ensure.call_count)
+class test_compat_TaskPublisher(AppCase):
+
+ def test_compat_exchange_is_string(self):
+ producer = TaskPublisher(exchange="foo", app=self.app)
+ self.assertIsInstance(producer.exchange, Exchange)
+ self.assertEqual(producer.exchange.name, "foo")
+ self.assertEqual(producer.exchange.type, "direct")
+ producer = TaskPublisher(exchange="foo", exchange_type="topic",
+ app=self.app)
+ self.assertEqual(producer.exchange.type, "topic")
+
+ def test_compat_exchange_is_Exchange(self):
+ producer = TaskPublisher(exchange=Exchange("foo"))
+ self.assertEqual(producer.exchange.name, "foo")
+
+
class test_PublisherPool(AppCase):
def test_setup_nolimit(self):
@@ -100,3 +117,22 @@ def test_queues_format(self):
def test_with_defaults(self):
self.assertEqual(Queues(None), {})
+
+ def test_add(self):
+ q = Queues()
+ q.add("foo", exchange="ex", routing_key="rk")
+ self.assertIn("foo", q)
+ self.assertIsInstance(q["foo"], Queue)
+ self.assertEqual(q["foo"].routing_key, "rk")
+
+ def test_add_default_exchange(self):
+ ex = Exchange("fff", "fanout")
+ q = Queues(default_exchange=ex)
+ q.add(Queue("foo"))
+ self.assertEqual(q["foo"].exchange, ex)
+
+ def test_alias(self):
+ q = Queues()
+ q.add(Queue("foo", alias="barfoo"))
+ self.assertIs(q["barfoo"], q["foo"])
+
Oops, something went wrong.

0 comments on commit 8af6e58

Please sign in to comment.