Permalink
Browse files

Make return values of pyramid_ldap3 compatible with pyramid_ldap.

  • Loading branch information...
1 parent 19ac972 commit 7fe0253991ae7e8a7937bb06cce4bdede6208fb4 @Cito committed Jul 3, 2014
Showing with 108 additions and 65 deletions.
  1. +1 −1 docs/index.rst
  2. +46 −30 pyramid_ldap3/__init__.py
  3. +60 −33 pyramid_ldap3/tests.py
  4. +1 −1 sampleapp/views.py
View
@@ -166,7 +166,7 @@ Here's a small application which uses the ``pyramid_ldap3`` API:
connector = get_ldap_connector(request)
data = connector.authenticate(login, password)
if data is not None:
- dn = data['dn']
+ dn = data[0]
headers = remember(request, dn)
return HTTPFound('/', headers=headers)
else:
View
@@ -25,45 +25,46 @@ class _LDAPQuery(object):
Provides rudimentary in-RAM caching of query results.
"""
- def __init__(self, base_dn, filter_tmpl, scope, cache_period):
+ def __init__(self, base_dn, filter_tmpl, scope, attributes, cache_period):
self.base_dn = base_dn
self.filter_tmpl = filter_tmpl
self.scope = scope
+ self.attributes = attributes
self.cache_period = cache_period
self.last_timeslice = 0
self.cache = {}
def __str__(self):
return ('base_dn={base_dn}, filter_tmpl={filter_tmpl}, '
- 'scope={scope}, cache_period={cache_period}'.format(**self.__dict__))
+ 'scope={scope}, attributes={attributes}, '
+ 'cache_period={cache_period}'.format(**self.__dict__))
def query_cache(self, cache_key):
now = time()
ts = _timeslice(self.cache_period, now)
if ts > self.last_timeslice:
- logger.debug('dumping cache; now ts: %r, last_ts: %r', ts, self.last_timeslice)
+ logger.debug('dumping cache; now ts: %r, last_ts: %r',
+ ts, self.last_timeslice)
self.cache = {}
self.last_timeslice = ts
return self.cache.get(cache_key)
def execute(self, conn, **kw):
- cache_key = (self.base_dn % kw, self.filter_tmpl % kw, self.scope)
+ cache_key = (self.base_dn % kw, self.filter_tmpl % kw)
logger.debug('searching for %r', cache_key)
- if self.cache_period:
- result = self.query_cache(cache_key)
- if result is None:
- ret = conn.search(*cache_key)
- result, ret = conn.get_response(ret)
- self.cache[cache_key] = result
- else:
- logger.debug('result for %r retrieved from cache', cache_key)
- else:
- ret = conn.search(*cache_key)
+ result = self.query_cache(cache_key) if self.cache_period else None
+ if result is None:
+ ret = conn.search(search_scope=self.scope,
+ attributes=self.attributes, *cache_key)
result, ret = conn.get_response(ret)
+ result = [(r['dn'], r['attributes']) for r in result]
+ self.cache[cache_key] = result
+ else:
+ logger.debug('result for %r retrieved from cache', cache_key)
logger.debug('search result: %r', result)
@@ -160,10 +161,13 @@ def authenticate(self, login, password):
'ldap_set_login_query was not called during setup')
result = search.execute(conn, login=login, password=password)
- try:
- login_dn = result[0]['dn']
- except (IndexError, KeyError, TypeError):
+
+ if not result:
return None
+ if len(result) > 1:
+ logger.debug('Non-unique result for login %r', login)
+ result = result[0]
+ login_dn = result[0]
try:
conn = self.manager.connection(login_dn, password)
@@ -175,7 +179,7 @@ def authenticate(self, login, password):
login, exc_info=True)
return None
- return result[0]
+ return result
def user_groups(self, userdn):
"""Get the groups the user belongs to.
@@ -209,14 +213,17 @@ def user_groups(self, userdn):
def ldap_set_login_query(config, base_dn, filter_tmpl,
- scope=ldap3.SEARCH_SCOPE_SINGLE_LEVEL, cache_period=0):
+ scope=ldap3.SEARCH_SCOPE_SINGLE_LEVEL, attributes=None,
+ cache_period=0):
"""Configurator method to set the LDAP login search.
``base_dn`` is the DN at which to begin the search.
``filter_tmpl`` is a string which can be used as an LDAP filter:
it should contain the replacement value ``%(login)s``.
- Scope is any valid LDAP scope value
+ ``scope`` is any valid LDAP scope value
(e.g. ``ldap3.SEARCH_SCOPE_SINGLE_LEVEL``).
+ ``attributes`` is a list of attributes that shall be returned
+ (can also be set to None or ``ldap3.ALL_ATTRIBUTES``).
``cache_period`` is the number of seconds to cache login search results;
if it is 0, login search results will not be cached.
@@ -231,7 +238,7 @@ def ldap_set_login_query(config, base_dn, filter_tmpl,
a valid login.
"""
- query = _LDAPQuery(base_dn, filter_tmpl, scope, cache_period)
+ query = _LDAPQuery(base_dn, filter_tmpl, scope, attributes, cache_period)
def register():
config.registry.ldap_login_query = query
@@ -246,14 +253,17 @@ def register():
def ldap_set_groups_query(config, base_dn, filter_tmpl,
- scope=ldap3.SEARCH_SCOPE_WHOLE_SUBTREE, cache_period=0):
+ scope=ldap3.SEARCH_SCOPE_WHOLE_SUBTREE, attributes=None,
+ cache_period=0):
""" Configurator method to set the LDAP groups search.
``base_dn`` is the DN at which to begin the search.
``filter_tmpl`` is a string which can be used as an LDAP filter:
it should contain the replacement value ``%(userdn)s``.
- Scope is any valid LDAP scope value
- (e.g. ``ldap3.SEARCH_SCOPE_WHOLE_SUBTREE``).
+ ``scope`` is any valid LDAP scope value
+ (e.g. ``ldap3.SEARCH_SCOPE_SINGLE_LEVEL``).
+ ``attributes`` is a list of attributes that shall be returned
+ (can also be set to None or ``ldap3.ALL_ATTRIBUTES``).
``cache_period`` is the number of seconds to cache groups search results;
if it is 0, groups search results will not be cached.
@@ -266,7 +276,7 @@ def ldap_set_groups_query(config, base_dn, filter_tmpl,
"""
- query = _LDAPQuery(base_dn, filter_tmpl, scope, cache_period)
+ query = _LDAPQuery(base_dn, filter_tmpl, scope, attributes, cache_period)
def register():
config.registry.ldap_groups_query = query
@@ -276,6 +286,7 @@ def register():
None,
str(query),
'pyramid_ldap3 groups query')
+
config.action('ldap-set-groups-query', register, introspectables=(intr,))
@@ -327,6 +338,12 @@ def get_ldap_connector(request):
return connector
+def get_groups(userdn, request):
+ """Raw groupfinder function returning the complete group query result."""
+ connector = get_ldap_connector(request)
+ return connector.user_groups(userdn)
+
+
def groupfinder(userdn, request):
"""Groupfinder function for Pyramid.
@@ -335,11 +352,10 @@ def groupfinder(userdn, request):
belonging to the user specified by ``userdn`` to as a principal
in the list of results; if the user does not exist, it returns None.
"""
- connector = get_ldap_connector(request)
- group_list = connector.user_groups(userdn)
- if group_list is None:
- return None
- return [group['dn'] for group in group_list]
+ groups = get_groups(userdn, request)
+ if groups:
+ groups = [r[0] for r in groups]
+ return groups
def includeme(config):
View
@@ -17,6 +17,26 @@ def test_it(self):
['ldap_setup', 'ldap_set_login_query', 'ldap_set_groups_query'])
+class Test_get_groups(unittest.TestCase):
+
+ def _callFUT(self, dn, request):
+ from pyramid_ldap3 import get_groups
+ return get_groups(dn, request)
+
+ def test_no_group_list(self):
+ request = testing.DummyRequest()
+ request.ldap_connector = DummyLDAPConnector('testdn', None)
+ result = self._callFUT('testdn', request)
+ self.assertTrue(result is None)
+
+ def test_with_group_list(self):
+ request = testing.DummyRequest()
+ request.ldap_connector = DummyLDAPConnector(
+ 'testdn', [('a', 'b')])
+ result = self._callFUT('testdn', request)
+ self.assertEqual(result, [('a', 'b')])
+
+
class Test_groupfinder(unittest.TestCase):
def _callFUT(self, dn, request):
@@ -27,12 +47,12 @@ def test_no_group_list(self):
request = testing.DummyRequest()
request.ldap_connector = DummyLDAPConnector('testdn', None)
result = self._callFUT('testdn', request)
- self.assertEqual(result, None)
+ self.assertTrue(result is None)
def test_with_group_list(self):
request = testing.DummyRequest()
request.ldap_connector = DummyLDAPConnector(
- 'testdn', [{'dn': 'groupdn', 'more': None}])
+ 'testdn', [('groupdn', None)])
result = self._callFUT('testdn', request)
self.assertEqual(result, ['groupdn'])
@@ -182,31 +202,29 @@ def test_authenticate_search_returns_non_one_result(self):
registry = Dummy()
registry.ldap_login_query = DummySearch([])
inst = self._makeOne(registry, manager)
- self.assertEqual(inst.authenticate(None, None), None)
+ self.assertTrue(inst.authenticate(None, None) is None)
def test_authenticate_empty_password(self):
manager = DummyManager()
registry = Dummy()
- registry.ldap_login_query = DummySearch([{'dn': 'a', 'more': 'b'}])
+ registry.ldap_login_query = DummySearch([('a', 'b')])
inst = self._makeOne(registry, manager)
- self.assertEqual(inst.authenticate('foo', ''), None)
+ self.assertTrue(inst.authenticate('foo', '') is None)
def test_authenticate_search_returns_one_result(self):
manager = DummyManager()
registry = Dummy()
- registry.ldap_login_query = DummySearch(
- [{'dn': 'a', 'more': 'b'}])
+ registry.ldap_login_query = DummySearch([('a', 'b')])
inst = self._makeOne(registry, manager)
- self.assertEqual(inst.authenticate(None, None),
- {'dn': 'a', 'more': 'b'})
+ self.assertEqual(inst.authenticate(None, None), ('a', 'b'))
def test_authenticate_search_bind_raises(self):
from pyramid_ldap3 import ldap3
manager = DummyManager([None, ldap3.LDAPException])
registry = Dummy()
- registry.ldap_login_query = DummySearch([{'dn': 'a', 'more': 'b'}])
+ registry.ldap_login_query = DummySearch([('a', 'b')])
inst = self._makeOne(registry, manager)
- self.assertEqual(inst.authenticate(None, None), None)
+ self.assertTrue(inst.authenticate(None, None) is None)
def test_user_groups_no_ldap_groups_query(self):
manager = DummyManager()
@@ -216,60 +234,67 @@ def test_user_groups_no_ldap_groups_query(self):
def test_user_groups_search_returns_result(self):
manager = DummyManager()
registry = Dummy()
- registry.ldap_groups_query = DummySearch([{'dn': 'a', 'more': 'b'}])
+ registry.ldap_groups_query = DummySearch([('a', 'b')])
inst = self._makeOne(registry, manager)
- self.assertEqual(inst.user_groups(None), [{'dn': 'a', 'more': 'b'}])
+ self.assertEqual(inst.user_groups(None), [('a', 'b')])
def test_user_groups_execute_raises(self):
from pyramid_ldap3 import ldap3
manager = DummyManager()
registry = Dummy()
registry.ldap_groups_query = DummySearch(
- [{'dn': 'a', 'more': 'b'}], ldap3.LDAPException)
+ [('a', 'b')], ldap3.LDAPException)
inst = self._makeOne(registry, manager)
- self.assertEqual(inst.user_groups(None), None)
+ self.assertTrue(inst.user_groups(None) is None)
class Test_LDAPQuery(unittest.TestCase):
- def _makeOne(self, base_dn, filter_tmpl, scope, cache_period):
+ def _makeOne(self, base_dn, filter_tmpl, scope, attributes, cache_period):
from pyramid_ldap3 import _LDAPQuery
- return _LDAPQuery(base_dn, filter_tmpl, scope, cache_period)
+ return _LDAPQuery(
+ base_dn, filter_tmpl, scope, attributes, cache_period)
def test_query_cache_no_rollover(self):
- inst = self._makeOne(None, None, None, 1)
+ inst = self._makeOne(None, None, None, None, 1)
inst.last_timeslice = 1 << 31
inst.cache['foo'] = 'bar'
self.assertEqual(inst.query_cache('foo'), 'bar')
def test_query_cache_with_rollover(self):
- inst = self._makeOne(None, None, None, 1)
+ inst = self._makeOne(None, None, None, None, 1)
inst.cache['foo'] = 'bar'
- self.assertEqual(inst.query_cache('foo'), None)
+ self.assertTrue(inst.query_cache('foo') is None)
self.assertEqual(inst.cache, {})
self.assertNotEqual(inst.last_timeslice, 0)
def test_execute_no_cache_period(self):
- inst = self._makeOne('%(login)s', '%(login)s', None, 0)
- conn = DummyConnection({'dn': 'abc'})
+ inst = self._makeOne('DN=Org', '(cn=%(login)s)', 'scope', 'attrs', 0)
+ conn = DummyConnection([{'dn': 'a', 'attributes': {'b': 'c'}}])
result = inst.execute(conn, login='foo')
- self.assertEqual(result, {'dn': 'abc'})
- self.assertEqual(conn.args, ('foo', 'foo', None))
+ self.assertEqual(result, [('a', {'b': 'c'})])
+ self.assertEqual(conn.args, ('DN=Org', '(cn=foo)'))
+ self.assertEqual(conn.kwargs,
+ {'attributes': 'attrs', 'search_scope': 'scope'})
def test_execute_with_cache_period_miss(self):
- inst = self._makeOne('%(login)s', '%(login)s', None, 1)
- conn = DummyConnection({'dn': 'abc'})
+ inst = self._makeOne('DN=Org', '(cn=%(login)s)', 'scope', 'attrs', 1)
+ conn = DummyConnection([{'dn': 'a', 'attributes': {'b': 'c'}}])
result = inst.execute(conn, login='foo')
- self.assertEqual(result, {'dn': 'abc'})
- self.assertEqual(conn.args, ('foo', 'foo', None))
+ self.assertEqual(result, [('a', {'b': 'c'})])
+ self.assertEqual(conn.args, ('DN=Org', '(cn=foo)'))
+ self.assertEqual(conn.kwargs, {
+ 'attributes': 'attrs', 'search_scope': 'scope'})
def test_execute_with_cache_period_hit(self):
- inst = self._makeOne('%(login)s', '%(login)s', None, 1)
+ inst = self._makeOne('DN=Org', '(cn=%(login)s)', 'scope', 'attrs', 1)
inst.last_timeslice = 1 << 31
- inst.cache[('foo', 'foo', None)] = {'dn': 'def'}
- conn = DummyConnection({'dn': 'abc'})
+ inst.cache[('DN=Org', '(cn=foo)')] = ('d', {'e': 'f'})
+ conn = DummyConnection([{'dn': 'a', 'attributes': {'b': 'c'}}])
result = inst.execute(conn, login='foo')
- self.assertEqual(result, {'dn': 'def'})
+ self.assertEqual(result, ('d', {'e': 'f'}))
+ self.assertTrue(conn.args is None)
+ self.assertTrue(conn.kwargs is None)
class DummyLDAPConnector(object):
@@ -358,9 +383,11 @@ def __init__(self, result):
self.result = result
self.result_id = 0
self.args = None
+ self.kwargs = None
- def search(self, *args):
+ def search(self, *args, **kwargs):
self.args = args
+ self.kwargs = kwargs
self.result_id += 1
return self.result_id
View
@@ -32,7 +32,7 @@ def login(request):
connector = get_ldap_connector(request)
data = connector.authenticate(login, password)
if data is not None:
- dn = data['dn']
+ dn = data[0]
headers = remember(request, dn)
return HTTPFound('/', headers=headers)
error = 'Invalid credentials'

0 comments on commit 7fe0253

Please sign in to comment.