Permalink
Browse files

99% overall coverage :happy:

  • Loading branch information...
1 parent 3a0bb50 commit 53b61c638b15a6cd45e685f4e81f03b3a61e979a @ask committed Apr 27, 2012
Showing with 1,071 additions and 195 deletions.
  1. +2 −1 celery/app/task.py
  2. +9 −7 celery/apps/worker.py
  3. +1 −1 celery/backends/base.py
  4. +1 −0 celery/backends/redis.py
  5. +30 −35 celery/beat.py
  6. +1 −1 celery/bin/base.py
  7. +1 −1 celery/bin/celeryd.py
  8. +1 −1 celery/events/state.py
  9. +1 −1 celery/result.py
  10. +1 −1 celery/task/http.py
  11. +6 −1 celery/tests/__init__.py
  12. +31 −0 celery/tests/app/test_amqp.py
  13. +5 −1 celery/tests/app/test_app.py
  14. +129 −18 celery/tests/app/test_beat.py
  15. +15 −0 celery/tests/app/test_control.py
  16. +12 −0 celery/tests/app/test_defaults.py
  17. +15 −3 celery/tests/app/test_loaders.py
  18. +39 −6 celery/tests/app/test_log.py
  19. +4 −0 celery/tests/backends/test_amqp.py
  20. +24 −0 celery/tests/backends/test_base.py
  21. +7 −1 celery/tests/backends/test_database.py
  22. +2 −0 celery/tests/backends/test_pyredis_compat.py
  23. +33 −0 celery/tests/backends/test_redis.py
  24. +31 −1 celery/tests/bin/test_base.py
  25. +6 −0 celery/tests/bin/test_celerybeat.py
  26. +19 −1 celery/tests/bin/test_celeryd.py
  27. +2 −4 celery/tests/bin/test_celeryd_multi.py
  28. +21 −0 celery/tests/bin/test_celeryev.py
  29. +16 −0 celery/tests/concurrency/test_concurrency.py
  30. +19 −8 celery/tests/concurrency/test_processes.py
  31. +36 −15 celery/tests/events/test_events.py
  32. +18 −3 celery/tests/events/test_snapshot.py
  33. +5 −0 celery/tests/events/test_state.py
  34. +6 −10 celery/tests/slow/test_buckets.py
  35. +8 −0 celery/tests/tasks/test_registry.py
  36. +136 −1 celery/tests/tasks/test_result.py
  37. +34 −1 celery/tests/tasks/test_tasks.py
  38. +139 −0 celery/tests/utilities/test_dispatcher.py
  39. +11 −0 celery/tests/utilities/test_imports.py
  40. +79 −0 celery/tests/utilities/test_saferef.py
  41. +29 −1 celery/tests/utilities/test_timeutils.py
  42. +16 −1 celery/tests/utilities/test_utils.py
  43. +19 −0 celery/tests/utils.py
  44. +2 −4 celery/tests/worker/test_autoreload.py
  45. +8 −13 celery/tests/worker/test_worker.py
  46. +19 −19 celery/utils/compat.py
  47. +4 −4 celery/utils/dispatch/saferef.py
  48. +2 −7 celery/utils/dispatch/signal.py
  49. +4 −4 celery/utils/functional.py
  50. +2 −2 celery/utils/log.py
  51. +2 −4 celery/utils/mail.py
  52. +1 −1 celery/utils/serialization.py
  53. +0 −1 celery/utils/timeutils.py
  54. +2 −1 contrib/release/doc4allmods
  55. +3 −6 contrib/release/py3k-run-tests
  56. +1 −1 requirements/test.txt
  57. +0 −2 setup.cfg
  58. +1 −1 setup.py
