diff --git a/py/selenium/webdriver/remote/remote_connection.py b/py/selenium/webdriver/remote/remote_connection.py index f66d343da5932..a37a6174d8711 100644 --- a/py/selenium/webdriver/remote/remote_connection.py +++ b/py/selenium/webdriver/remote/remote_connection.py @@ -162,6 +162,35 @@ def reset_timeout(cls): """ cls._timeout = socket._GLOBAL_DEFAULT_TIMEOUT + @classmethod + def get_remote_connection_headers(cls, parsed_url, keep_alive=False): + """ + Get headers for remote request. + + :Args: + - parsed_url - The parsed url + - keep_alive (Boolean) - Is this a keep-alive connection (default: False) + """ + + headers = { + 'Accept': 'application/json', + 'Content-Type': 'application/json;charset=UTF-8', + 'User-Agent': 'Python http auth' + } + + if parsed_url.username: + base64string = base64.b64encode('{0.username}:{0.password}'.format(parsed_url).encode()) + headers.update({ + 'Authorization': 'Basic {}'.format(base64string.decode()) + }) + + if keep_alive: + headers.update({ + 'Connection': 'keep-alive' + }) + + return headers + def __init__(self, remote_server_addr, keep_alive=False, resolve_ip=True): # Attempt to resolve the hostname and get an IP address. self.keep_alive = keep_alive @@ -429,17 +458,9 @@ def _request(self, method, url, body=None): LOGGER.debug('%s %s %s' % (method, url, body)) parsed_url = parse.urlparse(url) + headers = self.get_remote_connection_headers(parsed_url, self.keep_alive) if self.keep_alive: - headers = {"Connection": 'keep-alive', method: parsed_url.path, - "User-Agent": "Python http auth", - "Content-type": "application/json;charset=\"UTF-8\"", - "Accept": "application/json"} - if parsed_url.username: - auth = base64.standard_b64encode(('%s:%s' % ( - parsed_url.username, - parsed_url.password)).encode('ascii')).decode('ascii').replace('\n', '') - headers["Authorization"] = "Basic %s" % auth if body and method != 'POST' and method != 'PUT': body = None try: @@ -472,12 +493,7 @@ def _request(self, method, url, body=None): else: request = Request(url, data=body.encode('utf-8'), method=method) - request.add_header('Accept', 'application/json') - request.add_header('Content-Type', 'application/json;charset=UTF-8') - - if parsed_url.username: - base64string = base64.b64encode('{0.username}:{0.password}'.format(parsed_url).encode()) - request.add_header('Authorization', 'Basic {}'.format(base64string).decode()) + request.headers.update(headers) if password_manager: opener = url_request.build_opener(url_request.HTTPRedirectHandler(), diff --git a/py/test/unit/selenium/webdriver/remote/test_remote_connection.py b/py/test/unit/selenium/webdriver/remote/test_remote_connection.py index 9d5bbb54dcfdc..ddf30da858064 100644 --- a/py/test/unit/selenium/webdriver/remote/test_remote_connection.py +++ b/py/test/unit/selenium/webdriver/remote/test_remote_connection.py @@ -15,21 +15,71 @@ # specific language governing permissions and limitations # under the License. -import pytest -from selenium.webdriver.remote.remote_connection import RemoteConnection +try: + from urllib import parse +except ImportError: # above is available in py3+, below is py2.7 + import urlparse as parse +from selenium.webdriver.remote.remote_connection import ( + RemoteConnection, +) -def test_basic_auth(mocker): - def check(request, timeout): - assert request.headers['Authorization'] == 'Basic dXNlcjpwYXNz' +def test_get_remote_connection_headers_defaults(): + url = 'http://remote' + headers = RemoteConnection.get_remote_connection_headers(parse.urlparse(url)) + assert 'Authorization' not in headers.keys() + assert 'Connection' not in headers.keys() + assert headers.get('Accept') == 'application/json' + assert headers.get('Content-Type') == 'application/json;charset=UTF-8' + assert headers.get('User-Agent') == 'Python http auth' + + +def test_get_remote_connection_headers_adds_auth_header_if_pass(): + url = 'http://user:pass@remote' + headers = RemoteConnection.get_remote_connection_headers(parse.urlparse(url)) + assert headers.get('Authorization') == 'Basic dXNlcjpwYXNz' + + +def test_get_remote_connection_headers_adds_keep_alive_if_requested(): + url = 'http://remote' + headers = RemoteConnection.get_remote_connection_headers(parse.urlparse(url), keep_alive=True) + assert headers.get('Connection') == 'keep-alive' + + +class MockResponse: + code = 200 + headers = [] + + def read(self): + return b"{}" + + def close(self): + pass + + def getheader(self, *args, **kwargs): + pass + + +def test_remote_connection_adds_connection_headers_from_get_remote_connection_headers(mocker): + test_headers = {'FOO': 'bar'} + + # Stub out the get_remote_connection_headers method to return something testable + mocker.patch( + 'selenium.webdriver.remote.remote_connection.RemoteConnection.get_remote_connection_headers' + ).return_value = test_headers + + # Stub out response try: - method = mocker.patch('urllib.request.OpenerDirector.open') + mock_open = mocker.patch('urllib.request.OpenerDirector.open') except ImportError: - method = mocker.patch('urllib2.OpenerDirector.open') - method.side_effect = check + mock_open = mocker.patch('urllib2.OpenerDirector.open') + + def assert_header_added(request, timeout): + assert request.headers == test_headers + return MockResponse() + + mock_open.side_effect = assert_header_added - with pytest.raises(AttributeError): - RemoteConnection('http://user:pass@remote', resolve_ip=False) \ - .execute('status', {}) + RemoteConnection('http://remote', resolve_ip=False).execute('status', {})