diff --git a/keystone/token/backends/memcache.py b/keystone/token/backends/memcache.py index 91d547d8e6..ae4ecc0c17 100644 --- a/keystone/token/backends/memcache.py +++ b/keystone/token/backends/memcache.py @@ -63,10 +63,31 @@ def create_token(self, token_id, data): expires_ts = utils.unixtime(data_copy['expires']) kwargs['time'] = expires_ts self.client.set(ptk, data_copy, **kwargs) + if 'id' in data['user']: + token_data = token_id + user_id = data['user']['id'] + user_key = 'usertokens-%s' % user_id + if not self.client.append(user_key, ',%s' % token_data): + if not self.client.add(user_key, token_data): + if not self.client.append(user_key, ',%s' % token_data): + msg = _('Unable to add token user list.') + raise exception.UnexpectedError(msg) return copy.deepcopy(data_copy) def delete_token(self, token_id): # Test for existence self.get_token(token_id) ptk = self._prefix_token_id(token_id) - return self.client.delete(ptk) + result = self.client.delete(ptk) + return result + + def list_tokens(self, user_id): + tokens = [] + user_record = self.client.get('usertokens-%s' % user_id) or "" + token_list = user_record.split(',') + for token_id in token_list: + ptk = self._prefix_token_id(token_id) + token = self.client.get(ptk) + if token: + tokens.append(token_id) + return tokens diff --git a/tests/test_backend.py b/tests/test_backend.py index d738e0b47d..9a58f40c93 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -316,7 +316,8 @@ def test_add_user_to_tenant(self): class TokenTests(object): def test_token_crud(self): token_id = uuid.uuid4().hex - data = {'id': token_id, 'a': 'b'} + data = {'id': token_id, 'a': 'b', + 'user': {'id': 'testuserid'}} data_ref = self.token_api.create_token(token_id, data) expires = data_ref.pop('expires') self.assertTrue(isinstance(expires, datetime.datetime)) @@ -329,22 +330,52 @@ def test_token_crud(self): self.token_api.delete_token(token_id) self.assertRaises(exception.TokenNotFound, - self.token_api.delete_token, token_id) + self.token_api.get_token, token_id) self.assertRaises(exception.TokenNotFound, - self.token_api.get_token, token_id) + self.token_api.delete_token, token_id) def test_expired_token(self): token_id = uuid.uuid4().hex expire_time = datetime.datetime.utcnow() - datetime.timedelta(minutes=1) - data = {'id': token_id, 'a': 'b', 'expires': expire_time} + data = {'id': token_id, 'a': 'b', 'expires': expire_time, + 'user': {'id': 'testuserid'}} data_ref = self.token_api.create_token(token_id, data) self.assertDictEquals(data_ref, data) self.assertRaises(exception.TokenNotFound, self.token_api.get_token, token_id) + def create_token_sample_data(self): + token_id = uuid.uuid4().hex + data = {'id': token_id, 'a': 'b', + 'user': {'id': 'testuserid'}} + self.token_api.create_token(token_id, data) + return token_id + + def test_token_list(self): + tokens = self.token_api.list_tokens('testuserid') + self.assertEquals(len(tokens), 0) + token_id1 = self.create_token_sample_data() + tokens = self.token_api.list_tokens('testuserid') + self.assertEquals(len(tokens), 1) + self.assertIn(token_id1, tokens) + token_id2 = self.create_token_sample_data() + tokens = self.token_api.list_tokens('testuserid') + self.assertEquals(len(tokens), 2) + self.assertIn(token_id2, tokens) + self.assertIn(token_id1, tokens) + self.token_api.delete_token(token_id1) + tokens = self.token_api.list_tokens('testuserid') + self.assertIn(token_id2, tokens) + self.assertNotIn(token_id1, tokens) + self.token_api.delete_token(token_id2) + tokens = self.token_api.list_tokens('testuserid') + self.assertNotIn(token_id2, tokens) + self.assertNotIn(token_id1, tokens) + def test_null_expires_token(self): token_id = uuid.uuid4().hex - data = {'id': token_id, 'a': 'b', 'expires': None} + data = {'id': token_id, 'id_hash': token_id, 'a': 'b', 'expires': None, + 'user': {'id': 'testuserid'}} data_ref = self.token_api.create_token(token_id, data) self.assertDictEquals(data_ref, data) new_data_ref = self.token_api.get_token(token_id) diff --git a/tests/test_backend_memcache.py b/tests/test_backend_memcache.py index 2c07580b86..06f1c310c3 100644 --- a/tests/test_backend_memcache.py +++ b/tests/test_backend_memcache.py @@ -34,6 +34,18 @@ def __init__(self, *args, **kwargs): """Ignores the passed in args.""" self.cache = {} + def add(self, key, value): + if self.get(key): + return False + return self.set(key, value) + + def append(self, key, value): + existing_value = self.get(key) + if existing_value: + self.set(key, existing_value + value) + return True + return False + def check_key(self, key): if not isinstance(key, str): raise memcache.Client.MemcachedStringEncodingError() @@ -45,8 +57,6 @@ def get(self, key): now = time.mktime(datetime.datetime.utcnow().utctimetuple()) if obj and (obj[1] == 0 or obj[1] > now): return obj[0] - else: - raise exception.TokenNotFound(token_id=key) def set(self, key, value, time=0): """Sets the value for a key.""" @@ -71,6 +81,7 @@ def setUp(self): def test_get_unicode(self): token_id = unicode(uuid.uuid4().hex) - data = {'id': token_id, 'a': 'b'} + data = {'id': token_id, 'a': 'b', + 'user': {'id': 'testuserid'}} self.token_api.create_token(token_id, data) self.token_api.get_token(token_id)