View
3 celery/app/task.py
@@ -817,7 +817,8 @@ def execute(self, request, pool, loglevel, logfile, **kwargs):
def annotate(self):
for d in resolve_all_annotations(self.app.annotations, self):
- self.__dict__.update(d)
+ for key, value in d.iteritems():
+ setattr(self, key, value)
def __repr__(self):
"""`repr(task)`"""
View
16 celery/apps/worker.py
@@ -25,7 +25,7 @@
try:
from greenlet import GreenletExit
IGNORE_ERRORS = (GreenletExit, )
-except ImportError:
+except ImportError: # pragma: no cover
IGNORE_ERRORS = ()
logger = get_logger(__name__)
@@ -302,15 +302,17 @@ def install_cry_handler():
# Jython/PyPy does not have sys._current_frames
is_jython = sys.platform.startswith("java")
is_pypy = hasattr(sys, "pypy_version_info")
- if not (is_jython or is_pypy):
+ if is_jython or is_pypy: # pragma: no cover
+ return
- def cry_handler(signum, frame):
- """Signal handler logging the stacktrace of all active threads."""
- logger.error("\n" + cry())
- platforms.signals["SIGUSR1"] = cry_handler
+ def cry_handler(signum, frame):
+ """Signal handler logging the stacktrace of all active threads."""
+ logger.error("\n" + cry())
+ platforms.signals["SIGUSR1"] = cry_handler
-def install_rdb_handler(envvar="CELERY_RDBSIG", sig="SIGUSR2"):
+def install_rdb_handler(envvar="CELERY_RDBSIG",
+ sig="SIGUSR2"): # pragma: no cover
def rdb_handler(signum, frame):
"""Signal handler setting a rdb breakpoint at the current frame."""
View
2 celery/backends/base.py
@@ -120,7 +120,7 @@ def exception_to_python(self, exc):
if self.serializer in EXCEPTION_ABLE_CODECS:
return get_pickled_exception(exc)
return create_exception_cls(from_utf8(exc["exc_type"]),
- sys.modules[__name__])
+ sys.modules[__name__])(exc["exc_message"])
def prepare_value(self, result):
"""Prepare value for storage."""
View
1 celery/backends/redis.py
@@ -62,6 +62,7 @@ def _get(key):
uhost = uport = upass = udb = None
if url:
_, uhost, uport, _, upass, udb, _ = _parse_url(url)
+ udb = udb.strip("/")
self.host = uhost or host or _get("HOST") or self.host
self.port = int(uport or port or _get("PORT") or self.port)
self.db = udb or db or _get("DB") or self.db
View
65 celery/beat.py
@@ -16,10 +16,9 @@
import time
import shelve
import sys
-import threading
import traceback
-from billiard import Process
+from billiard import Process, ensure_multiprocessing
from kombu.utils import reprcall
from kombu.utils.functional import maybe_promise
@@ -31,6 +30,7 @@
from .schedules import maybe_schedule, crontab
from .utils import cached_property
from .utils.imports import instantiate
+from .utils.threads import Event, Thread
from .utils.timeutils import humanize_seconds
from .utils.log import get_logger
@@ -229,12 +229,12 @@ def apply_async(self, entry, publisher=None, **kwargs):
raise SchedulingError, SchedulingError(
"Couldn't apply scheduled task %s: %s" % (
entry.name, exc)), sys.exc_info()[2]
-
- if self.should_sync():
- self._do_sync()
+ finally:
+ if self.should_sync():
+ self._do_sync()
return result
- def send_task(self, *args, **kwargs): # pragma: no cover
+ def send_task(self, *args, **kwargs):
return self.app.send_task(*args, **kwargs)
def setup_schedule(self):
@@ -283,12 +283,6 @@ def merge_inplace(self, b):
else:
schedule[key] = entry
- def get_schedule(self):
- return self.data
-
- def set_schedule(self, schedule):
- self.data = schedule
-
def _ensure_connected(self):
# callback called for each retry while the connection
# can't be established.
@@ -299,6 +293,13 @@ def _error_handler(exc, interval):
return self.connection.ensure_connection(_error_handler,
self.app.conf.BROKER_CONNECTION_MAX_RETRIES)
+ def get_schedule(self):
+ return self.data
+
+ def set_schedule(self, schedule):
+ self.data = schedule
+ schedule = property(get_schedule, set_schedule)
+
@cached_property
def connection(self):
return self.app.broker_connection()
@@ -308,16 +309,13 @@ def publisher(self):
return self.Publisher(connection=self._ensure_connected())
@property
- def schedule(self):
- return self.get_schedule()
-
- @property
def info(self):
return ""
class PersistentScheduler(Scheduler):
persistence = shelve
+ known_suffixes = ("", ".db", ".dat", ".bak", ".dir")
_store = None
@@ -326,7 +324,7 @@ def __init__(self, *args, **kwargs):
Scheduler.__init__(self, *args, **kwargs)
def _remove_db(self):
- for suffix in "", ".db", ".dat", ".bak", ".dir":
+ for suffix in self.known_suffixes:
try:
os.remove(self.schedule_filename + suffix)
except OSError, exc:
@@ -358,6 +356,10 @@ def setup_schedule(self):
def get_schedule(self):
return self._store["entries"]
+ def set_schedule(self, schedule):
+ self._store["entries"] = schedule
+ schedule = property(get_schedule, set_schedule)
+
def sync(self):
if self._store is not None:
self._store.sync()
@@ -383,8 +385,8 @@ def __init__(self, max_interval=None, schedule_filename=None,
self.schedule_filename = schedule_filename or \
app.conf.CELERYBEAT_SCHEDULE_FILENAME
- self._is_shutdown = threading.Event()
- self._is_stopped = threading.Event()
+ self._is_shutdown = Event()
+ self._is_stopped = Event()
def start(self, embedded_process=False):
info("Celerybeat: Starting...")
@@ -397,7 +399,7 @@ def start(self, embedded_process=False):
platforms.set_process_title("celerybeat")
try:
- while not self._is_shutdown.isSet():
+ while not self._is_shutdown.is_set():
interval = self.scheduler.tick()
debug("Celerybeat: Waking up %s.",
humanize_seconds(interval, prefix="in "))
@@ -430,14 +432,14 @@ def scheduler(self):
return self.get_scheduler()
-class _Threaded(threading.Thread):
+class _Threaded(Thread):
"""Embedded task scheduler using threading."""
def __init__(self, *args, **kwargs):
super(_Threaded, self).__init__()
self.service = Service(*args, **kwargs)
- self.setDaemon(True)
- self.setName("Beat")
+ self.daemon = True
+ self.name = "Beat"
def run(self):
self.service.start()
@@ -446,16 +448,12 @@ def stop(self):
self.service.stop(wait=True)
-supports_fork = True
try:
- from billiard._ext import _billiard
- supports_fork = True if _billiard else False
-except ImportError:
- supports_fork = False
-
-if supports_fork:
- class _Process(Process):
- """Embedded task scheduler using multiprocessing."""
+ ensure_multiprocessing()
+except NotImplementedError: # pragma: no cover
+ _Process = None
+else:
+ class _Process(Process): # noqa
def __init__(self, *args, **kwargs):
super(_Process, self).__init__()
@@ -469,8 +467,6 @@ def run(self):
def stop(self):
self.service.stop()
self.terminate()
-else:
- _Process = None
def EmbeddedService(*args, **kwargs):
@@ -485,5 +481,4 @@ def EmbeddedService(*args, **kwargs):
# in reasonable time.
kwargs.setdefault("max_interval", 1)
return _Threaded(*args, **kwargs)
-
return _Process(*args, **kwargs)
View
2 celery/bin/base.py
@@ -139,7 +139,7 @@ def parse_options(self, prog_name, arguments):
# Don't want to load configuration to just print the version,
# so we handle --version manually here.
if "--version" in arguments:
- print(self.version)
+ sys.stdout.write("%s\n" % self.version)
sys.exit(0)
parser = self.create_parser(prog_name)
return parser.parse_args(arguments)
View
2 celery/bin/celeryd.py
@@ -188,7 +188,7 @@ def main():
# Fix for setuptools generated scripts, so that it will
# work with multiprocessing fork emulation.
# (see multiprocessing.forking.get_preparation_data())
- if __name__ != "__main__":
+ if __name__ != "__main__": # pragma: no cover
sys.modules["__main__"] = sys.modules[__name__]
freeze_support()
worker = WorkerCommand()
View
2 celery/events/state.py
@@ -299,7 +299,7 @@ def _dispatch_event(self, event):
def itertasks(self, limit=None):
for index, row in enumerate(self.tasks.iteritems()):
yield row
- if limit and index >= limit:
+ if limit and index + 1 >= limit:
break
def tasks_by_timestamp(self, limit=None):
View
2 celery/result.py
@@ -632,7 +632,7 @@ def _get_taskset_id(self):
return self.id
def _set_taskset_id(self, id):
- self.taskset_id = id
+ self.id = id
taskset_id = property(_get_taskset_id, _set_taskset_id)
View
2 celery/task/http.py
@@ -47,7 +47,7 @@ def maybe_utf8(value):
return value
-if sys.version_info >= (3, 0):
+if sys.version_info[0] == 3: # pragma: no cover
def utf8dict(tup):
if not isinstance(tup, dict):
View
7 celery/tests/__init__.py
@@ -1,8 +1,10 @@
from __future__ import absolute_import
+from __future__ import with_statement
import logging
import os
import sys
+import warnings
from importlib import import_module
@@ -77,4 +79,7 @@ def import_all_modules(name=__name__, file=__file__,
if os.environ.get("COVER_ALL_MODULES") or "--with-coverage3" in sys.argv:
- import_all_modules()
+ from celery.tests.utils import catch_warnings
+ with catch_warnings(record=True):
+ import_all_modules()
+ warnings.resetwarnings()
View
31 celery/tests/app/test_amqp.py
@@ -29,6 +29,21 @@ def test__exit__(self):
pass
publisher.release.assert_called_with()
+ def test_declare(self):
+ publisher = self.app.amqp.TaskPublisher(self.app.broker_connection())
+ publisher.exchange.name = "foo"
+ publisher.declare()
+ publisher.exchange.name = None
+ publisher.declare()
+
+ def test_exit_AttributeError(self):
+ publisher = self.app.amqp.TaskPublisher(self.app.broker_connection())
+ publisher.close = Mock()
+ publisher.release = Mock()
+ publisher.release.side_effect = AttributeError()
+ publisher.__exit__()
+ publisher.close.assert_called_with()
+
def test_ensure_declare_queue(self, q="x1242112"):
publisher = self.app.amqp.TaskPublisher(Mock())
self.app.amqp.queues.add(q, q, q)
@@ -103,3 +118,19 @@ def test_setup(self):
r2.release()
finally:
self.app.conf.BROKER_POOL_LIMIT = L
+
+
+class test_Queues(AppCase):
+
+ def test_queues_format(self):
+ prev, self.app.amqp.queues._consume_from = \
+ self.app.amqp.queues._consume_from, {}
+ try:
+ self.assertEqual(self.app.amqp.queues.format(), "")
+ finally:
+ self.app.amqp.queues._consume_from = prev
+
+ def test_with_defaults(self):
+ self.assertEqual(
+ self.app.amqp.queues.with_defaults(None, "celery", "direct"),
+ {})
View
6 celery/tests/app/test_app.py
@@ -164,7 +164,11 @@ def test_compat_setting_CARROT_BACKEND(self):
self.assertEqual(self.app.conf.BROKER_TRANSPORT, "set_by_us")
def test_WorkController(self):
- x = self.app.Worker()
+ x = self.app.WorkController
+ self.assertIs(x.app, self.app)
+
+ def test_Worker(self):
+ x = self.app.Worker
self.assertIs(x.app, self.app)
def test_AsyncResult(self):
View
147 celery/tests/app/test_beat.py
@@ -1,15 +1,19 @@
from __future__ import absolute_import
+from __future__ import with_statement
+
+import errno
from datetime import datetime, timedelta
-from mock import patch
+from mock import Mock, call, patch
from nose import SkipTest
from celery import beat
+from celery import task
from celery.result import AsyncResult
from celery.schedules import schedule
from celery.task.base import Task
from celery.utils import uuid
-from celery.tests.utils import Case
+from celery.tests.utils import Case, patch_settings
class Object(object):
@@ -159,10 +163,69 @@ def apply_async(cls, *args, **kwargs):
scheduler.apply_async(scheduler.Entry(task=MockTask.name))
self.assertTrue(through_task[0])
+ def test_apply_async_should_not_sync(self):
+
+ @task
+ def not_sync():
+ pass
+ not_sync.apply_async = Mock()
+
+ s = mScheduler()
+ s._do_sync = Mock()
+ s.should_sync = Mock()
+ s.should_sync.return_value = True
+ s.apply_async(s.Entry(task=not_sync.name))
+ s._do_sync.assert_called_with()
+
+ s._do_sync = Mock()
+ s.should_sync.return_value = False
+ s.apply_async(s.Entry(task=not_sync.name))
+ self.assertFalse(s._do_sync.called)
+
+ @patch("celery.app.base.Celery.send_task")
+ def test_send_task(self, send_task):
+ b = beat.Scheduler()
+ b.send_task("tasks.add", countdown=10)
+ send_task.assert_called_with("tasks.add", countdown=10)
+
def test_info(self):
scheduler = mScheduler()
self.assertIsInstance(scheduler.info, basestring)
+ def test_maybe_entry(self):
+ s = mScheduler()
+ entry = s.Entry(name="add every", task="tasks.add")
+ self.assertIs(s._maybe_entry(entry.name, entry), entry)
+ self.assertTrue(s._maybe_entry("add every", {
+ "task": "tasks.add",
+ }))
+
+ def test_set_schedule(self):
+ s = mScheduler()
+ s.schedule = {"foo": "bar"}
+ self.assertEqual(s.data, {"foo": "bar"})
+
+ @patch("kombu.connection.Connection.ensure_connection")
+ def test_ensure_connection_error_handler(self, ensure):
+ s = mScheduler()
+ self.assertTrue(s._ensure_connected())
+ self.assertTrue(ensure.called)
+ callback = ensure.call_args[0][0]
+
+ callback(KeyError(), 5)
+
+ def test_install_default_entries(self):
+ with patch_settings(CELERY_TASK_RESULT_EXPIRES=None,
+ CELERYBEAT_SCHEDULE={}):
+ s = mScheduler()
+ s.install_default_entries({})
+ self.assertNotIn("celery.backend_cleanup", s.data)
+ with patch_settings(CELERY_TASK_RESULT_EXPIRES=30,
+ CELERYBEAT_SCHEDULE={}):
+ s = mScheduler()
+ s.install_default_entries({})
+ self.assertIn("celery.backend_cleanup", s.data)
+
def test_due_tick(self):
scheduler = mScheduler()
scheduler.add(name="test_due_tick",
@@ -233,25 +296,73 @@ def test_merge_inplace(self):
self.assertEqual(a.schedule["bar"].schedule._next_run_at, 40)
+def create_persistent_scheduler(shelv=None):
+ if shelv is None:
+ shelv = MockShelve()
+
+ class MockPersistentScheduler(beat.PersistentScheduler):
+ sh = shelv
+ persistence = Object()
+ persistence.open = lambda *a, **kw: shelv
+ tick_raises_exit = False
+ shutdown_service = None
+
+ def tick(self):
+ if self.tick_raises_exit:
+ raise SystemExit()
+ if self.shutdown_service:
+ self.shutdown_service._is_shutdown.set()
+ return 0.0
+
+ return MockPersistentScheduler, shelv
+
+
+class test_PersistentScheduler(Case):
+
+ @patch("os.remove")
+ def test_remove_db(self, remove):
+ s = create_persistent_scheduler()[0](schedule_filename="schedule")
+ s._remove_db()
+ remove.assert_has_calls(
+ [call("schedule" + suffix) for suffix in s.known_suffixes]
+ )
+ err = OSError()
+ err.errno = errno.ENOENT
+ remove.side_effect = err
+ s._remove_db()
+ err.errno = errno.EPERM
+ with self.assertRaises(OSError):
+ s._remove_db()
+
+ def test_setup_schedule(self):
+ s = create_persistent_scheduler()[0](schedule_filename="schedule")
+ opens = s.persistence.open = Mock()
+ s._remove_db = Mock()
+
+ def effect(*args, **kwargs):
+ if opens.call_count > 1:
+ return s.sh
+ raise OSError()
+ opens.side_effect = effect
+ s.setup_schedule()
+ s._remove_db.assert_called_with()
+
+ s._store = {"__version__": 1}
+ s.setup_schedule()
+
+ def test_get_schedule(self):
+ s = create_persistent_scheduler()[0](schedule_filename="schedule")
+ s._store = {"entries": {}}
+ s.schedule = {"foo": "bar"}
+ self.assertDictEqual(s.schedule, {"foo": "bar"})
+ self.assertDictEqual(s._store["entries"], s.schedule)
+
+
class test_Service(Case):
def get_service(self):
- sh = MockShelve()
-
- class PersistentScheduler(beat.PersistentScheduler):
- persistence = Object()
- persistence.open = lambda *a, **kw: sh
- tick_raises_exit = False
- shutdown_service = None
-
- def tick(self):
- if self.tick_raises_exit:
- raise SystemExit()
- if self.shutdown_service:
- self.shutdown_service._is_shutdown.set()
- return 0.0
-
- return beat.Service(scheduler_cls=PersistentScheduler), sh
+ Scheduler, mock_shelve = create_persistent_scheduler()
+ return beat.Service(scheduler_cls=Scheduler), mock_shelve
def test_start(self):
s, sh = self.get_service()
View
15 celery/tests/app/test_control.py
@@ -117,6 +117,16 @@ def test_cancel_consumer(self):
self.i.cancel_consumer("foo")
self.assertIn("cancel_consumer", MockMailbox.sent)
+ @with_mock_broadcast
+ def test_active_queues(self):
+ self.i.active_queues()
+ self.assertIn("active_queues", MockMailbox.sent)
+
+ @with_mock_broadcast
+ def test_report(self):
+ self.i.report()
+ self.assertIn("report", MockMailbox.sent)
+
class test_Broadcast(Case):
@@ -154,6 +164,11 @@ def test_rate_limit(self):
self.assertIn("rate_limit", MockMailbox.sent)
@with_mock_broadcast
+ def test_time_limit(self):
+ self.control.time_limit(mytask.name, soft=10, hard=20)
+ self.assertIn("time_limit", MockMailbox.sent)
+
+ @with_mock_broadcast
def test_revoke(self):
self.control.revoke("foozbaaz")
self.assertIn("revoke", MockMailbox.sent)
View
12 celery/tests/app/test_defaults.py
@@ -4,6 +4,7 @@
import sys
from importlib import import_module
+from mock import Mock, patch
from celery.tests.utils import Case, pypy_version, sys_platform
@@ -17,6 +18,10 @@ def tearDown(self):
if self._prev:
sys.modules["celery.app.defaults"] = self._prev
+ def test_any(self):
+ val = object()
+ self.assertIs(self.defaults.Option.typemap["any"](val), val)
+
def test_default_pool_pypy_14(self):
with sys_platform("darwin"):
with pypy_version((1, 4, 0)):
@@ -27,6 +32,13 @@ def test_default_pool_pypy_15(self):
with pypy_version((1, 5, 0)):
self.assertEqual(self.defaults.DEFAULT_POOL, "processes")
+ def test_deprecated(self):
+ source = Mock()
+ source.BROKER_INSIST = True
+ with patch("celery.utils.warn_deprecated") as warn:
+ self.defaults.find_deprecated_settings(source)
+ self.assertTrue(warn.called)
+
def test_default_pool_jython(self):
with sys_platform("java 1.6.51"):
self.assertEqual(self.defaults.DEFAULT_POOL, "threads")
View
18 celery/tests/app/test_loaders.py
@@ -4,7 +4,7 @@
import os
import sys
-from mock import patch
+from mock import Mock, patch
from celery import loaders
from celery.app import app_or_default
@@ -83,6 +83,17 @@ def test_handlers_pass(self):
def test_import_task_module(self):
self.assertEqual(sys, self.loader.import_task_module("sys"))
+ def test_init_worker_process(self):
+ self.loader.on_worker_process_init()
+ m = self.loader.on_worker_process_init = Mock()
+ self.loader.init_worker_process()
+ m.assert_called_with()
+
+ def test_config_from_object_module(self):
+ self.loader.import_from_cwd = Mock()
+ self.loader.config_from_object("module_name")
+ self.loader.import_from_cwd.assert_called_with("module_name")
+
def test_conf_property(self):
self.assertEqual(self.loader.conf["foo"], "bar")
self.assertEqual(self.loader._conf["foo"], "bar")
@@ -181,7 +192,7 @@ class ConfigModule(ModuleType):
celeryconfig.CELERY_IMPORTS = ("os", "sys")
configname = os.environ.get("CELERY_CONFIG_MODULE") or "celeryconfig"
- prevconfig = sys.modules[configname]
+ prevconfig = sys.modules.get(configname)
sys.modules[configname] = celeryconfig
try:
l = default.Loader()
@@ -191,7 +202,8 @@ class ConfigModule(ModuleType):
self.assertTupleEqual(settings.CELERY_IMPORTS, ("os", "sys"))
l.on_worker_init()
finally:
- sys.modules[configname] = prevconfig
+ if prevconfig:
+ sys.modules[configname] = prevconfig
def test_import_from_cwd(self):
l = default.Loader()
View
45 celery/tests/app/test_log.py
@@ -8,17 +8,39 @@
from mock import patch, Mock
from celery import current_app
-from celery.app.log import Logging
+from celery import signals
+from celery.app.log import Logging, TaskFormatter
from celery.utils.log import LoggingProxy
from celery.utils import uuid
-from celery.utils.log import get_logger, ColorFormatter, logger as base_logger
+from celery.utils.log import (
+ get_logger,
+ ColorFormatter,
+ logger as base_logger,
+)
from celery.tests.utils import (
- Case, override_stdouts, wrap_logger, get_handlers,
+ AppCase, Case, override_stdouts, wrap_logger, get_handlers,
)
log = current_app.log
+class test_TaskFormatter(Case):
+
+ def test_no_task(self):
+ class Record(object):
+ msg = "hello world"
+ levelname = "info"
+ exc_text = exc_info = None
+
+ def getMessage(self):
+ return self.msg
+ record = Record()
+ x = TaskFormatter()
+ x.format(record)
+ self.assertEqual(record.task_name, "???")
+ self.assertEqual(record.task_id, "???")
+
+
class test_ColorFormatter(Case):
@patch("celery.utils.log.safe_str")
@@ -71,11 +93,12 @@ def test_format_raises_no_color(self, safe_str):
self.assertEqual(safe_str.call_count, 1)
-class test_default_logger(Case):
+class test_default_logger(AppCase):
- def setUp(self):
+ def setup(self):
self.setup_logger = log.setup_logger
self.get_logger = lambda n=None: get_logger(n) if n else logging.root
+ signals.setup_logging.receivers[:] = []
Logging._setup = False
def test_get_logger_sets_parent(self):
@@ -86,6 +109,14 @@ def test_get_logger_root(self):
logger = get_logger(base_logger.name)
self.assertIs(logger.parent, logging.root)
+ def test_setup_logging_subsystem_misc(self):
+ log.setup_logging_subsystem(loglevel=None)
+ self.app.conf.CELERYD_HIJACK_ROOT_LOGGER = True
+ try:
+ log.setup_logging_subsystem()
+ finally:
+ self.app.conf.CELERYD_HIJACK_ROOT_LOGGER = False
+
def test_setup_logging_subsystem_colorize(self):
log.setup_logging_subsystem(colorize=None)
log.setup_logging_subsystem(colorize=True)
@@ -149,6 +180,8 @@ def test_redirect_stdouts(self):
log.redirect_stdouts_to_logger(logger, loglevel=logging.ERROR)
logger.error("foo")
self.assertIn("foo", sio.getvalue())
+ log.redirect_stdouts_to_logger(logger, stdout=False,
+ stderr=False)
finally:
sys.stdout, sys.stderr = sys.__stdout__, sys.__stderr__
@@ -186,7 +219,7 @@ def test_logging_proxy_recurse_protection(self):
class test_task_logger(test_default_logger):
- def setUp(self):
+ def setup(self):
logger = self.logger = get_logger("celery.task")
logger.handlers = []
logging.root.manager.loggerDict.pop(logger.name, None)
View
4 celery/tests/backends/test_amqp.py
@@ -42,6 +42,10 @@ def test_mark_as_done(self):
self.assertTrue(tb2._cache.get(tid))
self.assertTrue(tb2.get_result(tid), 42)
+ def test_revive(self):
+ tb = self.create_backend()
+ tb.revive(None)
+
def test_is_pickled(self):
tb1 = self.create_backend()
tb2 = self.create_backend()
View
24 celery/tests/backends/test_base.py
@@ -58,6 +58,10 @@ def test__forget(self):
with self.assertRaises(NotImplementedError):
b.forget("SOMExx-N0Nex1stant-IDxx-")
+ def test_get_children(self):
+ with self.assertRaises(NotImplementedError):
+ b.get_children("SOMExx-N0Nex1stant-IDxx-")
+
def test_store_result(self):
with self.assertRaises(NotImplementedError):
b.store_result("SOMExx-N0nex1stant-IDxx-", 42, states.SUCCESS)
@@ -98,6 +102,9 @@ def test_forget(self):
with self.assertRaises(NotImplementedError):
b.forget("SOMExx-N0nex1stant-IDxx-")
+ def test_on_chord_part_return(self):
+ b.on_chord_part_return(None)
+
def test_on_chord_apply(self, unlock="celery.chord_unlock"):
p, current_app.tasks[unlock] = current_app.tasks.get(unlock), Mock()
try:
@@ -138,6 +145,7 @@ def test_unpickleable(self):
def test_impossible(self):
x = b.prepare_exception(Impossible())
self.assertIsInstance(x, UnpickleableExceptionWrapper)
+ self.assertTrue(str(x))
y = b.exception_to_python(x)
self.assertEqual(y.__class__.__name__, "Impossible")
if sys.version_info < (2, 5):
@@ -202,6 +210,14 @@ def test_delete_taskset(self):
self.b.delete_taskset("can-delete")
self.assertNotIn("can-delete", self.b._data)
+ def test_prepare_exception_json(self):
+ x = DictBackend(serializer="json")
+ e = x.prepare_exception(KeyError("foo"))
+ self.assertIn("exc_type", e)
+ e = x.exception_to_python(e)
+ self.assertEqual(e.__class__.__name__, "KeyError")
+ self.assertEqual(str(e), "'foo'")
+
def test_save_taskset(self):
b = BaseDictBackend()
b._save_taskset = Mock()
@@ -237,6 +253,10 @@ class test_KeyValueStoreBackend(Case):
def setUp(self):
self.b = KVBackend()
+ def test_on_chord_part_return(self):
+ assert not self.b.implements_incr
+ self.b.on_chord_part_return(None)
+
def test_get_store_delete_result(self):
tid = uuid()
self.b.mark_as_done(tid, "Hello world")
@@ -290,6 +310,10 @@ def test_set(self):
with self.assertRaises(NotImplementedError):
KeyValueStoreBackend().set("a", 1)
+ def test_incr(self):
+ with self.assertRaises(NotImplementedError):
+ KeyValueStoreBackend().incr("a")
+
def test_cleanup(self):
self.assertFalse(KeyValueStoreBackend().cleanup())
View
8 celery/tests/backends/test_database.py
@@ -6,6 +6,7 @@
from datetime import datetime
from nose import SkipTest
+from pickle import loads, dumps
from celery import states
from celery.app import app_or_default
@@ -151,14 +152,19 @@ def test_forget(self):
tb = DatabaseBackend(backend="memory://")
tid = uuid()
tb.mark_as_done(tid, {"foo": "bar"})
- x = AsyncResult(tid)
+ tb.mark_as_done(tid, {"foo": "bar"})
+ x = AsyncResult(tid, backend=tb)
x.forget()
self.assertIsNone(x.result)
def test_process_cleanup(self):
tb = DatabaseBackend()
tb.process_cleanup()
+ def test_reduce(self):
+ tb = DatabaseBackend()
+ self.assertTrue(loads(dumps(tb)))
+
def test_save__restore__delete_taskset(self):
tb = DatabaseBackend()
View
2 celery/tests/backends/test_pyredis_compat.py
@@ -1,6 +1,7 @@
from __future__ import absolute_import
from nose import SkipTest
+from pickle import loads, dumps
from celery.exceptions import ImproperlyConfigured
from celery.tests.utils import Case
@@ -19,3 +20,4 @@ def test_constructor(self):
self.assertEqual(x.redis_port, 312)
self.assertEqual(x.redis_db, 1)
self.assertEqual(x.redis_password, "foo")
+ self.assertTrue(loads(dumps(x)))
View
33 celery/tests/backends/test_redis.py
@@ -1,11 +1,16 @@
from __future__ import absolute_import
+from __future__ import with_statement
from datetime import timedelta
from mock import Mock, patch
+from nose import SkipTest
+from pickle import loads, dumps
from celery import current_app
from celery import states
+from celery.datastructures import AttributeDict
+from celery.exceptions import ImproperlyConfigured
from celery.result import AsyncResult
from celery.task import subtask
from celery.utils import cached_property, uuid
@@ -81,6 +86,34 @@ def client(self):
self.MockBackend = MockBackend
+ def test_reduce(self):
+ try:
+ from celery.backends.redis import RedisBackend
+ x = RedisBackend()
+ self.assertTrue(loads(dumps(x)))
+ except ImportError:
+ raise SkipTest("redis not installed")
+
+ def test_no_redis(self):
+ self.MockBackend.redis = None
+ with self.assertRaises(ImproperlyConfigured):
+ self.MockBackend()
+
+ def test_url(self):
+ x = self.MockBackend("redis://foobar//1")
+ self.assertEqual(x.host, "foobar")
+ self.assertEqual(x.db, "1")
+
+ def test_conf_raises_KeyError(self):
+ conf = AttributeDict({"CELERY_RESULT_SERIALIZER": "json",
+ "CELERY_MAX_CACHED_RESULTS": 1,
+ "CELERY_TASK_RESULT_EXPIRES": None})
+ prev, current_app.conf = current_app.conf, conf
+ try:
+ self.MockBackend()
+ finally:
+ current_app.conf = prev
+
def test_expires_defaults_to_config(self):
conf = current_app.conf
prev = conf.CELERY_TASK_RESULT_EXPIRES
View
32 celery/tests/bin/test_base.py
@@ -3,7 +3,9 @@
import os
-from celery.bin.base import Command
+from mock import patch
+
+from celery.bin.base import Command, Option
from celery.tests.utils import AppCase, override_stdouts
@@ -41,6 +43,13 @@ def test_run_interface(self):
with self.assertRaises(NotImplementedError):
Command().run()
+ @patch("sys.stdout")
+ def test_parse_options_version_only(self, stdout):
+ cmd = Command()
+ with self.assertRaises(SystemExit):
+ cmd.parse_options("prog", ["--version"])
+ stdout.write.assert_called_with(cmd.version + "\n")
+
def test_execute_from_commandline(self):
cmd = MockCommand()
args1, kwargs1 = cmd.execute_from_commandline() # sys.argv
@@ -71,6 +80,21 @@ def test_with_custom_config_module(self):
finally:
if prev:
os.environ["CELERY_CONFIG_MODULE"] = prev
+ else:
+ os.environ.pop("CELERY_CONFIG_MODULE", None)
+
+ def test_with_custom_broker(self):
+ prev = os.environ.pop("CELERY_BROKER_URL", None)
+ try:
+ cmd = MockCommand()
+ cmd.setup_app_from_commandline(["--broker=xyzza://"])
+ self.assertEqual(os.environ.get("CELERY_BROKER_URL"),
+ "xyzza://")
+ finally:
+ if prev:
+ os.environ["CELERY_BROKER_URL"] = prev
+ else:
+ os.environ.pop("CELERY_BROKER_URL", None)
def test_with_custom_app(self):
cmd = MockCommand()
@@ -89,3 +113,9 @@ def test_with_cmdline_config(self):
self.assertEqual(cmd.app.conf.BROKER_HOST, "broker.example.com")
self.assertEqual(cmd.app.conf.CELERYD_PREFETCH_MULTIPLIER, 100)
self.assertListEqual(rest, ["--loglevel=INFO"])
+
+ def test_parse_preload_options_shortopt(self):
+ cmd = Command()
+ cmd.preload_options = (Option("-s", action="store", dest="silent"), )
+ acc = cmd.parse_preload_options(["-s", "yes"])
+ self.assertEqual(acc.get("silent"), "yes")
View
6 celery/tests/bin/test_celerybeat.py
@@ -182,6 +182,12 @@ def test_detach(self):
self.assertTrue(MockDaemonContext.opened)
self.assertTrue(MockDaemonContext.closed)
+ @patch("os.chdir")
+ def test_prepare_preload_options(self, chdir):
+ cmd = celerybeat_bin.BeatCommand()
+ cmd.prepare_preload_options({"working_directory": "/opt/Project"})
+ chdir.assert_called_with("/opt/Project")
+
def test_parse_options(self):
cmd = celerybeat_bin.BeatCommand()
cmd.app = app_or_default()
View
20 celery/tests/bin/test_celeryd.py
@@ -7,7 +7,7 @@
from functools import wraps
-from mock import patch
+from mock import Mock, patch
from nose import SkipTest
from billiard import current_process
@@ -65,6 +65,16 @@ def test_queues_string(self):
self.assertTrue("foo" in celery.amqp.queues)
@disable_stdouts
+ def test_cpu_count(self):
+ celery = Celery(set_as_current=False)
+ with patch("celery.apps.worker.cpu_count") as cpu_count:
+ cpu_count.side_effect = NotImplementedError()
+ worker = celery.Worker(concurrency=None)
+ self.assertEqual(worker.concurrency, 2)
+ worker = celery.Worker(concurrency=5)
+ self.assertEqual(worker.concurrency, 5)
+
+ @disable_stdouts
def test_windows_B_option(self):
celery = Celery(set_as_current=False)
celery.IS_WINDOWS = True
@@ -139,6 +149,14 @@ def test_run(self):
worker.init_loader()
worker.run()
+ prev, cd.IGNORE_ERRORS = cd.IGNORE_ERRORS, (KeyError, )
+ try:
+ worker.run_worker = Mock()
+ worker.run_worker.side_effect = KeyError()
+ worker.run()
+ finally:
+ cd.IGNORE_ERRORS = prev
+
@disable_stdouts
def test_purge_messages(self):
self.Worker().purge_messages()
View
6 celery/tests/bin/test_celeryd_multi.py
@@ -312,7 +312,7 @@ def test_shutdown_nodes(self, slepp, gethostname, PIDFile):
self.prepare_pidfile_for_getpids(PIDFile)
self.assertIsNone(self.t.shutdown_nodes([]))
self.t.signal_node = Mock()
- self.t.node_alive = Mock()
+ node_alive = self.t.node_alive = Mock()
self.t.node_alive.return_value = False
callback = Mock()
@@ -324,11 +324,9 @@ def test_shutdown_nodes(self, slepp, gethostname, PIDFile):
self.t.signal_node.return_value = False
self.assertTrue(callback.called)
self.t.stop(["foo", "bar", "baz"], "celeryd", callback=None)
- calls = [0]
def on_node_alive(pid):
- calls[0] += 1
- if calls[0] > 3:
+ if node_alive.call_count > 4:
return True
return False
self.t.signal_node.return_value = True
View
21 celery/tests/bin/test_celeryev.py
@@ -1,6 +1,8 @@
from __future__ import absolute_import
+from __future__ import with_statement
from nose import SkipTest
+from mock import patch as mpatch
from celery.app import app_or_default
from celery.bin import celeryev
@@ -32,6 +34,14 @@ def test_run_dump(self):
self.assertEqual(self.ev.run(dump=True), "me dumper, you?")
self.assertIn("celeryev:dump", proctitle.last[0])
+ @mpatch("os.chdir")
+ def test_prepare_preload_options(self, chdir):
+ self.ev.prepare_preload_options({"working_directory": "/opt/Project"})
+ chdir.assert_called_with("/opt/Project")
+ chdir.called = False
+ self.ev.prepare_preload_options({})
+ self.assertFalse(chdir.called)
+
def test_run_top(self):
try:
import curses # noqa
@@ -56,6 +66,17 @@ def test_run_cam(self):
self.assertEqual(kw["logfile"], "logfile")
self.assertIn("celeryev:cam", proctitle.last[0])
+ @mpatch("celery.events.snapshot.evcam")
+ @mpatch("celery.bin.celeryev.detached")
+ def test_run_cam_detached(self, detached, evcam):
+ self.ev.prog_name = "celeryev"
+ self.ev.run_evcam("myapp.Camera", detach=True)
+ self.assertTrue(detached.called)
+ self.assertTrue(evcam.called)
+
+ def test_get_options(self):
+ self.assertTrue(self.ev.get_options())
+
@patch("celery.bin.celeryev", "EvCommand", MockCommand)
def test_main(self):
MockCommand.executed = []
View
16 celery/tests/concurrency/test_concurrency.py
@@ -47,6 +47,14 @@ def callback(*args):
{"target": (3, (8, 16)),
"callback": (4, (42, ))})
+ def test_does_not_debug(self):
+ x = BasePool(10)
+ x._does_debug = False
+ x.apply_async(object)
+
+ def test_num_processes(self):
+ self.assertEqual(BasePool(7).num_processes, 7)
+
def test_interface_on_start(self):
BasePool(10).on_start()
@@ -69,3 +77,11 @@ def test_restart(self):
p = BasePool(10)
with self.assertRaises(NotImplementedError):
p.restart()
+
+ def test_interface_on_terminate(self):
+ p = BasePool(10)
+ p.on_terminate()
+
+ def test_interface_terminate_job(self):
+ with self.assertRaises(NotImplementedError):
+ BasePool(10).terminate_job(101)
View
27 celery/tests/concurrency/test_processes.py
@@ -76,9 +76,9 @@ class MockPool(object):
def __init__(self, *args, **kwargs):
self.started = True
self._state = mp.RUN
- self.processes = kwargs.get("processes")
- self._pool = [Object(pid=i) for i in range(self.processes)]
- self._current_proc = cycle(xrange(self.processes)).next
+ self._processes = kwargs.get("processes")
+ self._pool = [Object(pid=i) for i in range(self._processes)]
+ self._current_proc = cycle(xrange(self._processes)).next
def close(self):
self.closed = True
@@ -91,10 +91,10 @@ def terminate(self):
self.terminated = True
def grow(self, n=1):
- self.processes += n
+ self._processes += n
def shrink(self, n=1):
- self.processes -= n
+ self._processes -= n
def apply_async(self, *args, **kwargs):
pass
@@ -179,11 +179,11 @@ def _do_test(_kill):
def test_grow_shrink(self):
pool = TaskPool(10)
pool.start()
- self.assertEqual(pool._pool.processes, 10)
+ self.assertEqual(pool._pool._processes, 10)
pool.grow()
- self.assertEqual(pool._pool.processes, 11)
+ self.assertEqual(pool._pool._processes, 11)
pool.shrink(2)
- self.assertEqual(pool._pool.processes, 9)
+ self.assertEqual(pool._pool._processes, 9)
def test_info(self):
pool = TaskPool(10)
@@ -197,6 +197,17 @@ def test_info(self):
self.assertIsNone(info["max-tasks-per-child"])
self.assertEqual(info["timeouts"], (5, 10))
+ def test_num_processes(self):
+ pool = TaskPool(7)
+ pool.start()
+ self.assertEqual(pool.num_processes, 7)
+
+ def test_restart_pool(self):
+ pool = TaskPool()
+ pool._pool = Mock()
+ pool.restart()
+ pool._pool.restart.assert_called_with()
+
def test_restart(self):
raise SkipTest("functional test")
View
51 celery/tests/events/test_events.py
@@ -3,9 +3,10 @@
import socket
+from mock import Mock
+
from celery import events
-from celery.app import app_or_default
-from celery.tests.utils import Case
+from celery.tests.utils import AppCase
class MockProducer(object):
@@ -29,18 +30,15 @@ def has_event(self, kind):
return False
-class test_Event(Case):
+class test_Event(AppCase):
def test_constructor(self):
event = events.Event("world war II")
self.assertEqual(event["type"], "world war II")
self.assertTrue(event["timestamp"])
-class test_EventDispatcher(Case):
-
- def setUp(self):
- self.app = app_or_default()
+class test_EventDispatcher(AppCase):
def test_send(self):
producer = MockProducer()
@@ -67,6 +65,30 @@ def test_send(self):
for ev in evs:
self.assertTrue(producer.has_event(ev))
+ buf = eventer._outbound_buffer = Mock()
+ buf.popleft.side_effect = IndexError()
+ eventer.flush()
+
+ def test_enter_exit(self):
+ with self.app.broker_connection() as conn:
+ d = self.app.events.Dispatcher(conn)
+ d.close = Mock()
+ with d as _d:
+ self.assertTrue(_d)
+ d.close.assert_called_with()
+
+ def test_enable_disable_callbacks(self):
+ on_enable = Mock()
+ on_disable = Mock()
+ with self.app.broker_connection() as conn:
+ with self.app.events.Dispatcher(conn, enabled=False) as d:
+ d.on_enabled.add(on_enable)
+ d.on_disabled.add(on_disable)
+ d.enable()
+ on_enable.assert_called_with()
+ d.disable()
+ on_disable.assert_called_with()
+
def test_enabled_disable(self):
connection = self.app.broker_connection()
channel = connection.channel()
@@ -99,10 +121,7 @@ def test_enabled_disable(self):
connection.close()
-class test_EventReceiver(Case):
-
- def setUp(self):
- self.app = app_or_default()
+class test_EventReceiver(AppCase):
def test_process(self):
@@ -181,11 +200,13 @@ def handler(event):
connection.close()
-class test_misc(Case):
-
- def setUp(self):
- self.app = app_or_default()
+class test_misc(AppCase):
def test_State(self):
state = self.app.events.State()
self.assertDictEqual(dict(state.workers), {})
+
+ def test_default_dispatcher(self):
+ with self.app.events.default_dispatcher() as d:
+ self.assertTrue(d)
+ self.assertTrue(d.connection)
View
21 celery/tests/events/test_snapshot.py
@@ -1,6 +1,8 @@
from __future__ import absolute_import
from __future__ import with_statement
+from mock import patch
+
from celery.app import app_or_default
from celery.events import Events
from celery.events.snapshot import Polaroid, evcam
@@ -114,11 +116,24 @@ def Receiver(self, *args, **kwargs):
def setUp(self):
self.app = app_or_default()
- self.app.events = self.MockEvents()
+ self.prev, self.app.events = self.app.events, self.MockEvents()
+
+ def tearDown(self):
+ self.app.events = self.prev
def test_evcam(self):
evcam(Polaroid, timer=timer)
evcam(Polaroid, timer=timer, loglevel="CRITICAL")
self.MockReceiver.raise_keyboard_interrupt = True
- with self.assertRaises(SystemExit):
- evcam(Polaroid, timer=timer)
+ try:
+ with self.assertRaises(SystemExit):
+ evcam(Polaroid, timer=timer)
+ finally:
+ self.MockReceiver.raise_keyboard_interrupt = False
+
+ @patch("atexit.register")
+ @patch("celery.platforms.create_pidlock")
+ def test_evcam_pidfile(self, create_pidlock, atexit):
+ evcam(Polaroid, timer=timer, pidfile="/var/pid")
+ self.assertTrue(atexit.called)
+ create_pidlock.assert_called_with("/var/pid")
View
5 celery/tests/events/test_state.py
@@ -172,6 +172,11 @@ def test_worker_online_offline(self):
self.assertFalse(r.state.alive_workers())
self.assertFalse(r.state.workers["utest1"].alive)
+ def test_itertasks(self):
+ s = State()
+ s.tasks = {"a": "a", "b": "b", "c": "c", "d": "d"}
+ self.assertEqual(len(list(s.itertasks(limit=2))), 2)
+
def test_worker_heartbeat_expire(self):
r = ev_worker_heartbeats(State())
r.next()
View
16 celery/tests/slow/test_buckets.py
@@ -148,26 +148,22 @@ def test_get_block(self, sleep):
x = buckets.TaskBucket(task_registry=self.registry)
x.not_empty = Mock()
get = x._get = Mock()
- calls = [0]
remaining = [0]
def effect():
- try:
- if not calls[0]:
- raise Empty()
- rem = remaining[0]
- remaining[0] = 0
- return rem, Mock()
- finally:
- calls[0] += 1
+ if get.call_count == 1:
+ raise Empty()
+ rem = remaining[0]
+ remaining[0] = 0
+ return rem, Mock()
get.side_effect = effect
with mock_context(Mock()) as context:
x.not_empty = context
x.wait = Mock()
x.get(block=True)
- calls[0] = 0
+ get.reset()
remaining[0] = 1
x.get(block=True)
View
8 celery/tests/tasks/test_registry.py
@@ -23,6 +23,9 @@ def run(self, **kwargs):
class test_TaskRegistry(Case):
+ def test_NotRegistered_str(self):
+ self.assertTrue(repr(TaskRegistry.NotRegistered("tasks.add")))
+
def assertRegisterUnregisterCls(self, r, task):
with self.assertRaises(r.NotRegistered):
r.unregister(task)
@@ -64,3 +67,8 @@ def test_task_registry(self):
self.assertTrue(MockTask().run())
self.assertTrue(MockPeriodicTask().run())
+
+ def test_compat(self):
+ r = TaskRegistry()
+ r.regular()
+ r.periodic()
View
137 celery/tests/tasks/test_result.py
@@ -1,11 +1,21 @@
from __future__ import absolute_import
from __future__ import with_statement
+from pickle import loads, dumps
+from mock import Mock
+
from celery import states
from celery.app import app_or_default
+from celery.exceptions import IncompleteStream
from celery.utils import uuid
from celery.utils.serialization import pickle
-from celery.result import AsyncResult, EagerResult, TaskSetResult, ResultSet
+from celery.result import (
+ AsyncResult,
+ EagerResult,
+ TaskSetResult,
+ ResultSet,
+ from_serializable,
+)
from celery.exceptions import TimeoutError
from celery.task import task
from celery.task.base import Task
@@ -53,6 +63,71 @@ def setup(self):
for task in (self.task1, self.task2, self.task3, self.task4):
save_result(task)
+ def test_compat_properties(self):
+ x = AsyncResult("1")
+ self.assertEqual(x.task_id, x.id)
+ x.task_id = "2"
+ self.assertEqual(x.id, "2")
+
+ def test_children(self):
+ x = AsyncResult("1")
+ children = [EagerResult(str(i), i, states.SUCCESS) for i in range(3)]
+ x.backend = Mock()
+ x.backend.get_children.return_value = children
+ x.backend.READY_STATES = states.READY_STATES
+ self.assertTrue(x.children)
+ self.assertEqual(len(x.children), 3)
+
+ def test_build_graph_get_leaf_collect(self):
+ x = AsyncResult("1")
+ x.backend._cache["1"] = {"status": states.SUCCESS, "result": None}
+ c = [EagerResult(str(i), i, states.SUCCESS) for i in range(3)]
+ x.iterdeps = Mock()
+ x.iterdeps.return_value = (
+ (None, x),
+ (x, c[0]),
+ (c[0], c[1]),
+ (c[1], c[2])
+ )
+ x.backend.READY_STATES = states.READY_STATES
+ self.assertTrue(x.graph)
+
+ self.assertIs(x.get_leaf(), 2)
+
+ it = x.collect()
+ self.assertListEqual(list(it), [
+ (x, None),
+ (c[0], 0),
+ (c[1], 1),
+ (c[2], 2),
+ ])
+
+ def test_iterdeps(self):
+ x = AsyncResult("1")
+ x.backend._cache["1"] = {"status": states.SUCCESS, "result": None}
+ c = [EagerResult(str(i), i, states.SUCCESS) for i in range(3)]
+ for child in c:
+ child.backend = Mock()
+ child.backend.get_children.return_value = []
+ x.backend.get_children = Mock()
+ x.backend.get_children.return_value = c
+ it = x.iterdeps()
+ self.assertListEqual(list(it), [
+ (None, x),
+ (x, c[0]),
+ (x, c[1]),
+ (x, c[2]),
+ ])
+ x.backend._cache.pop("1")
+ x.ready = Mock()
+ x.ready.return_value = False
+ with self.assertRaises(IncompleteStream):
+ list(x.iterdeps())
+ list(x.iterdeps(intermediate=True))
+
+ def test_eq_not_implemented(self):
+ self.assertFalse(AsyncResult("1") == object())
+
def test_reduce(self):
a1 = AsyncResult("uuid", task_name=mytask.name)
restored = pickle.loads(pickle.dumps(a1))
@@ -129,6 +204,7 @@ def test_get(self):
self.assertEqual(ok2_res.get(), "quick")
with self.assertRaises(KeyError):
nok_res.get()
+ self.assertTrue(nok_res.get(propagate=False))
self.assertIsInstance(nok2_res.result, KeyError)
self.assertEqual(ok_res.info, "the")
@@ -159,6 +235,32 @@ def test_ready(self):
class test_ResultSet(AppCase):
+ def test_resultset_repr(self):
+ self.assertTrue(repr(ResultSet(map(AsyncResult, [1, 2, 3]))))
+
+ def test_eq_other(self):
+ self.assertFalse(ResultSet([1, 3, 3]) == 1)
+ self.assertTrue(ResultSet([1]) == ResultSet([1]))
+
+ def test_get(self):
+ x = ResultSet(map(AsyncResult, [1, 2, 3]))
+ b = x.results[0].backend = Mock()
+ b.supports_native_join = False
+ x.join_native = Mock()
+ x.join = Mock()
+ x.get()
+ self.assertTrue(x.join.called)
+ b.supports_native_join = True
+ x.get()
+ self.assertTrue(x.join_native.called)
+
+ def test_add(self):
+ x = ResultSet([1])
+ x.add(2)
+ self.assertEqual(len(x), 2)
+ x.add(2)
+ self.assertEqual(len(x), 2)
+
def test_add_discard(self):
x = ResultSet([])
x.add(AsyncResult("1"))
@@ -231,6 +333,21 @@ def test_total(self):
self.assertEqual(len(self.ts), self.size)
self.assertEqual(self.ts.total, self.size)
+ def test_compat_properties(self):
+ self.assertEqual(self.ts.taskset_id, self.ts.id)
+ self.ts.taskset_id = "foo"
+ self.assertEqual(self.ts.taskset_id, "foo")
+
+ def test_eq_other(self):
+ self.assertFalse(self.ts == 1)
+
+ def test_reduce(self):
+ self.assertTrue(loads(dumps(self.ts)))
+
+ def test_compat_subtasks_kwarg(self):
+ x = TaskSetResult(uuid(), subtasks=[1, 2, 3])
+ self.assertEqual(x.results, [1, 2, 3])
+
def test_iterate_raises(self):
ar = MockAsyncResultFailure(uuid())
ts = TaskSetResult(uuid(), [ar])
@@ -432,13 +549,31 @@ def test_wait_raises(self):
res = RaisingTask.apply(args=[3, 3])
with self.assertRaises(KeyError):
res.wait()
+ self.assertTrue(res.wait(propagate=False))
def test_wait(self):
res = EagerResult("x", "x", states.RETRY)
res.wait()
self.assertEqual(res.state, states.RETRY)
self.assertEqual(res.status, states.RETRY)
+ def test_forget(self):
+ res = EagerResult("x", "x", states.RETRY)
+ res.forget()
+
def test_revoke(self):
res = RaisingTask.apply(args=[3, 3])
self.assertFalse(res.revoke())
+
+
+class test_serializable(AppCase):
+
+ def test_AsyncResult(self):
+ x = AsyncResult(uuid())
+ self.assertEqual(x, from_serializable(x.serializable()))
+ self.assertEqual(x, from_serializable(x))
+
+ def test_TaskSetResult(self):
+ x = TaskSetResult(uuid(), [AsyncResult(uuid()) for _ in range(10)])
+ self.assertEqual(x, from_serializable(x.serializable()))
+ self.assertEqual(x, from_serializable(x))
View
35 celery/tests/tasks/test_tasks.py
@@ -3,9 +3,11 @@
from datetime import datetime, timedelta
from functools import wraps
+from mock import patch
+from pickle import loads, dumps
from celery import task
-from celery.task import current
+from celery.task import current, Task
from celery.app import app_or_default
from celery.task import task as task_dec
from celery.exceptions import RetryTaskError
@@ -57,6 +59,7 @@ def retry_task(arg1, arg2, kwarg=1, max_retries=None, care=True):
current.iterations += 1
rmax = current.max_retries if max_retries is None else max_retries
+ assert repr(current.request)
retries = current.request.retries
if care and retries >= rmax:
return arg1
@@ -301,6 +304,22 @@ def test_context_get(self):
def test_task_class_repr(self):
task = self.createTask("c.unittest.t.repr")
self.assertIn("class Task of", repr(task.app.Task))
+ prev, task.app.Task._app = task.app.Task._app, None
+ try:
+ self.assertIn("unbound", repr(task.app.Task, ))
+ finally:
+ task.app.Task._app = prev
+
+ def test_bind_no_magic_kwargs(self):
+ task = self.createTask("c.unittest.t.magic_kwargs")
+ task.__class__.accept_magic_kwargs = None
+ task.bind(task.app)
+
+ def test_annotate(self):
+ with patch("celery.app.task.resolve_all_annotations") as anno:
+ anno.return_value = [{"FOO": "BAR"}]
+ Task.annotate()
+ self.assertEqual(Task.FOO, "BAR")
def test_after_return(self):
task = self.createTask("c.unittest.t.after_return")
@@ -436,6 +455,13 @@ def test_apply_throw(self):
with self.assertRaises(KeyError):
raising.apply(throw=True)
+ def test_apply_no_magic_kwargs(self):
+ increment_counter.accept_magic_kwargs = False
+ try:
+ increment_counter.apply()
+ finally:
+ increment_counter.accept_magic_kwargs = True
+
def test_apply_with_CELERY_EAGER_PROPAGATES_EXCEPTIONS(self):
raising.app.conf.CELERY_EAGER_PROPAGATES_EXCEPTIONS = True
try:
@@ -551,6 +577,13 @@ def __inner(*args, **kwargs):
class test_crontab_parser(Case):
+ def test_crontab_reduce(self):
+ self.assertTrue(loads(dumps(crontab("*"))))
+
+ def test_range_steps_not_enough(self):
+ with self.assertRaises(crontab_parser.ParseException):
+ crontab_parser(24)._range_steps([1])
+
def test_parse_star(self):
self.assertEqual(crontab_parser(24).parse('*'), set(range(24)))
self.assertEqual(crontab_parser(60).parse('*'), set(range(60)))
View
139 celery/tests/utilities/test_dispatcher.py
@@ -0,0 +1,139 @@
+from __future__ import absolute_import
+
+
+import gc
+import sys
+import time
+
+from celery.utils.dispatch import Signal
+from celery.tests.utils import Case
+
+
+if sys.platform.startswith('java'):
+
+ def garbage_collect():
+ # Some JVM GCs will execute finalizers in a different thread, meaning
+ # we need to wait for that to complete before we go on looking for the
+ # effects of that.
+ gc.collect()
+ time.sleep(0.1)
+
+elif hasattr(sys, "pypy_version_info"):
+
+ def garbage_collect(): # noqa
+ # Collecting weakreferences can take two collections on PyPy.
+ gc.collect()
+ gc.collect()
+else:
+
+ def garbage_collect(): # noqa
+ gc.collect()
+
+
+def receiver_1_arg(val, **kwargs):
+ return val
+
+
+class Callable(object):
+
+ def __call__(self, val, **kwargs):
+ return val
+
+ def a(self, val, **kwargs):
+ return val
+
+a_signal = Signal(providing_args=["val"])
+
+
+class DispatcherTests(Case):
+ """Test suite for dispatcher (barely started)"""
+
+ def _testIsClean(self, signal):
+ """Assert that everything has been cleaned up automatically"""
+ self.assertEqual(signal.receivers, [])
+
+ # force cleanup just in case
+ signal.receivers = []
+
+ def testExact(self):
+ a_signal.connect(receiver_1_arg, sender=self)
+ expected = [(receiver_1_arg, "test")]
+ result = a_signal.send(sender=self, val="test")
+ self.assertEqual(result, expected)
+ a_signal.disconnect(receiver_1_arg, sender=self)
+ self._testIsClean(a_signal)
+
+ def testIgnoredSender(self):
+ a_signal.connect(receiver_1_arg)
+ expected = [(receiver_1_arg, "test")]
+ result = a_signal.send(sender=self, val="test")
+ self.assertEqual(result, expected)
+ a_signal.disconnect(receiver_1_arg)
+ self._testIsClean(a_signal)
+
+ def testGarbageCollected(self):
+ a = Callable()
+ a_signal.connect(a.a, sender=self)
+ expected = []
+ del a
+ garbage_collect()
+ result = a_signal.send(sender=self, val="test")
+ self.assertEqual(result, expected)
+ self._testIsClean(a_signal)
+
+ def testMultipleRegistration(self):
+ a = Callable()
+ a_signal.connect(a)
+ a_signal.connect(a)
+ a_signal.connect(a)
+ a_signal.connect(a)
+ a_signal.connect(a)
+ a_signal.connect(a)
+ result = a_signal.send(sender=self, val="test")
+ self.assertEqual(len(result), 1)
+ self.assertEqual(len(a_signal.receivers), 1)
+ del a
+ del result
+ garbage_collect()
+ self._testIsClean(a_signal)
+
+ def testUidRegistration(self):
+
+ def uid_based_receiver_1(**kwargs):
+ pass
+
+ def uid_based_receiver_2(**kwargs):
+ pass
+
+ a_signal.connect(uid_based_receiver_1, dispatch_uid="uid")
+ a_signal.connect(uid_based_receiver_2, dispatch_uid="uid")
+ self.assertEqual(len(a_signal.receivers), 1)
+ a_signal.disconnect(dispatch_uid="uid")
+ self._testIsClean(a_signal)
+
+ def testRobust(self):
+ """Test the sendRobust function"""
+
+ def fails(val, **kwargs):
+ raise ValueError('this')
+
+ a_signal.connect(fails)
+ result = a_signal.send_robust(sender=self, val="test")
+ err = result[0][1]
+ self.assertTrue(isinstance(err, ValueError))
+ self.assertEqual(err.args, ('this',))
+ a_signal.disconnect(fails)
+ self._testIsClean(a_signal)
+
+ def testDisconnection(self):
+ receiver_1 = Callable()
+ receiver_2 = Callable()
+ receiver_3 = Callable()
+ a_signal.connect(receiver_1)
+ a_signal.connect(receiver_2)
+ a_signal.connect(receiver_3)
+ a_signal.disconnect(receiver_1)
+ del receiver_2
+ garbage_collect()
+ a_signal.disconnect(receiver_3)
+ self._testIsClean(a_signal)
View
11 celery/tests/utilities/test_imports.py
@@ -1,4 +1,5 @@
from __future__ import absolute_import
+from __future__ import with_statement
from mock import Mock, patch
@@ -7,13 +8,22 @@
symbol_by_name,
reload_from_cwd,
module_file,
+ find_module,
+ NotAPackage,
)
from celery.tests.utils import Case
class test_import_utils(Case):
+ def test_find_module(self):
+ self.assertTrue(find_module("celery"))
+ imp = Mock()
+ imp.return_value = None
+ with self.assertRaises(NotAPackage):
+ find_module("foo.bar.baz", imp=imp)
+
def test_qualname(self):
Class = type("Fox", (object, ), {"__module__": "quick.brown"})
self.assertEqual(qualname(Class), "quick.brown.Fox")
@@ -32,6 +42,7 @@ def test_symbol_by_name_package(self):
from celery.worker import WorkController
self.assertIs(symbol_by_name(".worker:WorkController",
package="celery"), WorkController)
+ self.assertTrue(symbol_by_name(":group", package="celery"))
@patch("celery.utils.imports.reload")
def test_reload_from_cwd(self, reload):
View
79 celery/tests/utilities/test_saferef.py
@@ -0,0 +1,79 @@
+from __future__ import absolute_import
+
+from celery.utils.dispatch.saferef import safe_ref