Permalink
Browse files

Py3k fixes

  • Loading branch information...
ask committed Jun 1, 2012
1 parent 7c3f9b4 commit d4bb75ef4b172f11d4c27d169bd8f95c60256ee1
View
@@ -356,6 +356,7 @@ def get_key_for_chord(self, taskset_id):
def _strip_prefix(self, key):
"""Takes bytes, emits string."""
+ key = ensure_bytes(key)
for prefix in self.task_keyprefix, self.taskset_keyprefix:
if key.startswith(prefix):
return bytes_to_str(key[len(prefix):])
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
-from UserDict import UserDict
+from celery.utils.compat import UserDict
from .base import apply_target, BasePool
View
@@ -36,7 +36,7 @@ class QueueNotFound(KeyError):
"""Task routed to a queue not in CELERY_QUEUES."""
-class ImproperlyConfigured(Exception):
+class ImproperlyConfigured(ImportError):
"""Celery is somehow improperly configured."""
@@ -6,6 +6,7 @@
from tempfile import mktemp
from mock import patch, Mock
+from nose import SkipTest
from celery import current_app
from celery import signals
@@ -31,6 +32,7 @@ class Record(object):
msg = "hello world"
levelname = "info"
exc_text = exc_info = None
+ stack_info = None
def getMessage(self):
return self.msg
@@ -59,7 +61,8 @@ def test_formatException_string(self, safe_str, fe, value="HELLO"):
x = ColorFormatter(value)
fe.return_value = value
self.assertTrue(x.formatException(value))
- self.assertTrue(safe_str.called)
+ if sys.version_info[0] == 2:
+ self.assertTrue(safe_str.called)
@patch("celery.utils.log.safe_str")
def test_format_raises(self, safe_str):
@@ -72,10 +75,19 @@ def on_safe_str(s):
safe_str.side_effect = None
safe_str.side_effect = on_safe_str
- record = Mock()
- record.levelname = "ERROR"
- record.msg = "HELLO"
- record.exc_text = "error text"
+ class Record(object):
+ levelname = "ERROR"
+ msg = "HELLO"
+ exc_text = "error text"
+ stack_info = None
+
+ def __str__(self):
+ return on_safe_str("")
+
+ def getMessage(self):
+ return self.msg
+
+ record = Record()
safe_str.return_value = record
x.format(record)
@@ -84,6 +96,8 @@ def on_safe_str(s):
@patch("celery.utils.log.safe_str")
def test_format_raises_no_color(self, safe_str):
+ if sys.version_info[0] == 3:
+ raise SkipTest("py3k")
x = ColorFormatter("HELLO", False)
record = Mock()
record.levelname = "ERROR"
@@ -16,6 +16,31 @@ class Object(object):
pass
+def install_exceptions(mod):
+ # py3k: cannot catch exceptions not ineheriting from BaseException.
+
+ class NotFoundException(Exception):
+ pass
+
+ class TException(Exception):
+ pass
+
+ class InvalidRequestException(Exception):
+ pass
+
+ class UnavailableException(Exception):
+ pass
+
+ class TimedOutException(Exception):
+ pass
+
+ mod.NotFoundException = NotFoundException
+ mod.TException = TException
+ mod.InvalidRequestException = InvalidRequestException
+ mod.TimedOutException = TimedOutException
+ mod.UnavailableException = UnavailableException
+
+
class test_CassandraBackend(AppCase):
def test_init_no_pycassa(self):
@@ -39,6 +64,7 @@ def test_init_with_and_without_LOCAL_QUROM(self):
with mock_module("pycassa"):
from celery.backends import cassandra as mod
mod.pycassa = Mock()
+ install_exceptions(mod.pycassa)
cons = mod.pycassa.ConsistencyLevel = Object()
cons.LOCAL_QUORUM = "foo"
@@ -65,7 +91,9 @@ def test_get_task_meta_for(self):
with mock_module("pycassa"):
from celery.backends import cassandra as mod
mod.pycassa = Mock()
+ install_exceptions(mod.pycassa)
mod.Thrift = Mock()
+ install_exceptions(mod.Thrift)
app = self.get_app()
x = mod.CassandraBackend(app=app)
Get_Column = x._get_column_family = Mock()
@@ -120,7 +148,9 @@ def test_store_result(self):
with mock_module("pycassa"):
from celery.backends import cassandra as mod
mod.pycassa = Mock()
+ install_exceptions(mod.pycassa)
mod.Thrift = Mock()
+ install_exceptions(mod.Thrift)
app = self.get_app()
x = mod.CassandraBackend(app=app)
Get_Column = x._get_column_family = Mock()
@@ -150,6 +180,7 @@ def test_get_column_family(self):
with mock_module("pycassa"):
from celery.backends import cassandra as mod
mod.pycassa = Mock()
+ install_exceptions(mod.pycassa)
app = self.get_app()
x = mod.CassandraBackend(app=app)
self.assertTrue(x._get_column_family())
@@ -10,7 +10,7 @@
migrate_task,
migrate_tasks,
)
-from celery.utils.encoding import bytes_t
+from celery.utils.encoding import bytes_t, ensure_bytes
from celery.tests.utils import AppCase, Case, Mock
@@ -71,9 +71,9 @@ def test_migrate(self, name="testcelery"):
migrate_tasks(x, y)
yq = q(y.default_channel)
- self.assertEqual(yq.get().body, "foo")
- self.assertEqual(yq.get().body, "bar")
- self.assertEqual(yq.get().body, "baz")
+ self.assertEqual(yq.get().body, ensure_bytes("foo"))
+ self.assertEqual(yq.get().body, ensure_bytes("bar"))
+ self.assertEqual(yq.get().body, ensure_bytes("baz"))
Producer(x).publish("foo", exchange=name, routing_key=name)
callback = Mock()
@@ -93,7 +93,7 @@ def incr(self, key, delta=1):
def is_list(l):
- return hasattr(l, "__iter__") and not isinstance(l, dict)
+ return hasattr(l, "__iter__") and not isinstance(l, (dict, basestring))
def maybe_list(l):
@@ -1,3 +1,4 @@
+billiard>=2.7.3.7
python-dateutil>=2.0
pytz
kombu>=2.1.8

0 comments on commit d4bb75e

Please sign in to comment.