Skip to content
Browse files

Optimizations

  • Loading branch information...
1 parent 354af2b commit c9b40117add43912e5a8a984c2ddf525d69e127a @ask committed May 31, 2012
View
72 celery/app/task.py
@@ -22,9 +22,8 @@
from celery import states
from celery.__compat__ import class_property
from celery.state import get_current_task
-from celery.datastructures import ExceptionInfo
+from celery.datastructures import AttributeDict, ExceptionInfo
from celery.exceptions import MaxRetriesExceededError, RetryTaskError
-from celery.local import LocalStack
from celery.result import EagerResult
from celery.utils import fun_takes_kwargs, uuid, maybe_reraise
from celery.utils.functional import mattrgetter, maybe_list
@@ -44,6 +43,10 @@
class Context(object):
+ __slots__ = ("logfile", "loglevel", "hostname",
+ "id", "args", "kwargs", "retries", "is_eager",
+ "delivery_info", "taskset", "chord", "called_directly",
+ "callbacks", "errbacks", "children", "__dict__")
# Default context
logfile = None
loglevel = None
@@ -59,32 +62,24 @@ class Context(object):
called_directly = True
callbacks = None
errbacks = None
- _children = None # see property
+ children = None # see property
def __init__(self, *args, **kwargs):
- self.update(*args, **kwargs)
+ up = self.update = self.__dict__.update
+ self.clear = self.__dict__.clear
+ self.get = self.__dict__.get
+ up(*args, children=[], **kwargs)
- def update(self, *args, **kwargs):
- self.__dict__.update(*args, **kwargs)
-
- def clear(self):
- self.__dict__.clear()
+ def __repr__(self):
+ return repr(self._vars())
- def get(self, key, default=None):
- try:
- return getattr(self, key)
- except AttributeError:
- return default
+ def __reduce__(self):
+ return self.__class__, (self._vars(), )
- def __repr__(self):
- return "<Context: %r>" % (vars(self, ))
+ def _vars(self):
+ return dict((k, v) for k, v in vars(self).iteritems()
+ if k not in ("clear", "get", "update"))
- @property
- def children(self):
- # children must be an empy list for every thread
- if self._children is None:
- self._children = []
- return self._children
class TaskType(type):
@@ -176,7 +171,7 @@ class BaseTask(object):
#: If disabled the worker will not forward magic keyword arguments.
#: Deprecated and scheduled for removal in v3.0.
- accept_magic_kwargs = False
+ accept_magic_kwargs = None
#: Destination queue. The queue needs to exist
#: in :setting:`CELERY_QUEUES`. The `routing_key`, `exchange` and
@@ -317,7 +312,6 @@ def bind(self, app):
for attr_name, config_name in self.from_config:
if getattr(self, attr_name, None) is None:
setattr(self, attr_name, conf[config_name])
- self.accept_magic_kwargs = app.accept_magic_kwargs
if self.accept_magic_kwargs is None:
self.accept_magic_kwargs = app.accept_magic_kwargs
if self.backend is None:
@@ -327,6 +321,7 @@ def bind(self, app):
if not was_bound:
self.annotate()
+ from celery.utils.threads import LocalStack
self.request_stack = LocalStack()
self.request_stack.push(Context())
@@ -696,12 +691,16 @@ def apply(self, args=None, kwargs=None, **options):
# Make sure we get the task instance, not class.
task = app._tasks[self.name]
- request = {"id": task_id,
- "retries": retries,
- "is_eager": True,
- "logfile": options.get("logfile"),
- "loglevel": options.get("loglevel", 0),
- "delivery_info": {"is_eager": True}}
+ request = AttributeDict({"id": task_id,
+ "retries": retries,
+ "is_eager": True,
+ "logfile": options.get("logfile"),
+ "loglevel": options.get("loglevel", 0),
+ "delivery_info": {"is_eager": True},
+ "callbacks": [],
+ "errbacks": [],
+ "chord": None,
+ "called_directly": True})
if self.accept_magic_kwargs:
default_kwargs = {"task_name": task.name,
"task_id": task_id,
@@ -852,19 +851,6 @@ def send_error_email(self, context, exc, **kwargs):
if self.send_error_emails and not self.disable_error_emails:
self.ErrorMail(self, **kwargs).send(context, exc)
- def execute(self, request, pool, loglevel, logfile, **kwargs):
- """The method the worker calls to execute the task.
-
- :param request: A :class:`~celery.worker.job.Request`.
- :param pool: A task pool.
- :param loglevel: Current loglevel.
- :param logfile: Name of the currently used logfile.
-
- :keyword consumer: The :class:`~celery.worker.consumer.Consumer`.
-
- """
- request.execute_using_pool(pool, loglevel, logfile)
-
def push_request(self, *args, **kwargs):
self.request_stack.push(Context(*args, **kwargs))
View
9 celery/apps/worker.py
@@ -19,9 +19,13 @@
from celery.utils.imports import qualname
from celery.utils.log import LOG_LEVELS, get_logger, mlevel, set_in_sighandler
from celery.utils.text import pluralize
-from celery.utils.threads import active_count as active_thread_count
from celery.worker import WorkController
+def active_thread_count():
+ from threading import enumerate
+ return sum(1 for t in enumerate()
+ if not t.name.startswith("Dummy-"))
+
try:
from greenlet import GreenletExit
IGNORE_ERRORS = (GreenletExit, )
@@ -243,6 +247,9 @@ def _handle_request(signum, frame):
callback(worker)
safe_say("celeryd: %s shutdown (MainProcess)" % how)
if active_thread_count() > 1:
+ print("SET SHOULD STOP")
+ import threading
+ print("THREADS: %r" % (list(threading.enumerate(), )))
setattr(state, {"Warm": "should_stop",
"Cold": "should_terminate"}[how], True)
else:
View
12 celery/concurrency/base.py
@@ -12,11 +12,13 @@
logger = get_logger("celery.concurrency")
+_pid = None
+
def apply_target(target, args=(), kwargs={}, callback=None,
accept_callback=None, pid=None, **_):
if accept_callback:
- accept_callback(pid or os.getpid(), time.time())
+ accept_callback(pid)
callback(target(*args, **kwargs))
@@ -116,13 +118,7 @@ def apply_async(self, target, args=[], kwargs={}, **options):
otherwise the thread which handles the result will get blocked.
"""
- if self._does_debug:
- logger.debug("TaskPool: Apply %s (args:%s kwargs:%s)",
- target, safe_repr(args), safe_repr(kwargs))
-
- return self.on_apply(target, args, kwargs,
- waitforslot=self.putlocks,
- **options)
+ raise NotImplementedError("apply_async")
def _get_info(self):
return {}
View
17 celery/concurrency/eventlet.py
@@ -17,8 +17,9 @@
def apply_target(target, args=(), kwargs={}, callback=None,
- accept_callback=None, getpid=None):
- return base.apply_target(target, args, kwargs, callback, accept_callback,
+ accept_callback=None, getpid=None,
+ apply=base.apply_target):
+ return apply(target, args, kwargs, callback, accept_callback,
pid=getpid())
@@ -105,6 +106,8 @@ def __init__(self, *args, **kwargs):
def on_start(self):
self._pool = self.Pool(self.limit)
+ self._quick_put = self._pool.spawn_n
+ self._quick_send_applysig = signals.eventlet_pool_apply.send
signals.eventlet_pool_started.send(sender=self)
def on_stop(self):
@@ -113,10 +116,10 @@ def on_stop(self):
self._pool.waitall()
signals.eventlet_pool_postshutdown.send(sender=self)
- def on_apply(self, target, args=None, kwargs=None, callback=None,
+ def apply_async(self, target, args=(), kwargs={}, callback=None,
accept_callback=None, **_):
- signals.eventlet_pool_apply.send(sender=self,
+ self._quick_send_applysig(sender=self,
target=target, args=args, kwargs=kwargs)
- self._pool.spawn_n(apply_target, target, args, kwargs,
- callback, accept_callback,
- self.getpid)
+ self._quick_put(apply_target, target, args, kwargs,
+ callback, accept_callback,
+ self.getpid)
View
7 celery/concurrency/gevent.py
@@ -87,15 +87,16 @@ def __init__(self, *args, **kwargs):
def on_start(self):
self._pool = self.Pool(self.limit)
+ self._quick_spawn = self._pool.spawn
def on_stop(self):
if self._pool is not None:
self._pool.join()
- def on_apply(self, target, args=None, kwargs=None, callback=None,
+ def apply_async(self, target, args=(), kwargs={}, callback=None,
accept_callback=None, **_):
- return self._pool.spawn(apply_target, target, args, kwargs,
- callback, accept_callback)
+ return self._quick_spawn(apply_target, target, args, kwargs,
+ callback, accept_callback)
def grow(self, n=1):
self._pool._semaphore.counter += n
View
6 celery/concurrency/processes/__init__.py
@@ -48,6 +48,10 @@ def process_initializer(app, hostname):
app.loader.init_worker()
app.loader.init_worker_process()
app.finalize()
+ from celery.task.trace import build_tracer
+ for name, task in app.tasks.iteritems():
+ task.__tracer__ = build_tracer(name, task, app.loader, hostname)
+
signals.worker_process_init.send(sender=None)
@@ -67,7 +71,7 @@ def on_start(self):
self._pool = self.Pool(processes=self.limit,
initializer=process_initializer,
**self.options)
- self.on_apply = self._pool.apply_async
+ self.apply_async = self._pool.apply_async
def did_start_ok(self):
return self._pool.did_start_ok()
View
2 celery/concurrency/solo.py
@@ -11,7 +11,7 @@ class TaskPool(BasePool):
def __init__(self, *args, **kwargs):
super(TaskPool, self).__init__(*args, **kwargs)
- self.on_apply = apply_target
+ self.apply_async = apply_target
def _get_info(self):
return {"max-concurrency": 1,
View
8 celery/concurrency/threads.py
@@ -29,18 +29,20 @@ def on_start(self):
# threadpool stores all work requests until they are processed
# we don't need this dict, and it occupies way too much memory.
self._pool.workRequests = NullDict()
+ self._quick_put = self._pool.putRequest
+ self._quick_clear = self._pool._results_queue.queue.clear
def on_stop(self):
self._pool.dismissWorkers(self.limit, do_join=True)
- def on_apply(self, target, args=None, kwargs=None, callback=None,
+ def apply_async(self, target, args=None, kwargs=None, callback=None,
accept_callback=None, **_):
req = self.WorkRequest(apply_target, (target, args, kwargs, callback,
accept_callback))
- self._pool.putRequest(req)
+ self._quick_put(req)
# threadpool also has callback support,
# but for some reason the callback is not triggered
# before you've collected the results.
# Clear the results (if any), so it doesn't grow too large.
- self._pool._results_queue.queue.clear()
+ self._quick_clear()
return req
View
6 celery/datastructures.py
@@ -386,12 +386,13 @@ class LimitedSet(object):
:keyword expires: Time in seconds, before a membership expires.
"""
- __slots__ = ("maxlen", "expires", "_data")
+ __slots__ = ("maxlen", "expires", "_data", "__len__")
def __init__(self, maxlen=None, expires=None):
self.maxlen = maxlen
self.expires = expires
self._data = {}
+ self.__len__ = self._data.__len__
def add(self, value):
"""Add a new member."""
@@ -434,9 +435,6 @@ def as_dict(self):
def __iter__(self):
return iter(self._data.keys())
- def __len__(self):
- return len(self._data.keys())
-
def __repr__(self):
return "LimitedSet(%r)" % (self._data.keys(), )
View
82 celery/local.py
@@ -98,12 +98,12 @@ def __repr__(self):
return '<%s unbound>' % self.__class__.__name__
return repr(obj)
+
def __nonzero__(self):
try:
return bool(self._get_current_object())
except RuntimeError: # pragma: no cover
return False
-
def __unicode__(self):
try:
return unicode(self._get_current_object())
@@ -277,86 +277,6 @@ def __delattr__(self, name):
raise AttributeError(name)
-class LocalStack(object):
- """This class works similar to a :class:`Local` but keeps a stack
- of objects instead. This is best explained with an example::
-
- >>> ls = LocalStack()
- >>> ls.push(42)
- >>> ls.top
- 42
- >>> ls.push(23)
- >>> ls.top
- 23
- >>> ls.pop()
- 23
- >>> ls.top
- 42
-
- They can be force released by using a :class:`LocalManager` or with
- the :func:`release_local` function but the correct way is to pop the
- item from the stack after using. When the stack is empty it will
- no longer be bound to the current context (and as such released).
-
- By calling the stack without arguments it returns a proxy that resolves to
- the topmost item on the stack.
-
- """
-
- def __init__(self):
- self._local = Local()
-
- def __release_local__(self):
- self._local.__release_local__()
-
- def _get__ident_func__(self):
- return self._local.__ident_func__
-
- def _set__ident_func__(self, value):
- object.__setattr__(self._local, '__ident_func__', value)
- __ident_func__ = property(_get__ident_func__, _set__ident_func__)
- del _get__ident_func__, _set__ident_func__
-
- def __call__(self):
- def _lookup():
- rv = self.top
- if rv is None:
- raise RuntimeError('object unbound')
- return rv
- return Proxy(_lookup)
-
- def push(self, obj):
- """Pushes a new item to the stack"""
- rv = getattr(self._local, 'stack', None)
- if rv is None:
- self._local.stack = rv = []
- rv.append(obj)
- return rv
-
- def pop(self):
- """Removes the topmost item from the stack, will return the
- old value or `None` if the stack was already empty.
- """
- stack = getattr(self._local, 'stack', None)
- if stack is None:
- return None
- elif len(stack) == 1:
- release_local(self._local)
- return stack[-1]
- else:
- return stack.pop()
-
- @property
- def top(self):
- """The topmost item on the stack. If the stack is empty,
- `None` is returned.
- """
- try:
- return self._local.stack[-1]
- except (AttributeError, IndexError):
- return None
-
-
class LocalManager(object):
"""Local objects cannot manage themselves. For that you need a local
manager. You can pass a local manager multiple locals or add them later
View
3 celery/state.py
@@ -2,7 +2,8 @@
import threading
-from celery.local import Proxy, LocalStack
+from celery.local import Proxy
+from celery.utils.threads import LocalStack
default_app = None
View
196 celery/task/trace.py
@@ -21,7 +21,9 @@
import os
import socket
import sys
+import logging
+from time import time
from warnings import warn
from kombu.utils import kwdict
@@ -33,22 +35,24 @@
from celery.datastructures import ExceptionInfo
from celery.exceptions import RetryTaskError
from celery.utils.serialization import get_pickleable_exception
+from celery.utils.encoding import safe_repr
from celery.utils.log import get_logger
+from celery.utils.text import truncate
-_logger = get_logger(__name__)
-
-send_prerun = signals.task_prerun.send
-prerun_receivers = signals.task_prerun.receivers
-send_postrun = signals.task_postrun.send
-postrun_receivers = signals.task_postrun.receivers
-send_success = signals.task_success.send
-success_receivers = signals.task_success.receivers
STARTED = states.STARTED
SUCCESS = states.SUCCESS
RETRY = states.RETRY
FAILURE = states.FAILURE
EXCEPTION_STATES = states.EXCEPTION_STATES
+_logger = get_logger(__name__)
+info = _logger.info
+
+#: Format string used to log task success.
+success_msg = """\
+ Task %(name)s[%(id)s] succeeded in %(runtime)ss: %(return_value)s
+""".strip()
+
def mro_lookup(cls, attr, stop=()):
"""Returns the first node by MRO order that defines an attribute.
@@ -71,6 +75,12 @@ def task_has_custom(task, attr):
return mro_lookup(task.__class__, attr, stop=(BaseTask, object))
+def repr_result(result, maxlen=46):
+ # 46 is the length needed to fit
+ # "the quick brown fox jumps over the lazy dog" :)
+ return truncate(safe_repr(result), maxlen)
+
+
class TraceInfo(object):
__slots__ = ("state", "retval")
@@ -153,11 +163,24 @@ def execute_bare(task, uuid, args, kwargs, request=None, Info=TraceInfo):
def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
- Info=TraceInfo, eager=False, propagate=False):
+ Info=TraceInfo, eager=False, propagate=False, kwdict=kwdict, time=time):
# If the task doesn't define a custom __call__ method
# we optimize it away by simply calling the run method directly,
# saving the extra method call and a line less in the stack trace.
fun = task if task_has_custom(task, "__call__") else task.run
+ prerun_receivers = signals.task_prerun.receivers
+ send_prerun = signals.task_prerun.send if prerun_receivers else None
+
+ postrun_receivers = signals.task_postrun.receivers
+ send_postrun = signals.task_postrun.send if postrun_receivers else None
+ success_receivers = signals.task_success.receivers
+ send_success = signals.task_success.send if success_receivers else None
+ STARTED = states.STARTED
+ SUCCESS = states.SUCCESS
+ RETRY = states.RETRY
+ FAILURE = states.FAILURE
+ EXCEPTION_STATES = states.EXCEPTION_STATES
+ _does_info = _logger.isEnabledFor(logging.INFO)
loader = loader or current_app.loader
backend = task.backend
@@ -187,98 +210,88 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
pop_request = request_stack.pop
on_chord_part_return = backend.on_chord_part_return
+ push_current = _task_stack.push
+ pop_current = _task_stack.pop
+
from celery import canvas
subtask = canvas.subtask
- def trace_task(uuid, args, kwargs, request=None):
+ def trace_task(uuid, args, kwargs, request=None, kwdict=kwdict, time=time):
R = I = None
kwargs = kwdict(kwargs)
+ push_current(task)
+ push_request(request)
+ # -*- PRE -*-
+ if send_prerun:
+ send_prerun(sender=task, task_id=uuid, task=task,
+ args=args, kwargs=kwargs)
+ loader_task_init(uuid, task)
+ if track_started:
+ store_result(uuid, {"pid": pid,
+ "hostname": hostname}, STARTED)
+
+ # -*- TRACE -*-
try:
- _task_stack.push(task)
- task_request = Context(request or {}, args=args,
- called_directly=False, kwargs=kwargs)
- push_request(task_request)
- try:
- # -*- PRE -*-
- if prerun_receivers:
- send_prerun(sender=task, task_id=uuid, task=task,
- args=args, kwargs=kwargs)
- loader_task_init(uuid, task)
- if track_started:
- store_result(uuid, {"pid": pid,
- "hostname": hostname}, STARTED)
-
- # -*- TRACE -*-
- try:
- R = retval = fun(*args, **kwargs)
- state = SUCCESS
- except RetryTaskError, exc:
- I = Info(RETRY, exc)
- state, retval = I.state, I.retval
- R = I.handle_error_state(task, eager=eager)
- except Exception, exc:
- if propagate:
- raise
- I = Info(FAILURE, exc)
- state, retval = I.state, I.retval
- R = I.handle_error_state(task, eager=eager)
- [subtask(errback).apply_async((uuid, ))
- for errback in task_request.errbacks or []]
- except BaseException, exc:
- raise
- except: # pragma: no cover
- # For Python2.5 where raising strings are still allowed
- # (but deprecated)
- if propagate:
- raise
- I = Info(FAILURE, None)
- state, retval = I.state, I.retval
- R = I.handle_error_state(task, eager=eager)
- [subtask(errback).apply_async((uuid, ))
- for errback in task_request.errbacks or []]
- else:
- if publish_result:
- store_result(uuid, retval, SUCCESS)
- # callback tasks must be applied before the result is
- # stored, so that result.children is populated.
- [subtask(callback).apply_async((retval, ))
- for callback in task_request.callbacks or []]
- if task_on_success:
- task_on_success(retval, uuid, args, kwargs)
- if success_receivers:
- send_success(sender=task, result=retval)
-
- # -* POST *-
- if task_request.chord:
- on_chord_part_return(task)
- if task_after_return:
- task_after_return(state, retval, uuid, args, kwargs, None)
- if postrun_receivers:
- send_postrun(sender=task, task_id=uuid, task=task,
- args=args, kwargs=kwargs,
- retval=retval, state=state)
- finally:
- _task_stack.pop()
- pop_request()
- if not eager:
- try:
- backend_cleanup()
- loader_cleanup()
- except (KeyboardInterrupt, SystemExit, MemoryError):
- raise
- except Exception, exc:
- _logger.error("Process cleanup failed: %r", exc,
- exc_info=True)
+ time_start = time()
+ R = fun(*args, **kwargs)
+ state = SUCCESS
+ if publish_result:
+ store_result(uuid, R, SUCCESS)
+ # callback tasks must be applied before the result is
+ # stored, so that result.children is populated.
+ [subtask(callback).apply_async((R, ))
+ for callback in request.callbacks or ()]
+ if task_on_success:
+ task_on_success(R, uuid, args, kwargs)
+ if send_success:
+ send_success(sender=task, result=R)
+ if _does_info:
+ info(success_msg, {
+ "id": uuid, "name": name,
+ "return_value": repr_result(R),
+ "runtime": time() - time_start})
+ except RetryTaskError, exc:
+ I = Info(RETRY, exc)
+ state, R = I.state, I.retval
+ R = I.handle_error_state(task, eager=eager)
except Exception, exc:
- if eager:
+ if propagate:
raise
- R = report_internal_error(task, exc)
+ I = Info(FAILURE, exc)
+ state, R = I.state, I.retval
+ R = I.handle_error_state(task, eager=eager)
+ [subtask(errback).apply_async((uuid, ))
+ for errback in request.errbacks or []]
+ except BaseException, exc:
+ raise
+ finally:
+ # -* POST *-
+ if request.chord:
+ on_chord_part_return(task)
+ if task_after_return:
+ task_after_return(state, R, uuid, args, kwargs, None)
+ if send_postrun:
+ send_postrun(sender=task, task_id=uuid, task=task,
+ args=args, kwargs=kwargs,
+ retval=R, state=state)
+ pop_current()
+ pop_request()
+ if not eager:
+ try:
+ backend_cleanup()
+ loader_cleanup()
+ except (KeyboardInterrupt, SystemExit, MemoryError):
+ raise
+ except Exception, exc:
+ _logger.error("Process cleanup failed: %r", exc,
+ exc_info=True)
return R, I
return trace_task
-def trace_task(task, uuid, args, kwargs, request=None, **opts):
+def trace_task(task, uuid, args, kwargs, request=None,
+ build_tracer=build_tracer, **opts):
try:
if task.__tracer__ is None:
task.__tracer__ = build_tracer(task.name, task, **opts)
@@ -287,6 +300,13 @@ def trace_task(task, uuid, args, kwargs, request=None, **opts):
return report_internal_error(task, exc), None
+def trace_task_ret(task, uuid, args, kwargs, request):
+ try:
+ task.__tracer__(uuid, args, kwargs, request)
+ except Exception, exc:
+ return report_internal_error(task, exc)
+
+
def eager_trace_task(task, uuid, args, kwargs, request=None, **opts):
opts.setdefault("eager", True)
return build_tracer(task.name, task, **opts)(
View
15 celery/utils/threads.py
@@ -81,3 +81,18 @@ def stop(self):
self._is_stopped.wait()
if self.is_alive():
self.join(1e100)
+
+
+class LocalStack(threading.local):
+
+ def __init__(self):
+ self.stack = []
+ self.push = self.stack.append
+ self.pop = self.stack.pop
+
+ @property
+ def top(self):
+ try:
+ return self.stack[-1]
+ except (AttributeError, IndexError):
+ return None
View
2 celery/worker/__init__.py
@@ -353,7 +353,7 @@ def process_task_sem(self, req):
def process_task(self, req):
"""Process task by sending it to the pool of workers."""
try:
- req.task.execute(req, self.pool, self.loglevel, self.logfile)
+ req.execute_using_pool(self.pool)
except Exception, exc:
logger.critical("Internal error: %r\n%s",
exc, traceback.format_exc(), exc_info=True)
View
30 celery/worker/consumer.py
@@ -89,6 +89,7 @@
from celery.app import app_or_default
from celery.datastructures import AttributeDict
from celery.exceptions import InvalidTaskError, SystemTerminate
+from celery.task.trace import build_tracer
from celery.utils import timer2
from celery.utils.functional import noop
from celery.utils.log import get_logger
@@ -145,6 +146,9 @@
Trying to re-establish the connection...\
"""
+task_reserved = state.task_reserved
+revoked_tasks = state.revoked
+
logger = get_logger(__name__)
info, warn, error, crit = (logger.info, logger.warn,
logger.error, logger.critical)
@@ -336,11 +340,16 @@ def __init__(self, ready_queue,
if hub:
hub.on_init.append(self.on_poll_init)
self.hub = hub
+ self._quick_put = self.ready_queue.put
def update_strategies(self):
S = self.strategies
- for task in self.app.tasks.itervalues():
- S[task.name] = task.start_strategy(self.app, self)
+ app = self.app
+ loader = app.loader
+ hostname = self.hostname
+ for name, task in self.app.tasks.iteritems():
+ S[name] = task.start_strategy(app, self)
+ task.__tracer__ = build_tracer(name, task, loader, hostname)
def start(self):
"""Start the consumer.
@@ -456,21 +465,22 @@ def on_task_received(body, message):
else:
sleep(min(time_to_sleep, 0.1))
- def on_task(self, task):
+ def on_task(self, task, task_reserved=task_reserved):
"""Handle received task.
If the task has an `eta` we enter it into the ETA schedule,
otherwise we move it the ready queue for immediate processing.
"""
- if task.revoked():
+ if (revoked_tasks or task.expires) and task.revoked():
return
if self._does_info:
info("Got task from broker: %s", task.shortinfo())
- if self.event_dispatcher.enabled:
- self.event_dispatcher.send("task-received", uuid=task.id,
+ ev = self.event_dispatcher
+ if ev and ev.enabled:
+ ev.send("task-received", uuid=task.id,
name=task.name, args=safe_repr(task.args),
kwargs=safe_repr(task.kwargs),
retries=task.request_dict.get("retries", 0),
@@ -489,8 +499,8 @@ def on_task(self, task):
self.timer.apply_at(eta, self.apply_eta_task, (task, ),
priority=6)
else:
- state.task_reserved(task)
- self.ready_queue.put(task)
+ task_reserved(task)
+ self._quick_put(task)
def on_control(self, body, message):
"""Process remote control command message."""
@@ -505,8 +515,8 @@ def on_control(self, body, message):
def apply_eta_task(self, task):
"""Method called by the timer to apply a task with an
ETA/countdown."""
- state.task_reserved(task)
- self.ready_queue.put(task)
+ task_reserved(task)
+ self._quick_put(task)
self.qos.decrement_eventually()
def _message_report(self, body, message):
View
236 celery/worker/job.py
@@ -14,7 +14,6 @@
import logging
import time
-import socket
import sys
from datetime import datetime
@@ -25,10 +24,12 @@
from celery import current_app
from celery import exceptions
from celery.app import app_or_default
-from celery.datastructures import ExceptionInfo
+from celery.app.task import Context
+from celery.datastructures import AttributeDict, ExceptionInfo
from celery.task.trace import (
build_tracer,
trace_task,
+ trace_task_ret,
report_internal_error,
execute_bare,
)
@@ -47,13 +48,33 @@
_does_debug = logger.isEnabledFor(logging.DEBUG)
_does_info = logger.isEnabledFor(logging.INFO)
+#: Format string used to log task failure.
+error_msg = """\
+ Task %(name)s[%(id)s] raised exception: %(exc)s
+""".strip()
+
+#: Format string used to log internal error.
+internal_error_msg = """\
+ Task %(name)s[%(id)s] INTERNAL ERROR: %(exc)s
+""".strip()
+
+#: Format string used to log task retry.
+retry_msg = """Task %(name)s[%(id)s] retry: %(exc)s"""
+
+
# Localize
tz_to_local = timezone.to_local
tz_or_local = timezone.tz_or_local
tz_utc = timezone.utc
NEEDS_KWDICT = sys.version_info <= (2, 6)
+_current_app_for_proc = None
+
+task_accepted = state.task_accepted
+task_ready = state.task_ready
+revoked_tasks = state.revoked
+
def execute_and_trace(name, uuid, args, kwargs, request=None, **opts):
"""This is a pickleable method used as a target when applying to pools.
@@ -63,7 +84,10 @@ def execute_and_trace(name, uuid, args, kwargs, request=None, **opts):
>>> trace_task(name, *args, **kwargs)[0]
"""
- task = current_app.tasks[name]
+ global _current_app_for_proc
+ if _current_app_for_proc is None:
+ _current_app_for_proc = current_app._get_current_object()
+ task = _current_app_for_proc.tasks[name]
try:
hostname = opts.get("hostname")
setps("celeryd", name, hostname, rate_limit=True)
@@ -79,83 +103,62 @@ def execute_and_trace(name, uuid, args, kwargs, request=None, **opts):
class Request(object):
"""A request for task execution."""
- __slots__ = ("app", "name", "id", "args", "kwargs",
- "on_ack", "delivery_info", "hostname",
- "callbacks", "errbacks",
- "eventer", "connection_errors",
- "task", "eta", "expires", "flags",
- "request_dict", "acknowledged", "success_msg",
- "error_msg", "retry_msg", "time_start", "worker_pid",
- "_already_revoked", "_terminate_on_ack", "_tzlocal")
-
- #: Format string used to log task success.
- success_msg = """\
- Task %(name)s[%(id)s] succeeded in %(runtime)ss: %(return_value)s
- """
-
- #: Format string used to log task failure.
- error_msg = """\
- Task %(name)s[%(id)s] raised exception: %(exc)s
- """
-
- #: Format string used to log internal error.
- internal_error_msg = """\
- Task %(name)s[%(id)s] INTERNAL ERROR: %(exc)s
- """
-
- #: Format string used to log task retry.
- retry_msg = """Task %(name)s[%(id)s] retry: %(exc)s"""
+ eta = None
+ started = False
+ acknowledged = _already_revoked = False
+ worker_pid = _terminate_on_ack = None
+ _tzlocal = None
+ expires = None
+ delivery_info = {}
+ flags = 0
+ args = ()
def __init__(self, body, on_ack=noop,
hostname=None, eventer=None, app=None,
connection_errors=None, request_dict=None,
- delivery_info=None, task=None, **opts):
- self.app = app or app_or_default(app)
- name = self.name = body["task"]
+ delivery_info=None, task=None, Context=Context, **opts):
+ self.app = app
+ self.name = body["task"]
self.id = body["id"]
- self.args = body.get("args", [])
- self.kwargs = body.get("kwargs", {})
+ self.args = body["args"]
+ try:
+ self.kwargs = body["kwargs"]
+ if NEEDS_KWDICT:
+ self.kwargs = kwdict(self.kwargs)
+ except KeyError:
+ self.kwargs = {}
try:
- self.kwargs.items
- except AttributeError:
- raise exceptions.InvalidTaskError(
- "Task keyword arguments is not a mapping")
- if NEEDS_KWDICT:
- self.kwargs = kwdict(self.kwargs)
- eta = body.get("eta")
- expires = body.get("expires")
- utc = body.get("utc", False)
- self.flags = body.get("flags", False)
+ self.flags = body["flags"]
+ except KeyError:
+ pass
self.on_ack = on_ack
- self.hostname = hostname or socket.gethostname()
+ self.hostname = hostname
self.eventer = eventer
self.connection_errors = connection_errors or ()
- self.task = task or self.app.tasks[name]
- self.acknowledged = self._already_revoked = False
- self.time_start = self.worker_pid = self._terminate_on_ack = None
- self._tzlocal = None
-
- # timezone means the message is timezone-aware, and the only timezone
- # supported at this point is UTC.
- if eta is not None:
- tz = tz_utc if utc else self.tzlocal
- self.eta = tz_to_local(maybe_iso8601(eta), self.tzlocal, tz)
- else:
- self.eta = None
- if expires is not None:
- tz = tz_utc if utc else self.tzlocal
- self.expires = tz_to_local(maybe_iso8601(expires),
+ self.task = task or self.app._tasks[self.name]
+ utc = body.get("utc")
+ if "eta" in body:
+ eta = body["eta"]
+ if eta:
+ tz = tz_utc if utc else self.tzlocal
+ self.eta = tz_to_local(maybe_iso8601(eta), self.tzlocal, tz)
+ if "expires" in body:
+ expires = body["expires"]
+ if expires:
+ tz = tz_utc if utc else self.tzlocal
+ self.expires = tz_to_local(maybe_iso8601(expires),
self.tzlocal, tz)
- else:
- self.expires = None
-
- delivery_info = {} if delivery_info is None else delivery_info
- self.delivery_info = {
- "exchange": delivery_info.get("exchange"),
- "routing_key": delivery_info.get("routing_key"),
- }
-
- self.request_dict = body
+ if delivery_info:
+ self.delivery_info = {
+ "exchange": delivery_info.get("exchange"),
+ "routing_key": delivery_info.get("routing_key"),
+ }
+
+ self.request_dict = AttributeDict(
+ {"called_directly": False,
+ "callbacks": [],
+ "errbacks": [],
+ "chord": None}, **body)
@classmethod
def from_message(cls, message, body, **kwargs):
@@ -190,15 +193,10 @@ def extend_with_default_kwargs(self, loglevel, logfile):
kwargs.update(extend_with)
return kwargs
- def execute_using_pool(self, pool, loglevel=None, logfile=None):
+ def execute_using_pool(self, pool, **kwargs):
"""Like :meth:`execute`, but using a worker pool.
- :param pool: A :class:`multiprocessing.Pool` instance.
-
- :keyword loglevel: The loglevel used by the task.
-
- :keyword logfile: The logfile used by the task.
-
+ :param pool: A :class:`celery.concurrency.base.TaskPool` instance.
"""
task = self.task
if self.flags & 0x004:
@@ -210,21 +208,18 @@ def execute_using_pool(self, pool, loglevel=None, logfile=None):
error_callback=self.on_failure,
soft_timeout=task.soft_time_limit,
timeout=task.time_limit)
- if self.revoked():
+ if (revoked_tasks or self.expires) and self.revoked():
return
hostname = self.hostname
kwargs = self.kwargs
- if self.task.accept_magic_kwargs:
+ if task.accept_magic_kwargs:
kwargs = self.extend_with_default_kwargs(loglevel, logfile)
request = self.request_dict
- request.update({"loglevel": loglevel, "logfile": logfile,
- "hostname": hostname, "is_eager": False,
+ request.update({"hostname": hostname, "is_eager": False,
"delivery_info": self.delivery_info})
- result = pool.apply_async(execute_and_trace,
- args=(self.name, self.id, self.args, kwargs),
- kwargs={"hostname": hostname,
- "request": request},
+ result = pool.apply_async(trace_task_ret,
+ (task, self.id, self.args, kwargs, request),
accept_callback=self.on_accepted,
timeout_callback=self.on_timeout,
callback=self.on_success,
@@ -240,7 +235,7 @@ def execute(self, loglevel=None, logfile=None):
:keyword logfile: The logfile used by the task.
"""
- if self.revoked():
+ if (revoked_tasks or self.expires) and self.revoked():
return
# acknowledge task as being processed.
@@ -264,12 +259,12 @@ def execute(self, loglevel=None, logfile=None):
def maybe_expire(self):
"""If expired, mark the task as revoked."""
if self.expires and datetime.now(self.tzlocal) > self.expires:
- state.revoked.add(self.id)
+ revoked_tasks.add(self.id)
if self.store_errors:
self.task.backend.mark_as_revoked(self.id)
def terminate(self, pool, signal=None):
- if self.time_start:
+ if self.started:
return pool.terminate_job(self.worker_pid, signal)
else:
self._terminate_on_ack = (True, pool, signal)
@@ -280,26 +275,24 @@ def revoked(self):
return True
if self.expires:
self.maybe_expire()
- if self.id in state.revoked:
+ if self.id in revoked_tasks:
warn("Skipping revoked task: %s[%s]", self.name, self.id)
- self.send_event("task-revoked", uuid=self.id)
+ if self.eventer and self.eventer.enabled:
+ self.eventer.send("task-revoked", uuid=self.id)
self.acknowledge()
self._already_revoked = True
return True
return False
- def send_event(self, type, **fields):
- if self.eventer and self.eventer.enabled:
- self.eventer.send(type, **fields)
-
- def on_accepted(self, pid, time_accepted):
+ def on_accepted(self, pid, *args):
"""Handler called when task is accepted by worker pool."""
+ self.started = True
self.worker_pid = pid
- self.time_start = time_accepted
- state.task_accepted(self)
+ task_accepted(self)
if not self.task.acks_late:
self.acknowledge()
- self.send_event("task-started", uuid=self.id, pid=pid)
+ if self.eventer and self.eventer.enabled:
+ self.eventer.send("task-started", uuid=self.id, pid=pid)
if _does_debug:
debug("Task accepted: %s[%s] pid:%r", self.name, self.id, pid)
if self._terminate_on_ack is not None:
@@ -308,7 +301,7 @@ def on_accepted(self, pid, time_accepted):
def on_timeout(self, soft, timeout):
"""Handler called if the task times out."""
- state.task_ready(self)
+ task_ready(self)
if soft:
warn("Soft time limit (%ss) exceeded for %s[%s]",
timeout, self.name, self.id)
@@ -321,37 +314,29 @@ def on_timeout(self, soft, timeout):
if self.store_errors:
self.task.backend.mark_as_failure(self.id, exc)
- def on_success(self, ret_value, now=None):
+ def on_success(self, ret_value):
"""Handler called if the task was successfully processed."""
if isinstance(ret_value, ExceptionInfo):
if isinstance(ret_value.exception, (
SystemExit, KeyboardInterrupt)):
raise ret_value.exception
return self.on_failure(ret_value)
- state.task_ready(self)
-
+ task_ready(self)
if self.task.acks_late:
self.acknowledge()
if self.eventer and self.eventer.enabled:
now = time.time()
- runtime = self.time_start and (time.time() - self.time_start) or 0
- self.send_event("task-succeeded", uuid=self.id,
- result=safe_repr(ret_value), runtime=runtime)
-
- if _does_info:
- now = now or time.time()
- runtime = self.time_start and (time.time() - self.time_start) or 0
- info(self.success_msg.strip(), {
- "id": self.id, "name": self.name,
- "return_value": self.repr_result(ret_value),
- "runtime": runtime})
+ runtime = 0 #self.time_start and (time.time() - self.time_start) or 0
+ self.eventer.send("task-succeeded", uuid=self.id,
+ result=safe_repr(ret_value), runtime=runtime)
def on_retry(self, exc_info):
"""Handler called if the task should be retried."""
- self.send_event("task-retried", uuid=self.id,
- exception=safe_repr(exc_info.exception.exc),
- traceback=safe_str(exc_info.traceback))
+ if self.eventer and self.eventer.enabled:
+ self.eventer.send("task-retried", uuid=self.id,
+ exception=safe_repr(exc_info.exception.exc),
+ traceback=safe_str(exc_info.traceback))
if _does_info:
info(self.retry_msg.strip(), {
@@ -360,7 +345,7 @@ def on_retry(self, exc_info):
def on_failure(self, exc_info):
"""Handler called if the task raised an exception."""
- state.task_ready(self)
+ task_ready(self)
if not exc_info.internal:
@@ -387,15 +372,16 @@ def _log_error(self, einfo):
safe_repr(self.args),
safe_repr(self.kwargs),
)
- format = self.error_msg
+ format = error_msg
description = "raised exception"
severity = logging.ERROR
- self.send_event("task-failed", uuid=self.id,
- exception=exception,
- traceback=traceback)
+ if self.eventer and self.eventer.enabled:
+ self.eventer.send("task-failed", uuid=self.id,
+ exception=exception,
+ traceback=traceback)
if internal:
- format = self.internal_error_msg
+ format = internal_error_msg
description = "INTERNAL ERROR"
severity = logging.CRITICAL
@@ -427,18 +413,12 @@ def acknowledge(self):
self.on_ack(logger, self.connection_errors)
self.acknowledged = True
- def repr_result(self, result, maxlen=46):
- # 46 is the length needed to fit
- # "the quick brown fox jumps over the lazy dog" :)
- return truncate(safe_repr(result), maxlen)
-
def info(self, safe=False):
return {"id": self.id,
"name": self.name,
"args": self.args if safe else safe_repr(self.args),
"kwargs": self.kwargs if safe else safe_repr(self.kwargs),
"hostname": self.hostname,
- "time_start": self.time_start,
"acknowledged": self.acknowledged,
"delivery_info": self.delivery_info,
"worker_pid": self.worker_pid}

0 comments on commit c9b4011

Please sign in to comment.
Something went wrong with that request. Please try again.