Skip to content

Commit

Permalink
Make Origin scheme-aware (#388)
Browse files Browse the repository at this point in the history
Regarding #379.
  • Loading branch information
wgonczaronek authored and adamchainz committed May 10, 2019
1 parent 424a7e2 commit 3a1c92d
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 26 deletions.
3 changes: 3 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ Pending

.. Insert new release notes below this line
* Origin is now scheme-aware. Deprecation warning has been added when origin
without scheme is included.

2.5.3 (2019-04-28)
------------------

Expand Down
17 changes: 10 additions & 7 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,20 @@ A list of origin hostnames that are authorized to make cross-site HTTP
requests. The value ``'null'`` can also appear in this list, and will match the
``Origin: null`` header that is used in `"privacy-sensitive contexts"
<https://tools.ietf.org/html/rfc6454#section-6>`_, such as when the client is
running from a ``file://`` domain. Defaults to ``[]``.
running from a ``file://`` domain. Defaults to ``[]``. Proper origin should consist of
scheme, host and port (which could be given implicitly, e.g. for http it is assumed that the port is
80). Skipping scheme is allowed only for backward compatibility, deprecation warning will be raised
if this is discovered.

Example:

.. code-block:: python
CORS_ORIGIN_WHITELIST = (
'google.com',
'hostname.example.com',
'localhost:8000',
'127.0.0.1:9000'
'https://google.com',
'http://hostname.example.com',
'http://localhost:8000',
'http://127.0.0.1:9000'
)
Expand Down Expand Up @@ -260,8 +263,8 @@ For example:
.. code-block:: python
CORS_ORIGIN_WHITELIST = (
'read.only.com',
'change.allowed.com',
'http://read.only.com',
'http://change.allowed.com',
)
CSRF_TRUSTED_ORIGINS = (
Expand Down
40 changes: 38 additions & 2 deletions corsheaders/middleware.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import

import re
import warnings

from django import http
from django.apps import apps
Expand Down Expand Up @@ -144,8 +145,10 @@ def process_response(self, request, response):
return response

def origin_found_in_white_lists(self, origin, url):
whitelisted_origins = self._get_parsed_whitelisted_origins(conf.CORS_ORIGIN_WHITELIST)
self._check_for_origins_without_scheme(whitelisted_origins)
return (
url.netloc in conf.CORS_ORIGIN_WHITELIST
self._url_in_whitelist(url, whitelisted_origins)
or (origin == 'null' and origin in conf.CORS_ORIGIN_WHITELIST)
or self.regex_domain_match(origin)
)
Expand All @@ -159,7 +162,11 @@ def origin_found_in_model(self, url):
if conf.CORS_MODEL is None:
return False
model = apps.get_model(*conf.CORS_MODEL.split('.'))
return model.objects.filter(cors=url.netloc).exists()
queryset = model.objects.filter(cors__icontains=url.netloc).values_list('cors', flat=True)

whitelisted_origins = self._get_parsed_whitelisted_origins(queryset)
self._check_for_origins_without_scheme(whitelisted_origins)
return self._url_in_whitelist(url, whitelisted_origins)

def is_enabled(self, request):
return (
Expand All @@ -176,3 +183,32 @@ def check_signal(self, request):
return_value for
function, return_value in signal_responses
)

def _get_parsed_whitelisted_origins(self, origins):
whitelisted_origins = []
for origin in origins:
# Note that when port is defined explicitly, it's part of netloc/path
parsed_origin = urlparse(origin)
whitelisted_origins.append(
{
'scheme': parsed_origin.scheme,
'host': parsed_origin.netloc or parsed_origin.path
}
)
return whitelisted_origins

def _check_for_origins_without_scheme(self, origins):
if any((origin['scheme'] == '' and origin['host'] != 'null' for origin in origins)):
warnings.warn('Passing origins without scheme will be deprecated.', DeprecationWarning)

def _url_in_whitelist(self, url, origins_whitelist):
possible_matching_origins = [
origin for origin in origins_whitelist if origin['host'] == url.netloc
]
if not possible_matching_origins:
return False
else:
for origin in possible_matching_origins:
if origin['scheme'] == '' or origin['scheme'] == url.scheme:
return True
return False
67 changes: 50 additions & 17 deletions tests/test_middleware.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import absolute_import

import warnings

from django.http import HttpResponse
from django.test import TestCase
from django.test.utils import override_settings
Expand Down Expand Up @@ -30,17 +32,31 @@ def test_get_origin_vary_by_default(self):
resp = self.client.get('/')
assert resp['Vary'] == 'Origin'

@override_settings(CORS_ORIGIN_WHITELIST=['example.com'])
@override_settings(CORS_ORIGIN_WHITELIST=['http://example.com'])
def test_get_not_in_whitelist(self):
resp = self.client.get('/', HTTP_ORIGIN='http://example.org')
assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp

@override_settings(CORS_ORIGIN_WHITELIST=['example.com', 'example.org'])
@override_settings(CORS_ORIGIN_WHITELIST=['https://example.org'])
def test_get_not_in_whitelist_due_to_wrong_scheme(self):
resp = self.client.get('/', HTTP_ORIGIN='http://example.org')
assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp

@override_settings(CORS_ORIGIN_WHITELIST=['example.org'])
def test_get_without_scheme_in_whitelist_raises_warning(self):
with warnings.catch_warnings(record=True) as warn:
resp = self.client.get('/', HTTP_ORIGIN='http://example.org')
assert ACCESS_CONTROL_ALLOW_ORIGIN in resp
assert len(warn) == 1
assert issubclass(warn[-1].category, DeprecationWarning)
assert 'Passing origins without scheme will be deprecated.' in str(warn[-1].message)

@override_settings(CORS_ORIGIN_WHITELIST=['http://example.com', 'http://example.org'])
def test_get_in_whitelist(self):
resp = self.client.get('/', HTTP_ORIGIN='http://example.org')
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == 'http://example.org'

@override_settings(CORS_ORIGIN_WHITELIST=['example.com', 'null'])
@override_settings(CORS_ORIGIN_WHITELIST=['http://example.com', 'null'])
def test_null_in_whitelist(self):
resp = self.client.get('/', HTTP_ORIGIN='null')
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == 'null'
Expand Down Expand Up @@ -101,7 +117,7 @@ def test_options_no_max_age(self):
@override_settings(
CORS_ALLOW_METHODS=['OPTIONS'],
CORS_ALLOW_CREDENTIALS=True,
CORS_ORIGIN_WHITELIST=('localhost:9000',),
CORS_ORIGIN_WHITELIST=('http://localhost:9000',),
)
def test_options_whitelist_with_port(self):
resp = self.client.options('/', HTTP_ORIGIN='http://localhost:9000')
Expand All @@ -127,14 +143,31 @@ def test_options_will_not_add_origin_when_domain_not_found_in_origin_regex_white

@override_settings(CORS_MODEL='testapp.CorsModel')
def test_get_when_custom_model_enabled(self):
CorsModel.objects.create(cors='example.com')
CorsModel.objects.create(cors='http://example.com')
resp = self.client.get('/', HTTP_ORIGIN='http://example.com')
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == 'http://example.com'
assert ACCESS_CONTROL_ALLOW_CREDENTIALS not in resp

@override_settings(CORS_MODEL='testapp.CorsModel')
def test_get_when_custom_model_enabled_without_scheme(self):
with warnings.catch_warnings(record=True) as warn:
CorsModel.objects.create(cors='example.com')
resp = self.client.get('/', HTTP_ORIGIN='http://example.com')

assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == 'http://example.com'
assert len(warn) == 1
assert issubclass(warn[-1].category, DeprecationWarning)
assert 'Passing origins without scheme will be deprecated.' in str(warn[-1].message)

@override_settings(CORS_MODEL='testapp.CorsModel')
def test_get_when_custom_model_enabled_with_different_scheme(self):
CorsModel.objects.create(cors='https://example.com')
resp = self.client.get('/', HTTP_ORIGIN='http://example.com')
assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp

@override_settings(CORS_MODEL='testapp.CorsModel', CORS_ALLOW_CREDENTIALS=True)
def test_get_when_custom_model_enabled_and_allow_credentials(self):
CorsModel.objects.create(cors='example.com')
CorsModel.objects.create(cors='http://example.com')
resp = self.client.get('/', HTTP_ORIGIN='http://example.com')
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == 'http://example.com'
assert resp[ACCESS_CONTROL_ALLOW_CREDENTIALS] == 'true'
Expand All @@ -159,7 +192,7 @@ def test_options_no_header(self):

@override_settings(CORS_MODEL='testapp.CorsModel')
def test_options_when_custom_model_enabled(self):
CorsModel.objects.create(cors='example.com')
CorsModel.objects.create(cors='http://example.com')
resp = self.client.options(
'/',
HTTP_ORIGIN='http://example.com',
Expand All @@ -169,7 +202,7 @@ def test_options_when_custom_model_enabled(self):

@override_settings(CORS_MODEL='testapp.CorsModel')
def test_process_response_when_custom_model_enabled(self):
CorsModel.objects.create(cors='foo.google.com')
CorsModel.objects.create(cors='http://foo.google.com')
response = self.client.get('/', HTTP_ORIGIN='http://foo.google.com')
assert response.get(ACCESS_CONTROL_ALLOW_ORIGIN, None) == 'http://foo.google.com'

Expand Down Expand Up @@ -259,7 +292,7 @@ def handler(*args, **kwargs):
assert resp.status_code == 200
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == 'http://example.com'

@override_settings(CORS_ORIGIN_WHITELIST=['example.com'])
@override_settings(CORS_ORIGIN_WHITELIST=['http://example.com'])
def test_signal_handler_allow_some_urls_to_everyone(self):
def allow_api_to_all(sender, request, **kwargs):
return request.path.startswith('/api/')
Expand All @@ -281,7 +314,7 @@ def allow_api_to_all(sender, request, **kwargs):
assert resp.status_code == 200
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == 'http://example.org'

@override_settings(CORS_ORIGIN_WHITELIST=['example.com'])
@override_settings(CORS_ORIGIN_WHITELIST=['http://example.com'])
def test_signal_called_once_during_normal_flow(self):
def allow_all(sender, request, **kwargs):
allow_all.calls += 1
Expand All @@ -293,7 +326,7 @@ def allow_all(sender, request, **kwargs):

assert allow_all.calls == 1

@override_settings(CORS_ORIGIN_WHITELIST=['example.com'])
@override_settings(CORS_ORIGIN_WHITELIST=['http://example.com'])
@prepend_middleware('tests.test_middleware.ShortCircuitMiddleware')
def test_get_short_circuit(self):
"""
Expand All @@ -306,7 +339,7 @@ def test_get_short_circuit(self):
assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp

@override_settings(
CORS_ORIGIN_WHITELIST=['example.com'],
CORS_ORIGIN_WHITELIST=['http://example.com'],
CORS_URLS_REGEX=r'^/foo/$',
)
@prepend_middleware(__name__ + '.ShortCircuitMiddleware')
Expand All @@ -315,30 +348,30 @@ def test_get_short_circuit_should_be_ignored(self):
assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp

@override_settings(
CORS_ORIGIN_WHITELIST=['example.com'],
CORS_ORIGIN_WHITELIST=['http://example.com'],
CORS_URLS_REGEX=r'^/foo/$',
)
def test_get_regex_matches(self):
resp = self.client.get('/foo/', HTTP_ORIGIN='http://example.com')
assert ACCESS_CONTROL_ALLOW_ORIGIN in resp

@override_settings(
CORS_ORIGIN_WHITELIST=['example.com'],
CORS_ORIGIN_WHITELIST=['http://example.com'],
CORS_URLS_REGEX=r'^/not-foo/$',
)
def test_get_regex_doesnt_match(self):
resp = self.client.get('/foo/', HTTP_ORIGIN='http://example.com')
assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp

@override_settings(
CORS_ORIGIN_WHITELIST=['example.com'],
CORS_ORIGIN_WHITELIST=['http://example.com'],
CORS_URLS_REGEX=r'^/foo/$',
)
def test_get_regex_matches_path_info(self):
resp = self.client.get('/foo/', HTTP_ORIGIN='http://example.com', SCRIPT_NAME='/prefix/')
assert ACCESS_CONTROL_ALLOW_ORIGIN in resp

@override_settings(CORS_ORIGIN_WHITELIST=['example.com'])
@override_settings(CORS_ORIGIN_WHITELIST=['http://example.com'])
def test_cors_enabled_is_attached_and_bool(self):
"""
Ensure that request._cors_enabled is available - although a private API
Expand All @@ -349,7 +382,7 @@ def test_cors_enabled_is_attached_and_bool(self):
assert isinstance(request._cors_enabled, bool)
assert request._cors_enabled

@override_settings(CORS_ORIGIN_WHITELIST=['example.com'])
@override_settings(CORS_ORIGIN_WHITELIST=['http://example.com'])
def test_works_if_view_deletes_cors_enabled(self):
"""
Just in case something crazy happens in the view or other middleware,
Expand Down

0 comments on commit 3a1c92d

Please sign in to comment.