Permalink
Browse files

100% coverage for celery.worker.hub

  • Loading branch information...
1 parent 0ac5bbd commit a686b662f93ddd293fc2e76deb2a6431f65d1b33 @ask committed Jun 6, 2012
Showing with 162 additions and 127 deletions.
  1. +2 −9 celery/app/amqp.py
  2. +6 −7 celery/app/task.py
  3. +4 −28 celery/task/trace.py
  4. +34 −16 celery/tests/worker/test_request.py
  5. +113 −30 celery/worker/hub.py
  6. +3 −37 celery/worker/job.py
View
@@ -30,10 +30,6 @@
. %(name)s exchange:%(exchange)s(%(exchange_type)s) binding:%(routing_key)s
"""
-TASK_BARE = 0x004
-TASK_DEFAULT = 0
-
-
class Queues(dict):
"""Queue name⇒ declaration mapping.
@@ -154,7 +150,7 @@ def delay_task(self, task_name, task_args=None, task_kwargs=None,
queue=None, now=None, retries=0, chord=None, callbacks=None,
errbacks=None, mandatory=None, priority=None, immediate=None,
routing_key=None, serializer=None, delivery_mode=None,
- compression=None, bare=False, **kwargs):
+ compression=None, **kwargs):
"""Send task message."""
# merge default and custom policy
_rp = (dict(self.retry_policy, **retry_policy) if retry_policy
@@ -174,8 +170,6 @@ def delay_task(self, task_name, task_args=None, task_kwargs=None,
expires = now + timedelta(seconds=expires)
eta = eta and eta.isoformat()
expires = expires and expires.isoformat()
- flags = TASK_DEFAULT
- flags |= TASK_BARE if bare else 0
body = {"task": task_name,
"id": task_id,
@@ -186,8 +180,7 @@ def delay_task(self, task_name, task_args=None, task_kwargs=None,
"expires": expires,
"utc": self.utc,
"callbacks": callbacks,
- "errbacks": errbacks,
- "flags": flags}
+ "errbacks": errbacks}
if taskset_id:
body["taskset"] = taskset_id
if chord:
View
@@ -34,13 +34,12 @@
from .annotations import resolve_all as resolve_all_annotations
from .registry import _unpickle_task
-#: extracts options related to publishing a message from a dict.
-extract_exec_options = mattrgetter("queue", "routing_key",
- "exchange", "immediate",
- "mandatory", "priority",
- "serializer", "delivery_mode",
- "compression", "expires", "bare")
-
+#: extracts attributes related to publishing a message from an object.
+extract_exec_options = mattrgetter(
+ "queue", "routing_key", "exchange",
+ "immediate", "mandatory", "priority", "expires",
+ "serializer", "delivery_mode", "compression",
+)
#: Billiard sets this when execv is enabled.
#: We use it to find out the name of the original ``__main__``
View
@@ -130,30 +130,6 @@ def handle_failure(self, task, store_errors=True):
del(tb)
-def execute_bare(task, uuid, args, kwargs, request=None, Info=TraceInfo):
- R = I = None
- kwargs = kwdict(kwargs)
- try:
- try:
- R = retval = task(*args, **kwargs)
- state = SUCCESS
- except Exception, exc:
- I = Info(FAILURE, exc)
- state, retval = I.state, I.retval
- R = I.handle_error_state(task)
- except BaseException, exc:
- raise
- except: # pragma: no cover
- # For Python2.5 where raising strings are still allowed
- # (but deprecated)
- I = Info(FAILURE, None)
- state, retval = I.state, I.retval
- R = I.handle_error_state(task)
- except Exception, exc:
- R = report_internal_error(task, exc)
- return R
-
-
def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
Info=TraceInfo, eager=False, propagate=False):
# If the task doesn't define a custom __call__ method
@@ -282,16 +258,16 @@ def trace_task(uuid, args, kwargs, request=None):
return trace_task
-def trace_task(task, uuid, args, kwargs, request=None, **opts):
+def trace_task(task, uuid, args, kwargs, request={}, **opts):
try:
if task.__trace__ is None:
task.__trace__ = build_tracer(task.name, task, **opts)
- return task.__trace__(uuid, args, kwargs, request)
+ return task.__trace__(uuid, args, kwargs, request)[0]
except Exception, exc:
- return report_internal_error(task, exc), None
+ return report_internal_error(task, exc)
-def trace_task_ret(task, uuid, args, kwargs, request):
+def trace_task_ret(task, uuid, args, kwargs, request={}):
return _tasks[task].__trace__(uuid, args, kwargs, request)[0]
@@ -21,18 +21,23 @@
from celery.datastructures import ExceptionInfo
from celery.exceptions import (RetryTaskError,
WorkerLostError, InvalidTaskError)
-from celery.task.trace import eager_trace_task, TraceInfo, mro_lookup
+from celery.task.trace import (
+ trace_task,
+ trace_task_ret,
+ TraceInfo,
+ mro_lookup,
+ build_tracer,
+)
from celery.result import AsyncResult
from celery.task import task as task_dec
from celery.task.base import Task
from celery.utils import uuid
from celery.worker import job as module
-from celery.worker.job import Request, TaskRequest, execute_and_trace
+from celery.worker.job import Request, TaskRequest
from celery.worker.state import revoked
from celery.tests.utils import Case
-
scratch = {"ACK": False}
some_kwargs_scratchpad = {}
@@ -68,8 +73,10 @@ def mro(cls):
def jail(task_id, name, args, kwargs):
request = {"id": task_id}
- return eager_trace_task(current_app.tasks[name],
- task_id, args, kwargs, request=request, eager=False)[0]
+ task = current_app.tasks[name]
+ task.__trace__ = None # rebuild
+ return trace_task(task,
+ task_id, args, kwargs, request=request, eager=False)
def on_ack(*args, **kwargs):
@@ -221,6 +228,7 @@ def send(self, event, **fields):
class test_TaskRequest(Case):
+
def test_task_wrapper_repr(self):
tw = TaskRequest(mytask.name, uuid(), [1], {"f": "x"})
self.assertTrue(repr(tw))
@@ -262,8 +270,11 @@ def test_on_retry(self):
einfo = ExceptionInfo()
tw.on_failure(einfo)
self.assertIn("task-retried", tw.eventer.sent)
- tw._does_info = False
- tw.on_failure(einfo)
+ prev, module._does_info = module._does_info, False
+ try:
+ tw.on_failure(einfo)
+ finally:
+ module._does_info = prev
einfo.internal = True
tw.on_failure(einfo)
@@ -408,8 +419,11 @@ def test_on_accepted_acks_early(self):
tw = TaskRequest(mytask.name, uuid(), [1], {"f": "x"})
tw.on_accepted(pid=os.getpid(), time_accepted=time.time())
self.assertTrue(tw.acknowledged)
- tw._does_debug = False
- tw.on_accepted(pid=os.getpid(), time_accepted=time.time())
+ prev, module._does_debug = module._does_debug, False
+ try:
+ tw.on_accepted(pid=os.getpid(), time_accepted=time.time())
+ finally:
+ module._does_debug = prev
def test_on_accepted_acks_late(self):
tw = TaskRequest(mytask.name, uuid(), [1], {"f": "x"})
@@ -432,9 +446,12 @@ def test_on_success_acks_early(self):
tw = TaskRequest(mytask.name, uuid(), [1], {"f": "x"})
tw.time_start = 1
tw.on_success(42)
- tw._does_info = False
- tw.on_success(42)
- self.assertFalse(tw.acknowledged)
+ prev, module._does_info = module._does_info, False
+ try:
+ tw.on_success(42)
+ self.assertFalse(tw.acknowledged)
+ finally:
+ module._does_info = prev
def test_on_success_BaseException(self):
tw = TaskRequest(mytask.name, uuid(), [1], {"f": "x"})
@@ -539,8 +556,10 @@ def test_on_timeout(self, warn, error):
finally:
mytask.ignore_result = False
- def test_execute_and_trace(self):
- res = execute_and_trace(mytask.name, uuid(), [4], {})
+ def test_trace_task_ret(self):
+ mytask.__trace__ = build_tracer(mytask.name, mytask,
+ current_app.loader, "test")
+ res = trace_task_ret(mytask.name, uuid(), [4], {})
self.assertEqual(res, 4 ** 4)
def test_execute_safe_catches_exception(self):
@@ -554,8 +573,7 @@ def raising():
with self.assertWarnsRegex(RuntimeWarning,
r'Exception raised outside'):
- res = execute_and_trace(raising.name, uuid(),
- [], {})
+ res = trace_task(raising, uuid(), [], {})
self.assertIsInstance(res, ExceptionInfo)
def test_worker_task_trace_handle_retry(self):
Oops, something went wrong.

0 comments on commit a686b66

Please sign in to comment.