Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

100% coverage for .security

  • Loading branch information...
commit 8e7646c68ed98c40566dcf0ce3e662b94e4ee4d4 1 parent 98d501f
Ask Solem Hoel authored
2  celery/security/__init__.py
View
@@ -71,7 +71,7 @@ def setup_security(allowed_serializers=None, key=None, cert=None, store=None,
cert = cert or conf.CELERY_SECURITY_CERTIFICATE
store = store or conf.CELERY_SECURITY_CERT_STORE
- if any(not v for v in (key, cert, store)):
+ if not (key and cert and store):
raise ImproperlyConfigured(SETTING_MISSING)
with open(key) as kf:
24 celery/tests/test_security/__init__.py
View
@@ -13,9 +13,12 @@
"""
from __future__ import absolute_import
+from __future__ import with_statement
import __builtin__
+from mock import Mock, patch
+
from celery import current_app
from celery.exceptions import ImproperlyConfigured
from celery.security import setup_security, disable_untrusted_serializers
@@ -23,6 +26,8 @@
from .case import SecurityCase
+from celery.tests.utils import mock_open
+
KEY1 = """-----BEGIN RSA PRIVATE KEY-----
MIICXgIBAAKBgQDCsmLC+eqL4z6bhtv0nzbcnNXuQrZUoh827jGfDI3kxNZ2LbEy
@@ -116,6 +121,25 @@ def test_setup_security(self):
self.assertIn('application/x-python-serialize', disabled)
disabled.clear()
+ @patch("celery.security.register_auth")
+ @patch("celery.security.disable_untrusted_serializers")
+ def test_setup_registry_complete(self, dis, reg, key="KEY", cert="CERT"):
+ calls = [0]
+ def effect(*args):
+ try:
+ m = Mock()
+ m.read.return_value = "B" if calls[0] else "A"
+ return m
+ finally:
+ calls[0] += 1
+
+ with mock_open(side_effect=effect):
+ store = Mock()
+ setup_security(["json"], key, cert, store)
+ dis.assert_called_with(["json"])
+ reg.assert_called_with("A", "B", store)
+
+
def test_security_conf(self):
current_app.conf.CELERY_TASK_SERIALIZER = 'auth'
50 celery/tests/test_security/test_certificate.py
View
@@ -8,6 +8,8 @@
from . import CERT1, CERT2, KEY1
from .case import SecurityCase
+from celery.tests.utils import mock_open
+
class test_Certificate(SecurityCase):
@@ -51,37 +53,27 @@ class test_FSCertStore(SecurityCase):
@patch("os.path.isdir")
@patch("glob.glob")
@patch("celery.security.certificate.Certificate")
- @patch("__builtin__.open")
- def test_init(self, open_, Certificate, glob, isdir):
+ def test_init(self, Certificate, glob, isdir):
cert = Certificate.return_value = Mock()
cert.has_expired.return_value = False
isdir.return_value = True
glob.return_value = ["foo.cert"]
- op = open_.return_value = Mock()
- op.__enter__ = Mock()
- def on_exit(*x):
- if x[0]:
- print(x)
- raise x[0], x[1], x[2]
- op.__exit__ = Mock()
- op.__exit__.side_effect = on_exit
- cert.get_id.return_value = 1
- x = FSCertStore("/var/certs")
- self.assertIn(1, x._certs)
- glob.assert_called_with("/var/certs/*")
- op.__enter__.assert_called_with()
- op.__exit__.assert_called_with(None, None, None)
-
- # they both end up with the same id
- glob.return_value = ["foo.cert", "bar.cert"]
- with self.assertRaises(SecurityError):
- x = FSCertStore("/var/certs")
- glob.return_value = ["foo.cert"]
-
- cert.has_expired.return_value = True
- with self.assertRaises(SecurityError):
- x = FSCertStore("/var/certs")
-
- isdir.return_value = False
- with self.assertRaises(SecurityError):
+ with mock_open():
+ cert.get_id.return_value = 1
x = FSCertStore("/var/certs")
+ self.assertIn(1, x._certs)
+ glob.assert_called_with("/var/certs/*")
+
+ # they both end up with the same id
+ glob.return_value = ["foo.cert", "bar.cert"]
+ with self.assertRaises(SecurityError):
+ x = FSCertStore("/var/certs")
+ glob.return_value = ["foo.cert"]
+
+ cert.has_expired.return_value = True
+ with self.assertRaises(SecurityError):
+ x = FSCertStore("/var/certs")
+
+ isdir.return_value = False
+ with self.assertRaises(SecurityError):
+ x = FSCertStore("/var/certs")
2  celery/tests/test_task/test_result.py
View
@@ -350,7 +350,7 @@ def test_result(self):
self.assertIsNone(self.task.result)
-class test_failed_AsyncResult(TestTaskSetResult):
+class test_failed_AsyncResult(test_TaskSetResult):
def setup(self):
self.size = 11
24 celery/tests/test_worker/test_worker_autoreload.py
View
@@ -20,7 +20,7 @@
Autoreloader,
)
-from celery.tests.utils import AppCase, Case, WhateverIO
+from celery.tests.utils import AppCase, Case, WhateverIO, mock_open
class test_WorkerComponent(AppCase):
@@ -37,19 +37,15 @@ def test_create(self):
class test_file_hash(Case):
- @patch("__builtin__.open")
- def test_hash(self, open_):
- context = open_.return_value = Mock()
- context.__enter__ = Mock()
- context.__exit__ = Mock()
- a = context.__enter__.return_value = WhateverIO()
- a.write("the quick brown fox\n")
- a.seek(0)
- A = file_hash("foo")
- b = context.__enter__.return_value = WhateverIO()
- b.write("the quick brown bar\n")
- b.seek(0)
- B = file_hash("bar")
+ def test_hash(self):
+ with mock_open() as a:
+ a.write("the quick brown fox\n")
+ a.seek(0)
+ A = file_hash("foo")
+ with mock_open() as b:
+ b.write("the quick brown bar\n")
+ b.seek(0)
+ B = file_hash("bar")
self.assertNotEqual(A, B)
26 celery/tests/utils.py
View
@@ -498,3 +498,29 @@ def __getattr__(self, attr):
for name in names:
if prev[name]:
sys.modules[name] = prev[name]
+
+
+@contextmanager
+def mock_context(mock, typ=Mock):
+ context = mock.return_value = Mock()
+ context.__enter__ = typ()
+ context.__exit__ = typ()
+
+ def on_exit(*x):
+ if x[0]:
+ raise x[0], x[1], x[2]
+ context.__exit__.side_effect = on_exit
+ context.__enter__.return_value = context
+ yield context
+ context.reset()
+
+
+@contextmanager
+def mock_open(typ=WhateverIO, side_effect=None):
+ with mock.patch("__builtin__.open") as open_:
+ with mock_context(open_) as context:
+ if side_effect is not None:
+ context.__enter__.side_effect = side_effect
+ val = context.__enter__.return_value = typ()
+ yield val
+
Please sign in to comment.
Something went wrong with that request. Please try again.