Browse files

current_task and Task.request are now LocalStack's. Maybe fixes #521

  • Loading branch information...
1 parent 55bbcc4 commit e2b052c2d8e4b25e43575ba1799b2d292eedc7cd @ask committed May 16, 2012
View
2 celery/__compat__.py
@@ -89,7 +89,7 @@ class class_property(object):
def __init__(self, fget=None, fset=None):
assert fget and isinstance(fget, classmethod)
- assert fset and isinstance(fset, classmethod)
+ assert isinstance(fset, classmethod) if fset else True
self.__get = fget
self.__set = fset
View
9 celery/app/state.py
@@ -2,7 +2,7 @@
import threading
-from celery.local import Proxy
+from celery.local import Proxy, LocalStack
default_app = None
@@ -12,11 +12,10 @@ class _TLS(threading.local):
#: sets this, so it will always contain the last instantiated app,
#: and is the default app returned by :func:`app_or_default`.
current_app = None
-
- #: The currently executing task.
- current_task = None
_tls = _TLS()
+_task_stack = LocalStack()
+
def set_default_app(app):
global default_app
@@ -28,7 +27,7 @@ def get_current_app():
def get_current_task():
- return getattr(_tls, "current_task", None)
+ return _task_stack.top
current_app = Proxy(get_current_app)
View
27 celery/app/task.py
@@ -14,7 +14,6 @@
import logging
import sys
-import threading
from kombu import Exchange
from kombu.utils import cached_property
@@ -24,6 +23,7 @@
from celery.__compat__ import class_property
from celery.datastructures import ExceptionInfo
from celery.exceptions import MaxRetriesExceededError, RetryTaskError
+from celery.local import LocalStack
from celery.result import EagerResult
from celery.utils import fun_takes_kwargs, uuid, maybe_reraise
from celery.utils.functional import mattrgetter, maybe_list
@@ -43,7 +43,7 @@
"compression", "expires")
-class Context(threading.local):
+class Context(object):
# Default context
logfile = None
loglevel = None
@@ -61,8 +61,11 @@ class Context(threading.local):
errbacks = None
_children = None # see property
- def update(self, d, **kwargs):
- self.__dict__.update(d, **kwargs)
+ def __init__(self, *args, **kwargs):
+ self.update(*args, **kwargs)
+
+ def update(self, *args, **kwargs):
+ self.__dict__.update(*args, **kwargs)
def clear(self):
self.__dict__.clear()
@@ -172,9 +175,6 @@ class BaseTask(object):
#: Deprecated and scheduled for removal in v3.0.
accept_magic_kwargs = False
- #: Request context (set when task is applied).
- request = Context()
-
#: Destination queue. The queue needs to exist
#: in :setting:`CELERY_QUEUES`. The `routing_key`, `exchange` and
#: `exchange_type` attributes will be ignored if this is set.
@@ -324,6 +324,9 @@ def bind(self, app):
if not was_bound:
self.annotate()
+ self.request_stack = LocalStack()
+ self.request_stack.push(Context())
+
# PeriodicTask uses this to add itself to the PeriodicTask schedule.
self.on_bound(app)
@@ -845,6 +848,12 @@ def execute(self, request, pool, loglevel, logfile, **kwargs):
"""
request.execute_using_pool(pool, loglevel, logfile)
+ def push_request(self, *args, **kwargs):
+ self.request_stack.push(Context(*args, **kwargs))
+
+ def pop_request(self):
+ self.request_stack.pop()
+
def __repr__(self):
"""`repr(task)`"""
return "<@task: %s>" % (self.name, )
@@ -854,5 +863,9 @@ def logger(self):
return self.get_logger()
@property
+ def request(self):
+ return self.request_stack.top
+
+ @property
def __name__(self):
return self.__class__.__name__
View
4 celery/contrib/batches.py
@@ -79,14 +79,14 @@ def consume_queue(queue):
def apply_batches_task(task, args, loglevel, logfile):
- task.request.update({"loglevel": loglevel, "logfile": logfile})
+ task.push_request(loglevel=loglevel, logfile=logfile)
try:
result = task(*args)
except Exception, exc:
result = None
task.logger.error("Error: %r", exc, exc_info=True)
finally:
- task.request.clear()
+ task.pop_request()
return result
View
5 celery/events/__init__.py
@@ -91,6 +91,8 @@ def __init__(self, connection=None, hostname=None, enabled=True,
self.on_disabled = set()
self.enabled = enabled
+ if not connection and channel:
+ self.connection = channel.connection.client
if self.enabled:
self.enable()
@@ -151,8 +153,7 @@ def copy_buffer(self, other):
def close(self):
"""Close the event dispatcher."""
self.mutex.locked() and self.mutex.release()
- if self.publisher is not None:
- self.publisher = None
+ self.publisher = None
class EventReceiver(object):
View
212 celery/local.py
@@ -7,12 +7,25 @@
needs to be loaded as soon as possible, and that
shall not load any third party modules.
+ Parts of this module is Copyright by Werkzeug Team.
+
:copyright: (c) 2009 - 2012 by Ask Solem.
:license: BSD, see LICENSE for more details.
"""
from __future__ import absolute_import
+# since each thread has its own greenlet we can just use those as identifiers
+# for the context. If greenlets are not available we fall back to the
+# current thread ident.
+try:
+ from greenlet import getcurrent as get_ident
+except ImportError: # pragma: no cover
+ try:
+ from thread import get_ident # noqa
+ except ImportError: # pragma: no cover
+ from dummy_thread import get_ident # noqa
+
def try_import(module, default=None):
"""Try to import and return module, or return
@@ -201,3 +214,202 @@ def maybe_evaluate(obj):
return obj.__maybe_evaluate__()
except AttributeError:
return obj
+
+
+def release_local(local):
+ """Releases the contents of the local for the current context.
+ This makes it possible to use locals without a manager.
+
+ Example::
+
+ >>> loc = Local()
+ >>> loc.foo = 42
+ >>> release_local(loc)
+ >>> hasattr(loc, 'foo')
+ False
+
+ With this function one can release :class:`Local` objects as well
+ as :class:`StackLocal` objects. However it is not possible to
+ release data held by proxies that way, one always has to retain
+ a reference to the underlying local object in order to be able
+ to release it.
+
+ .. versionadded:: 0.6.1
+ """
+ local.__release_local__()
+
+
+class Local(object):
+ __slots__ = ('__storage__', '__ident_func__')
+
+ def __init__(self):
+ object.__setattr__(self, '__storage__', {})
+ object.__setattr__(self, '__ident_func__', get_ident)
+
+ def __iter__(self):
+ return iter(self.__storage__.items())
+
+ def __call__(self, proxy):
+ """Create a proxy for a name."""
+ return Proxy(self, proxy)
+
+ def __release_local__(self):
+ self.__storage__.pop(self.__ident_func__(), None)
+
+ def __getattr__(self, name):
+ try:
+ return self.__storage__[self.__ident_func__()][name]
+ except KeyError:
+ raise AttributeError(name)
+
+ def __setattr__(self, name, value):
+ ident = self.__ident_func__()
+ storage = self.__storage__
+ try:
+ storage[ident][name] = value
+ except KeyError:
+ storage[ident] = {name: value}
+
+ def __delattr__(self, name):
+ try:
+ del self.__storage__[self.__ident_func__()][name]
+ except KeyError:
+ raise AttributeError(name)
+
+
+class LocalStack(object):
+ """This class works similar to a :class:`Local` but keeps a stack
+ of objects instead. This is best explained with an example::
+
+ >>> ls = LocalStack()
+ >>> ls.push(42)
+ >>> ls.top
+ 42
+ >>> ls.push(23)
+ >>> ls.top
+ 23
+ >>> ls.pop()
+ 23
+ >>> ls.top
+ 42
+
+ They can be force released by using a :class:`LocalManager` or with
+ the :func:`release_local` function but the correct way is to pop the
+ item from the stack after using. When the stack is empty it will
+ no longer be bound to the current context (and as such released).
+
+ By calling the stack without arguments it returns a proxy that resolves to
+ the topmost item on the stack.
+
+ """
+
+ def __init__(self):
+ self._local = Local()
+
+ def __release_local__(self):
+ self._local.__release_local__()
+
+ def _get__ident_func__(self):
+ return self._local.__ident_func__
+
+ def _set__ident_func__(self, value):
+ object.__setattr__(self._local, '__ident_func__', value)
+ __ident_func__ = property(_get__ident_func__, _set__ident_func__)
+ del _get__ident_func__, _set__ident_func__
+
+ def __call__(self):
+ def _lookup():
+ rv = self.top
+ if rv is None:
+ raise RuntimeError('object unbound')
+ return rv
+ return Proxy(_lookup)
+
+ def push(self, obj):
+ """Pushes a new item to the stack"""
+ rv = getattr(self._local, 'stack', None)
+ if rv is None:
+ self._local.stack = rv = []
+ rv.append(obj)
+ return rv
+
+ def pop(self):
+ """Removes the topmost item from the stack, will return the
+ old value or `None` if the stack was already empty.
+ """
+ stack = getattr(self._local, 'stack', None)
+ if stack is None:
+ return None
+ elif len(stack) == 1:
+ release_local(self._local)
+ return stack[-1]
+ else:
+ return stack.pop()
+
+ @property
+ def top(self):
+ """The topmost item on the stack. If the stack is empty,
+ `None` is returned.
+ """
+ try:
+ return self._local.stack[-1]
+ except (AttributeError, IndexError):
+ return None
+
+
+class LocalManager(object):
+ """Local objects cannot manage themselves. For that you need a local
+ manager. You can pass a local manager multiple locals or add them later
+ by appending them to `manager.locals`. Everytime the manager cleans up
+ it, will clean up all the data left in the locals for this context.
+
+ The `ident_func` parameter can be added to override the default ident
+ function for the wrapped locals.
+
+ .. versionchanged:: 0.6.1
+ Instead of a manager the :func:`release_local` function can be used
+ as well.
+
+ .. versionchanged:: 0.7
+ `ident_func` was added.
+ """
+
+ def __init__(self, locals=None, ident_func=None):
+ if locals is None:
+ self.locals = []
+ elif isinstance(locals, Local):
+ self.locals = [locals]
+ else:
+ self.locals = list(locals)
+ if ident_func is not None:
+ self.ident_func = ident_func
+ for local in self.locals:
+ object.__setattr__(local, '__ident_func__', ident_func)
+ else:
+ self.ident_func = get_ident
+
+ def get_ident(self):
+ """Return the context identifier the local objects use internally for
+ this context. You cannot override this method to change the behavior
+ but use it to link other context local objects (such as SQLAlchemy's
+ scoped sessions) to the Werkzeug locals.
+
+ .. versionchanged:: 0.7
+ You can pass a different ident function to the local manager that
+ will then be propagated to all the locals passed to the
+ constructor.
+ """
+ return self.ident_func()
+
+ def cleanup(self):
+ """Manually clean up the data in the locals for this context. Call
+ this at the end of the request or use `make_middleware()`.
+ """
+ for local in self.locals:
+ release_local(local)
+
+ def __repr__(self):
+ return '<%s storages: %d>' % (
+ self.__class__.__name__,
+ len(self.locals)
+ )
View
10 celery/task/base.py
@@ -14,14 +14,15 @@
from __future__ import absolute_import
from celery import current_app
-from celery.__compat__ import reclassmethod
+from celery.__compat__ import class_property, reclassmethod
from celery.app.task import Context, TaskType, BaseTask # noqa
from celery.schedules import maybe_schedule
#: list of methods that must be classmethods in the old API.
_COMPAT_CLASSMETHODS = (
"get_logger", "establish_connection", "get_publisher", "get_consumer",
- "delay", "apply_async", "retry", "apply", "AsyncResult", "subtask")
+ "delay", "apply_async", "retry", "apply", "AsyncResult", "subtask",
+ "push_request", "pop_request")
class Task(BaseTask):
@@ -36,6 +37,11 @@ class Task(BaseTask):
for name in _COMPAT_CLASSMETHODS:
locals()[name] = reclassmethod(getattr(BaseTask, name))
+ @classmethod
+ def _get_request(self):
+ return self.request_stack.top
+ request = class_property(_get_request)
+
class PeriodicTask(Task):
"""A periodic task is a task that adds itself to the
View
21 celery/task/trace.py
@@ -28,8 +28,8 @@
from celery import current_app
from celery import states, signals
-from celery.app.state import _tls
-from celery.app.task import BaseTask
+from celery.app.state import _task_stack
+from celery.app.task import BaseTask, Context
from celery.datastructures import ExceptionInfo
from celery.exceptions import RetryTaskError
from celery.utils.serialization import get_pickleable_exception
@@ -146,15 +146,15 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True,
task_on_success = task.on_success
task_after_return = task.after_return
- task_request = task.request
store_result = backend.store_result
backend_cleanup = backend.process_cleanup
pid = os.getpid()
- update_request = task_request.update
- clear_request = task_request.clear
+ request_stack = task.request_stack
+ push_request = request_stack.push
+ pop_request = request_stack.pop
on_chord_part_return = backend.on_chord_part_return
from celery import canvas
@@ -164,9 +164,10 @@ def trace_task(uuid, args, kwargs, request=None):
R = I = None
kwargs = kwdict(kwargs)
try:
- _tls.current_task = task
- update_request(request or {}, args=args,
- called_directly=False, kwargs=kwargs)
+ _task_stack.push(task)
+ task_request = Context(request or {}, args=args,
+ called_directly=False, kwargs=kwargs)
+ push_request(task_request)
try:
# -*- PRE -*-
send_prerun(sender=task, task_id=uuid, task=task,
@@ -220,8 +221,8 @@ def trace_task(uuid, args, kwargs, request=None):
send_postrun(sender=task, task_id=uuid, task=task,
args=args, kwargs=kwargs, retval=retval)
finally:
- _tls.current_task = None
- clear_request()
+ _task_stack.pop()
+ pop_request()
if not eager:
try:
backend_cleanup()
View
6 celery/tests/app/test_builtins.py
@@ -4,7 +4,7 @@
from celery import current_app as app, group, task, chord
from celery.app import builtins
-from celery.app.state import _tls
+from celery.app.state import _task_stack
from celery.tests.utils import Case
@@ -61,13 +61,13 @@ def test_apply_async(self):
x.apply_async()
def test_apply_async_with_parent(self):
- _tls.current_task = add
+ _task_stack.push(add)
try:
x = group([add.s(4, 4), add.s(8, 8)])
x.apply_async()
self.assertTrue(add.request.children)
finally:
- _tls.current_task = None
+ _task_stack.pop()
class test_chain(Case):
View
8 celery/tests/app/test_log.py
@@ -230,12 +230,12 @@ def test_task():
pass
test_task.logger.handlers = []
self.task = test_task
- from celery.app.state import _tls
- _tls.current_task = test_task
+ from celery.app.state import _task_stack
+ _task_stack.push(test_task)
def tearDown(self):
- from celery.app.state import _tls
- _tls.current_task = None
+ from celery.app.state import _task_stack
+ _task_stack.pop()
def setup_logger(self, *args, **kwargs):
return log.setup_task_loggers(*args, **kwargs)
View
11 celery/tests/events/test_events.py
@@ -132,10 +132,11 @@ def test_process(self):
def my_handler(event):
got_event[0] = True
- r = events.EventReceiver(object(),
+ connection = Mock()
+ connection.transport_cls = "memory"
+ r = events.EventReceiver(connection,
handlers={"world-war": my_handler},
- node_id="celery.tests",
- )
+ node_id="celery.tests")
r._receive(message, object())
self.assertTrue(got_event[0])
@@ -148,7 +149,9 @@ def test_catch_all_event(self):
def my_handler(event):
got_event[0] = True
- r = events.EventReceiver(object(), node_id="celery.tests")
+ connection = Mock()
+ connection.transport_cls = "memory"
+ r = events.EventReceiver(connection, node_id="celery.tests")
events.EventReceiver.handlers["*"] = my_handler
try:
r._receive(message, object())
View
94 celery/tests/tasks/test_context.py
@@ -1,8 +1,6 @@
# -*- coding: utf-8 -*-"
from __future__ import absolute_import
-import threading
-
from celery.task.base import Context
from celery.tests.utils import Case
@@ -22,21 +20,6 @@ def get_context_as_dict(ctx, getter=getattr):
default_context = get_context_as_dict(Context())
-# Manipulate the a context in a separate thread
-class ContextManipulator(threading.Thread):
- def __init__(self, ctx, *args):
- super(ContextManipulator, self).__init__()
- self.daemon = True
- self.ctx = ctx
- self.args = args
- self.result = None
-
- def run(self):
- for func, args in self.args:
- func(self.ctx, *args)
- self.result = get_context_as_dict(self.ctx)
-
-
class test_Context(Case):
def test_default_context(self):
@@ -45,14 +28,6 @@ def test_default_context(self):
defaults = dict(default_context, children=[])
self.assertDictEqual(get_context_as_dict(Context()), defaults)
- def test_default_context_threaded(self):
- ctx = Context()
- worker = ContextManipulator(ctx)
- worker.start()
- worker.join()
- self.assertDictEqual(worker.result, default_context)
- self.assertDictEqual(get_context_as_dict(ctx), default_context)
-
def test_updated_context(self):
expected = dict(default_context)
changes = dict(id="unique id", args=["some", 1], wibble="wobble")
@@ -62,26 +37,6 @@ def test_updated_context(self):
self.assertDictEqual(get_context_as_dict(ctx), expected)
self.assertDictEqual(get_context_as_dict(Context()), default_context)
- def test_updated_contex_threadedt(self):
- expected_a = dict(default_context)
- changes_a = dict(id="a", args=["some", 1], wibble="wobble")
- expected_a.update(changes_a)
- expected_b = dict(default_context)
- changes_b = dict(id="b", args=["other", 2], weasel="woozle")
- expected_b.update(changes_b)
- ctx = Context()
-
- worker_a = ContextManipulator(ctx, (Context.update, [changes_a]))
- worker_b = ContextManipulator(ctx, (Context.update, [changes_b]))
- worker_a.start()
- worker_b.start()
- worker_a.join()
- worker_b.join()
-
- self.assertDictEqual(worker_a.result, expected_a)
- self.assertDictEqual(worker_b.result, expected_b)
- self.assertDictEqual(get_context_as_dict(ctx), default_context)
-
def test_modified_context(self):
expected = dict(default_context)
ctx = Context()
@@ -92,34 +47,6 @@ def test_modified_context(self):
self.assertDictEqual(get_context_as_dict(ctx), expected)
self.assertDictEqual(get_context_as_dict(Context()), default_context)
- def test_modified_contex_threadedt(self):
- expected_a = dict(default_context)
- expected_a["id"] = "a"
- expected_a["args"] = ["some", 1]
- expected_a["wibble"] = "wobble"
- expected_b = dict(default_context)
- expected_b["id"] = "b"
- expected_b["args"] = ["other", 2]
- expected_b["weasel"] = "woozle"
- ctx = Context()
-
- worker_a = ContextManipulator(ctx,
- (setattr, ["id", "a"]),
- (setattr, ["args", ["some", 1]]),
- (setattr, ["wibble", "wobble"]))
- worker_b = ContextManipulator(ctx,
- (setattr, ["id", "b"]),
- (setattr, ["args", ["other", 2]]),
- (setattr, ["weasel", "woozle"]))
- worker_a.start()
- worker_b.start()
- worker_a.join()
- worker_b.join()
-
- self.assertDictEqual(worker_a.result, expected_a)
- self.assertDictEqual(worker_b.result, expected_b)
- self.assertDictEqual(get_context_as_dict(ctx), default_context)
-
def test_cleared_context(self):
changes = dict(id="unique id", args=["some", 1], wibble="wobble")
ctx = Context()
@@ -129,27 +56,6 @@ def test_cleared_context(self):
self.assertDictEqual(get_context_as_dict(ctx), defaults)
self.assertDictEqual(get_context_as_dict(Context()), defaults)
- def test_cleared_context_threaded(self):
- changes_a = dict(id="a", args=["some", 1], wibble="wobble")
- expected_b = dict(default_context)
- changes_b = dict(id="b", args=["other", 2], weasel="woozle")
- expected_b.update(changes_b)
- ctx = Context()
-
- worker_a = ContextManipulator(ctx,
- (Context.update, [changes_a]),
- (Context.clear, []))
- worker_b = ContextManipulator(ctx,
- (Context.update, [changes_b]))
- worker_a.start()
- worker_b.start()
- worker_a.join()
- worker_b.join()
-
- self.assertDictEqual(worker_a.result, default_context)
- self.assertDictEqual(worker_b.result, expected_b)
- self.assertDictEqual(get_context_as_dict(ctx), default_context)
-
def test_context_get(self):
expected = dict(default_context)
changes = dict(id="unique id", args=["some", 1], wibble="wobble")
View
8 celery/tests/tasks/test_sets.py
@@ -148,13 +148,13 @@ def send(self, *args, **kwargs):
@current_app.task
def xyz():
pass
- from celery.app.state import _tls
- _tls.current_task = xyz
+ from celery.app.state import _task_stack
+ _task_stack.push(xyz)
try:
ts.apply_async(publisher=Publisher())
finally:
- _tls.current_task = None
- xyz.request.clear()
+ _task_stack.pop()
+ xyz.pop_request()
def test_apply(self):
View
30 celery/tests/worker/test_request.py
@@ -67,8 +67,9 @@ def mro(cls):
def jail(task_id, name, args, kwargs):
+ request = {"id": task_id}
return eager_trace_task(current_app.tasks[name],
- task_id, args, kwargs, eager=False)[0]
+ task_id, args, kwargs, request=request, eager=False)[0]
def on_ack(*args, **kwargs):
@@ -196,26 +197,16 @@ def store_result(self, tid, meta, state):
mytask.ignore_result = False
def test_execute_jail_failure(self):
- u = uuid()
- mytask_raising.request.update({"id": u})
- try:
- ret = jail(u, mytask_raising.name,
- [4], {})
- self.assertIsInstance(ret, ExceptionInfo)
- self.assertTupleEqual(ret.exception.args, (4, ))
- finally:
- mytask_raising.request.clear()
+ ret = jail(uuid(), mytask_raising.name,
+ [4], {})
+ self.assertIsInstance(ret, ExceptionInfo)
+ self.assertTupleEqual(ret.exception.args, (4, ))
def test_execute_ignore_result(self):
task_id = uuid()
- MyTaskIgnoreResult.request.update({"id": task_id})
- try:
- ret = jail(task_id, MyTaskIgnoreResult.name,
- [4], {})
- self.assertEqual(ret, 256)
- self.assertFalse(AsyncResult(task_id).ready())
- finally:
- MyTaskIgnoreResult.request.clear()
+ ret = jail(task_id, MyTaskIgnoreResult.name, [4], {})
+ self.assertEqual(ret, 256)
+ self.assertFalse(AsyncResult(task_id).ready())
class MockEventDispatcher(object):
@@ -557,10 +548,9 @@ def test_execute_safe_catches_exception(self):
def _error_exec(self, *args, **kwargs):
raise KeyError("baz")
- @task_dec
+ @task_dec(request=None)
def raising():
raise KeyError("baz")
- raising.request = None
with self.assertWarnsRegex(RuntimeWarning,
r'Exception raised outside'):

0 comments on commit e2b052c

Please sign in to comment.