From 3a1c92d3537db0d0187cedaae7174d1b883d65ac Mon Sep 17 00:00:00 2001 From: wgonczaronek <40790247+wgonczaronek@users.noreply.github.com> Date: Fri, 10 May 2019 11:12:05 +0200 Subject: [PATCH] Make Origin scheme-aware (#388) Regarding #379. --- HISTORY.rst | 3 ++ README.rst | 17 ++++++---- corsheaders/middleware.py | 40 +++++++++++++++++++++-- tests/test_middleware.py | 67 +++++++++++++++++++++++++++++---------- 4 files changed, 101 insertions(+), 26 deletions(-) diff --git a/HISTORY.rst b/HISTORY.rst index 1d81d3c0..08e38b51 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -6,6 +6,9 @@ Pending .. Insert new release notes below this line +* Origin is now scheme-aware. Deprecation warning has been added when origin + without scheme is included. + 2.5.3 (2019-04-28) ------------------ diff --git a/README.rst b/README.rst index ba9f9824..956722b9 100644 --- a/README.rst +++ b/README.rst @@ -99,17 +99,20 @@ A list of origin hostnames that are authorized to make cross-site HTTP requests. The value ``'null'`` can also appear in this list, and will match the ``Origin: null`` header that is used in `"privacy-sensitive contexts" `_, such as when the client is -running from a ``file://`` domain. Defaults to ``[]``. +running from a ``file://`` domain. Defaults to ``[]``. Proper origin should consist of +scheme, host and port (which could be given implicitly, e.g. for http it is assumed that the port is +80). Skipping scheme is allowed only for backward compatibility, deprecation warning will be raised +if this is discovered. Example: .. code-block:: python CORS_ORIGIN_WHITELIST = ( - 'google.com', - 'hostname.example.com', - 'localhost:8000', - '127.0.0.1:9000' + 'https://google.com', + 'http://hostname.example.com', + 'http://localhost:8000', + 'http://127.0.0.1:9000' ) @@ -260,8 +263,8 @@ For example: .. code-block:: python CORS_ORIGIN_WHITELIST = ( - 'read.only.com', - 'change.allowed.com', + 'http://read.only.com', + 'http://change.allowed.com', ) CSRF_TRUSTED_ORIGINS = ( diff --git a/corsheaders/middleware.py b/corsheaders/middleware.py index 8213f136..f2e24c32 100644 --- a/corsheaders/middleware.py +++ b/corsheaders/middleware.py @@ -1,6 +1,7 @@ from __future__ import absolute_import import re +import warnings from django import http from django.apps import apps @@ -144,8 +145,10 @@ def process_response(self, request, response): return response def origin_found_in_white_lists(self, origin, url): + whitelisted_origins = self._get_parsed_whitelisted_origins(conf.CORS_ORIGIN_WHITELIST) + self._check_for_origins_without_scheme(whitelisted_origins) return ( - url.netloc in conf.CORS_ORIGIN_WHITELIST + self._url_in_whitelist(url, whitelisted_origins) or (origin == 'null' and origin in conf.CORS_ORIGIN_WHITELIST) or self.regex_domain_match(origin) ) @@ -159,7 +162,11 @@ def origin_found_in_model(self, url): if conf.CORS_MODEL is None: return False model = apps.get_model(*conf.CORS_MODEL.split('.')) - return model.objects.filter(cors=url.netloc).exists() + queryset = model.objects.filter(cors__icontains=url.netloc).values_list('cors', flat=True) + + whitelisted_origins = self._get_parsed_whitelisted_origins(queryset) + self._check_for_origins_without_scheme(whitelisted_origins) + return self._url_in_whitelist(url, whitelisted_origins) def is_enabled(self, request): return ( @@ -176,3 +183,32 @@ def check_signal(self, request): return_value for function, return_value in signal_responses ) + + def _get_parsed_whitelisted_origins(self, origins): + whitelisted_origins = [] + for origin in origins: + # Note that when port is defined explicitly, it's part of netloc/path + parsed_origin = urlparse(origin) + whitelisted_origins.append( + { + 'scheme': parsed_origin.scheme, + 'host': parsed_origin.netloc or parsed_origin.path + } + ) + return whitelisted_origins + + def _check_for_origins_without_scheme(self, origins): + if any((origin['scheme'] == '' and origin['host'] != 'null' for origin in origins)): + warnings.warn('Passing origins without scheme will be deprecated.', DeprecationWarning) + + def _url_in_whitelist(self, url, origins_whitelist): + possible_matching_origins = [ + origin for origin in origins_whitelist if origin['host'] == url.netloc + ] + if not possible_matching_origins: + return False + else: + for origin in possible_matching_origins: + if origin['scheme'] == '' or origin['scheme'] == url.scheme: + return True + return False diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 17b5d616..edff1b44 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -1,5 +1,7 @@ from __future__ import absolute_import +import warnings + from django.http import HttpResponse from django.test import TestCase from django.test.utils import override_settings @@ -30,17 +32,31 @@ def test_get_origin_vary_by_default(self): resp = self.client.get('/') assert resp['Vary'] == 'Origin' - @override_settings(CORS_ORIGIN_WHITELIST=['example.com']) + @override_settings(CORS_ORIGIN_WHITELIST=['http://example.com']) def test_get_not_in_whitelist(self): resp = self.client.get('/', HTTP_ORIGIN='http://example.org') assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp - @override_settings(CORS_ORIGIN_WHITELIST=['example.com', 'example.org']) + @override_settings(CORS_ORIGIN_WHITELIST=['https://example.org']) + def test_get_not_in_whitelist_due_to_wrong_scheme(self): + resp = self.client.get('/', HTTP_ORIGIN='http://example.org') + assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp + + @override_settings(CORS_ORIGIN_WHITELIST=['example.org']) + def test_get_without_scheme_in_whitelist_raises_warning(self): + with warnings.catch_warnings(record=True) as warn: + resp = self.client.get('/', HTTP_ORIGIN='http://example.org') + assert ACCESS_CONTROL_ALLOW_ORIGIN in resp + assert len(warn) == 1 + assert issubclass(warn[-1].category, DeprecationWarning) + assert 'Passing origins without scheme will be deprecated.' in str(warn[-1].message) + + @override_settings(CORS_ORIGIN_WHITELIST=['http://example.com', 'http://example.org']) def test_get_in_whitelist(self): resp = self.client.get('/', HTTP_ORIGIN='http://example.org') assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == 'http://example.org' - @override_settings(CORS_ORIGIN_WHITELIST=['example.com', 'null']) + @override_settings(CORS_ORIGIN_WHITELIST=['http://example.com', 'null']) def test_null_in_whitelist(self): resp = self.client.get('/', HTTP_ORIGIN='null') assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == 'null' @@ -101,7 +117,7 @@ def test_options_no_max_age(self): @override_settings( CORS_ALLOW_METHODS=['OPTIONS'], CORS_ALLOW_CREDENTIALS=True, - CORS_ORIGIN_WHITELIST=('localhost:9000',), + CORS_ORIGIN_WHITELIST=('http://localhost:9000',), ) def test_options_whitelist_with_port(self): resp = self.client.options('/', HTTP_ORIGIN='http://localhost:9000') @@ -127,14 +143,31 @@ def test_options_will_not_add_origin_when_domain_not_found_in_origin_regex_white @override_settings(CORS_MODEL='testapp.CorsModel') def test_get_when_custom_model_enabled(self): - CorsModel.objects.create(cors='example.com') + CorsModel.objects.create(cors='http://example.com') resp = self.client.get('/', HTTP_ORIGIN='http://example.com') assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == 'http://example.com' assert ACCESS_CONTROL_ALLOW_CREDENTIALS not in resp + @override_settings(CORS_MODEL='testapp.CorsModel') + def test_get_when_custom_model_enabled_without_scheme(self): + with warnings.catch_warnings(record=True) as warn: + CorsModel.objects.create(cors='example.com') + resp = self.client.get('/', HTTP_ORIGIN='http://example.com') + + assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == 'http://example.com' + assert len(warn) == 1 + assert issubclass(warn[-1].category, DeprecationWarning) + assert 'Passing origins without scheme will be deprecated.' in str(warn[-1].message) + + @override_settings(CORS_MODEL='testapp.CorsModel') + def test_get_when_custom_model_enabled_with_different_scheme(self): + CorsModel.objects.create(cors='https://example.com') + resp = self.client.get('/', HTTP_ORIGIN='http://example.com') + assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp + @override_settings(CORS_MODEL='testapp.CorsModel', CORS_ALLOW_CREDENTIALS=True) def test_get_when_custom_model_enabled_and_allow_credentials(self): - CorsModel.objects.create(cors='example.com') + CorsModel.objects.create(cors='http://example.com') resp = self.client.get('/', HTTP_ORIGIN='http://example.com') assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == 'http://example.com' assert resp[ACCESS_CONTROL_ALLOW_CREDENTIALS] == 'true' @@ -159,7 +192,7 @@ def test_options_no_header(self): @override_settings(CORS_MODEL='testapp.CorsModel') def test_options_when_custom_model_enabled(self): - CorsModel.objects.create(cors='example.com') + CorsModel.objects.create(cors='http://example.com') resp = self.client.options( '/', HTTP_ORIGIN='http://example.com', @@ -169,7 +202,7 @@ def test_options_when_custom_model_enabled(self): @override_settings(CORS_MODEL='testapp.CorsModel') def test_process_response_when_custom_model_enabled(self): - CorsModel.objects.create(cors='foo.google.com') + CorsModel.objects.create(cors='http://foo.google.com') response = self.client.get('/', HTTP_ORIGIN='http://foo.google.com') assert response.get(ACCESS_CONTROL_ALLOW_ORIGIN, None) == 'http://foo.google.com' @@ -259,7 +292,7 @@ def handler(*args, **kwargs): assert resp.status_code == 200 assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == 'http://example.com' - @override_settings(CORS_ORIGIN_WHITELIST=['example.com']) + @override_settings(CORS_ORIGIN_WHITELIST=['http://example.com']) def test_signal_handler_allow_some_urls_to_everyone(self): def allow_api_to_all(sender, request, **kwargs): return request.path.startswith('/api/') @@ -281,7 +314,7 @@ def allow_api_to_all(sender, request, **kwargs): assert resp.status_code == 200 assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == 'http://example.org' - @override_settings(CORS_ORIGIN_WHITELIST=['example.com']) + @override_settings(CORS_ORIGIN_WHITELIST=['http://example.com']) def test_signal_called_once_during_normal_flow(self): def allow_all(sender, request, **kwargs): allow_all.calls += 1 @@ -293,7 +326,7 @@ def allow_all(sender, request, **kwargs): assert allow_all.calls == 1 - @override_settings(CORS_ORIGIN_WHITELIST=['example.com']) + @override_settings(CORS_ORIGIN_WHITELIST=['http://example.com']) @prepend_middleware('tests.test_middleware.ShortCircuitMiddleware') def test_get_short_circuit(self): """ @@ -306,7 +339,7 @@ def test_get_short_circuit(self): assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp @override_settings( - CORS_ORIGIN_WHITELIST=['example.com'], + CORS_ORIGIN_WHITELIST=['http://example.com'], CORS_URLS_REGEX=r'^/foo/$', ) @prepend_middleware(__name__ + '.ShortCircuitMiddleware') @@ -315,7 +348,7 @@ def test_get_short_circuit_should_be_ignored(self): assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp @override_settings( - CORS_ORIGIN_WHITELIST=['example.com'], + CORS_ORIGIN_WHITELIST=['http://example.com'], CORS_URLS_REGEX=r'^/foo/$', ) def test_get_regex_matches(self): @@ -323,7 +356,7 @@ def test_get_regex_matches(self): assert ACCESS_CONTROL_ALLOW_ORIGIN in resp @override_settings( - CORS_ORIGIN_WHITELIST=['example.com'], + CORS_ORIGIN_WHITELIST=['http://example.com'], CORS_URLS_REGEX=r'^/not-foo/$', ) def test_get_regex_doesnt_match(self): @@ -331,14 +364,14 @@ def test_get_regex_doesnt_match(self): assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp @override_settings( - CORS_ORIGIN_WHITELIST=['example.com'], + CORS_ORIGIN_WHITELIST=['http://example.com'], CORS_URLS_REGEX=r'^/foo/$', ) def test_get_regex_matches_path_info(self): resp = self.client.get('/foo/', HTTP_ORIGIN='http://example.com', SCRIPT_NAME='/prefix/') assert ACCESS_CONTROL_ALLOW_ORIGIN in resp - @override_settings(CORS_ORIGIN_WHITELIST=['example.com']) + @override_settings(CORS_ORIGIN_WHITELIST=['http://example.com']) def test_cors_enabled_is_attached_and_bool(self): """ Ensure that request._cors_enabled is available - although a private API @@ -349,7 +382,7 @@ def test_cors_enabled_is_attached_and_bool(self): assert isinstance(request._cors_enabled, bool) assert request._cors_enabled - @override_settings(CORS_ORIGIN_WHITELIST=['example.com']) + @override_settings(CORS_ORIGIN_WHITELIST=['http://example.com']) def test_works_if_view_deletes_cors_enabled(self): """ Just in case something crazy happens in the view or other middleware,