Skip to content

Commit

Permalink
Enable SQL tests for oauth
Browse files Browse the repository at this point in the history
Enables testing of the SQL backend for oauth.
As a result of running the tests, I noticed that this also
fixes an update_consumer test cases too.

fixes bug: #1215483
fixes bug: #1216447

Change-Id: I206d164caa66c3211cfc216d13e3d0bab0e7d54a
  • Loading branch information
Steve Martinelli committed Aug 29, 2013
1 parent 8fdfbf0 commit a746527
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 24 deletions.
2 changes: 1 addition & 1 deletion keystone/auth/plugins/oauth1.py
Expand Up @@ -44,7 +44,7 @@ def authenticate(self, context, auth_info, auth_context):
attribute='oauth_token', target='request')

acc_token = self.oauth_api.get_access_token(access_token_id)
consumer = self.oauth_api._get_consumer(consumer_id)
consumer = self.oauth_api.get_consumer_with_secret(consumer_id)

expires_at = acc_token['expires_at']
if expires_at:
Expand Down
25 changes: 11 additions & 14 deletions keystone/contrib/oauth1/backends/sql.py
Expand Up @@ -84,19 +84,20 @@ class OAuth1(sql.Base):
def db_sync(self):
migration.db_sync()

def _get_consumer(self, consumer_id):
session = self.get_session()
def _get_consumer(self, session, consumer_id):
consumer_ref = session.query(Consumer).get(consumer_id)
if consumer_ref is None:
raise exception.NotFound(_('Consumer not found'))
return consumer_ref

def get_consumer(self, consumer_id):
def get_consumer_with_secret(self, consumer_id):
session = self.get_session()
consumer_ref = session.query(Consumer).get(consumer_id)
if consumer_ref is None:
raise exception.NotFound(_('Consumer not found'))
return core.filter_consumer(consumer_ref.to_dict())
consumer_ref = self._get_consumer(session, consumer_id)
return consumer_ref.to_dict()

def get_consumer(self, consumer_id):
return core.filter_consumer(
self.get_consumer_with_secret(consumer_id))

def create_consumer(self, consumer):
consumer['secret'] = uuid.uuid4().hex
Expand All @@ -110,7 +111,7 @@ def create_consumer(self, consumer):
return consumer_ref.to_dict()

def _delete_consumer(self, session, consumer_id):
consumer_ref = self._get_consumer(consumer_id)
consumer_ref = self._get_consumer(session, consumer_id)
q = session.query(Consumer)
q = q.filter_by(id=consumer_id)
q.delete(False)
Expand Down Expand Up @@ -154,15 +155,11 @@ def list_consumers(self):
def update_consumer(self, consumer_id, consumer):
session = self.get_session()
with session.begin():
consumer_ref = self._get_consumer(consumer_id)
consumer_ref = self._get_consumer(session, consumer_id)
old_consumer_dict = consumer_ref.to_dict()
old_consumer_dict.update(consumer)
new_consumer = Consumer.from_dict(old_consumer_dict)
for attr in Consumer.attributes:
if (attr != 'id' or attr != 'secret'):
setattr(consumer_ref,
attr,
getattr(new_consumer, attr))
consumer_ref.description = new_consumer.description
consumer_ref.extra = new_consumer.extra
session.flush()
return core.filter_consumer(consumer_ref.to_dict())
Expand Down
4 changes: 2 additions & 2 deletions keystone/contrib/oauth1/controllers.py
Expand Up @@ -172,7 +172,7 @@ def create_request_token(self, context):
attribute='requested_project_id', target='request')

req_role_ids = requested_role_ids.split(',')
consumer_ref = self.oauth_api._get_consumer(consumer_id)
consumer_ref = self.oauth_api.get_consumer_with_secret(consumer_id)
consumer = oauth1.Consumer(key=consumer_ref['id'],
secret=consumer_ref['secret'])

Expand Down Expand Up @@ -251,7 +251,7 @@ def create_access_token(self, context):
raise exception.ValidationError(
attribute='oauth_verifier', target='request')

consumer = self.oauth_api._get_consumer(consumer_id)
consumer = self.oauth_api.get_consumer_with_secret(consumer_id)
req_token = self.oauth_api.get_request_token(
request_token_id)

Expand Down
17 changes: 16 additions & 1 deletion keystone/contrib/oauth1/core.py
Expand Up @@ -169,7 +169,22 @@ def list_consumers(self):
raise exception.NotImplemented()

def get_consumer(self, consumer_id):
"""Get consumer.
"""Get consumer, returns the consumer id (key)
and description.
:param consumer_id: id of consumer to get
:type consumer_ref: string
:returns: consumer_ref
"""
raise exception.NotImplemented()

def get_consumer_with_secret(self, consumer_id):
"""Like get_consumer() but returned consumer_ref includes
the consumer secret.
Secrets should only be shared upon consumer creation; the
consumer secret is required to verify incoming OAuth requests.
:param consumer_id: id of consumer to get
:type consumer_ref: string
Expand Down
3 changes: 0 additions & 3 deletions keystone/tests/test_overrides.conf
Expand Up @@ -19,9 +19,6 @@ backend = dogpile.cache.memory
enabled = True
debug_cache_backend = True

[oauth1]
driver = keystone.contrib.oauth1.backends.kvs.OAuth1

[signing]
certfile = ../../examples/pki/certs/signing_cert.pem
keyfile = ../../examples/pki/private/signing_key.pem
Expand Down
17 changes: 14 additions & 3 deletions keystone/tests/test_v3_oauth1.py
Expand Up @@ -22,9 +22,12 @@
import webtest

from keystone.common import cms
from keystone.common.sql import migration
from keystone import config
from keystone import contrib
from keystone.contrib import oauth1
from keystone.contrib.oauth1 import controllers
from keystone.openstack.common import importutils
from keystone.tests import core

import test_v3
Expand All @@ -35,12 +38,22 @@


class OAuth1Tests(test_v3.RestfulTestCase):
EXTENSION_NAME = 'oauth1'

def setup_database(self):
super(OAuth1Tests, self).setup_database()
package_name = "%s.%s.migrate_repo" % (contrib.__name__,
self.EXTENSION_NAME)
package = importutils.import_module(package_name)
self.repo_path = os.path.abspath(os.path.dirname(package.__file__))
migration.db_version_control(version=None, repo_path=self.repo_path)
migration.db_sync(version=None, repo_path=self.repo_path)

def setUp(self):
super(OAuth1Tests, self).setUp()
self.controller = controllers.OAuthControllerV3()
self.base_url = CONF.public_endpoint % CONF + "v3"
self._generate_paste_config()
self.load_backends()
self.admin_app = webtest.TestApp(
self.loadapp('v3_oauth1', name='admin'))
self.public_app = webtest.TestApp(
Expand Down Expand Up @@ -169,7 +182,6 @@ def test_consumer_update(self):
consumer = self._create_single_consumer()
original_id = consumer.get('id')
original_description = consumer.get('description')
original_secret = consumer.get('secret')
update_description = original_description + "_new"

update_ref = {'description': update_description}
Expand All @@ -179,7 +191,6 @@ def test_consumer_update(self):
consumer = update_resp.result.get('consumer')
self.assertEqual(consumer.get('description'), update_description)
self.assertEqual(consumer.get('id'), original_id)
self.assertEqual(consumer.get('secret'), original_secret)

def test_consumer_update_bad_secret(self):
consumer = self._create_single_consumer()
Expand Down

0 comments on commit a746527

Please sign in to comment.