Skip to content

Commit

Permalink
Merge pull request #281 from mozilla-services/fix_cors_exposed_header…
Browse files Browse the repository at this point in the history
…s_by_method

Service.cors_supported_headers are now filtered by method
  • Loading branch information
almet committed Mar 17, 2015
2 parents 47c78c9 + b4a872e commit 11d7a3b
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 19 deletions.
6 changes: 3 additions & 3 deletions cornice/cors.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def get_cors_preflight_view(service):
def _preflight_view(request):
response = request.response
origin = request.headers.get('Origin')
supported_headers = service.cors_supported_headers
supported_headers = service.cors_supported_headers_for()

if not origin:
request.errors.add('header', 'Origin',
Expand Down Expand Up @@ -121,13 +121,13 @@ def apply_cors_post_request(service, request, response):
response = ensure_origin(service, request, response)
method = _get_method(request)

if (service.cors_support_credentials(method) and
if (service.cors_support_credentials_for(method) and
'Access-Control-Allow-Credentials' not in response.headers):
response.headers['Access-Control-Allow-Credentials'] = 'true'

if request.method != 'OPTIONS':
# Which headers are exposed?
supported_headers = service.cors_supported_headers
supported_headers = service.cors_supported_headers_for(request.method)
if supported_headers:
response.headers['Access-Control-Expose-Headers'] = (
', '.join(supported_headers))
Expand Down
25 changes: 23 additions & 2 deletions cornice/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,15 +415,28 @@ def cors_enabled(self, value):

@property
def cors_supported_headers(self):
"""Backward compatibility for ``cors_supported_headers_for``."""
msg = "The '{0}' property is deprecated. Please start using '{1}' "\
"instead.".format('cors_supported_headers',
'cors_supported_headers_for()')
warnings.warn(msg, DeprecationWarning)
return self.cors_supported_headers_for()

def cors_supported_headers_for(self, method=None):
"""Return an iterable of supported headers for this service.
The supported headers are defined by the :param headers: argument
that is passed to services or methods, at definition time.
"""
headers = set()
for _, _, args in self.definitions:
for meth, _, args in self.definitions:
if args.get('cors_enabled', True):
headers |= set(args.get('cors_headers', ()))
exposed_headers = args.get('cors_headers', ())
if method is not None:
if meth.upper() == method.upper():
return exposed_headers
else:
headers |= set(exposed_headers)
return headers

@property
Expand Down Expand Up @@ -454,6 +467,14 @@ def cors_origins_for(self, method):
return origins

def cors_support_credentials(self, method=None):
"""Backward compatibility for ``cors_support_credentials_for``."""
msg = "The '{0}' property is deprecated. Please start using '{1}' "\
"instead.".format('cors_support_credentials',
'cors_support_credentials_for()')
warnings.warn(msg, DeprecationWarning)
return self.cors_supported_headers_for()

def cors_support_credentials_for(self, method=None):
"""Returns if the given method support credentials.
:param method:
Expand Down
4 changes: 2 additions & 2 deletions cornice/tests/test_cors.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def post(self):
cors_klass.add_view('post', 'post')


@squirel.get(cors_origins=('notmyidea.org',))
@squirel.get(cors_origins=('notmyidea.org',), cors_headers=('X-My-Header',))
def get_squirel(request):
return "squirels"

Expand All @@ -51,7 +51,7 @@ def post_squirel(request):
return "moar squirels (take care)"


@squirel.put(cors_headers=('X-My-Header',))
@squirel.put()
def put_squirel(request):
return "squirels!"

Expand Down
34 changes: 22 additions & 12 deletions cornice/tests/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,37 +347,47 @@ def test_cors_headers_for_service_instanciation(self):
# it is possible to list all the headers supported by a service.
service = Service('coconuts', '/migrate',
cors_headers=('X-Header-Coconut'))
self.assertNotIn('X-Header-Coconut', service.cors_supported_headers)
self.assertNotIn('X-Header-Coconut',
service.cors_supported_headers_for())

service.add_view('POST', _stub)
self.assertIn('X-Header-Coconut', service.cors_supported_headers)
self.assertIn('X-Header-Coconut', service.cors_supported_headers_for())

def test_cors_headers_for_view_definition(self):
# defining headers in the view should work.
service = Service('coconuts', '/migrate')
service.add_view('POST', _stub, cors_headers=('X-Header-Foobar'))
self.assertIn('X-Header-Foobar', service.cors_supported_headers)
self.assertIn('X-Header-Foobar', service.cors_supported_headers_for())

def test_cors_headers_extension(self):
# definining headers in the service and in the view
service = Service('coconuts', '/migrate',
cors_headers=('X-Header-Foobar'))
service.add_view('POST', _stub, cors_headers=('X-Header-Barbaz'))
self.assertIn('X-Header-Foobar', service.cors_supported_headers)
self.assertIn('X-Header-Barbaz', service.cors_supported_headers)
self.assertIn('X-Header-Foobar', service.cors_supported_headers_for())
self.assertIn('X-Header-Barbaz', service.cors_supported_headers_for())

# check that adding the same header twice doesn't make bad things
# happen
service.add_view('POST', _stub, cors_headers=('X-Header-Foobar'),)
self.assertEqual(len(service.cors_supported_headers), 2)
self.assertEqual(len(service.cors_supported_headers_for()), 2)

# check that adding a header on a cors disabled method doesn't
# change anything
service.add_view('put', _stub,
cors_headers=('X-Another-Header',),
cors_enabled=False)

self.assertFalse('X-Another-Header' in service.cors_supported_headers)
self.assertNotIn('X-Another-Header',
service.cors_supported_headers_for())

def test_cors_headers_for_method(self):
# defining headers in the view should work.
service = Service('coconuts', '/migrate')
service.add_view('GET', _stub, cors_headers=('X-Header-Foobar'))
service.add_view('POST', _stub, cors_headers=('X-Header-Barbaz'))
get_headers = service.cors_supported_headers_for(method='GET')
self.assertNotIn('X-Header-Barbaz', get_headers)

def test_cors_supported_methods(self):
foo = Service(name='foo', path='/foo', cors_enabled=True)
Expand Down Expand Up @@ -421,24 +431,24 @@ def test_per_method_supported_origins(self):
def test_credential_support_can_be_enabled(self):
foo = Service(name='foo', path='/foo', cors_credentials=True)
foo.add_view('POST', _stub)
self.assertTrue(foo.cors_support_credentials())
self.assertTrue(foo.cors_support_credentials_for())

def test_credential_support_is_disabled_by_default(self):
foo = Service(name='foo', path='/foo')
foo.add_view('POST', _stub)
self.assertFalse(foo.cors_support_credentials())
self.assertFalse(foo.cors_support_credentials_for())

def test_per_method_credential_support(self):
foo = Service(name='foo', path='/foo')
foo.add_view('GET', _stub, cors_credentials=True)
foo.add_view('POST', _stub)
self.assertTrue(foo.cors_support_credentials('GET'))
self.assertFalse(foo.cors_support_credentials('POST'))
self.assertTrue(foo.cors_support_credentials_for('GET'))
self.assertFalse(foo.cors_support_credentials_for('POST'))

def test_method_takes_precendence_for_credential_support(self):
foo = Service(name='foo', path='/foo', cors_credentials=True)
foo.add_view('GET', _stub, cors_credentials=False)
self.assertFalse(foo.cors_support_credentials('GET'))
self.assertFalse(foo.cors_support_credentials_for('GET'))

def test_max_age_can_be_defined(self):
foo = Service(name='foo', path='/foo', cors_max_age=42)
Expand Down

0 comments on commit 11d7a3b

Please sign in to comment.