Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 31 additions & 15 deletions py/selenium/webdriver/remote/remote_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,35 @@ def reset_timeout(cls):
"""
cls._timeout = socket._GLOBAL_DEFAULT_TIMEOUT

@classmethod
def get_remote_connection_headers(cls, parsed_url, keep_alive=False):
"""
Get headers for remote request.

:Args:
- parsed_url - The parsed url
- keep_alive (Boolean) - Is this a keep-alive connection (default: False)
"""

headers = {
'Accept': 'application/json',
'Content-Type': 'application/json;charset=UTF-8',
'User-Agent': 'Python http auth'
}

if parsed_url.username:
base64string = base64.b64encode('{0.username}:{0.password}'.format(parsed_url).encode())
headers.update({
'Authorization': 'Basic {}'.format(base64string.decode())
})

if keep_alive:
headers.update({
'Connection': 'keep-alive'
})

return headers

def __init__(self, remote_server_addr, keep_alive=False, resolve_ip=True):
# Attempt to resolve the hostname and get an IP address.
self.keep_alive = keep_alive
Expand Down Expand Up @@ -429,17 +458,9 @@ def _request(self, method, url, body=None):
LOGGER.debug('%s %s %s' % (method, url, body))

parsed_url = parse.urlparse(url)
headers = self.get_remote_connection_headers(parsed_url, self.keep_alive)

if self.keep_alive:
headers = {"Connection": 'keep-alive', method: parsed_url.path,
"User-Agent": "Python http auth",
"Content-type": "application/json;charset=\"UTF-8\"",
"Accept": "application/json"}
if parsed_url.username:
auth = base64.standard_b64encode(('%s:%s' % (
parsed_url.username,
parsed_url.password)).encode('ascii')).decode('ascii').replace('\n', '')
headers["Authorization"] = "Basic %s" % auth
if body and method != 'POST' and method != 'PUT':
body = None
try:
Expand Down Expand Up @@ -472,12 +493,7 @@ def _request(self, method, url, body=None):
else:
request = Request(url, data=body.encode('utf-8'), method=method)

request.add_header('Accept', 'application/json')
request.add_header('Content-Type', 'application/json;charset=UTF-8')

if parsed_url.username:
base64string = base64.b64encode('{0.username}:{0.password}'.format(parsed_url).encode())
request.add_header('Authorization', 'Basic {}'.format(base64string).decode())
request.headers.update(headers)

if password_manager:
opener = url_request.build_opener(url_request.HTTPRedirectHandler(),
Expand Down
72 changes: 61 additions & 11 deletions py/test/unit/selenium/webdriver/remote/test_remote_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,71 @@
# specific language governing permissions and limitations
# under the License.

import pytest

from selenium.webdriver.remote.remote_connection import RemoteConnection
try:
from urllib import parse
except ImportError: # above is available in py3+, below is py2.7
import urlparse as parse

from selenium.webdriver.remote.remote_connection import (
RemoteConnection,
)

def test_basic_auth(mocker):
def check(request, timeout):
assert request.headers['Authorization'] == 'Basic dXNlcjpwYXNz'

def test_get_remote_connection_headers_defaults():
url = 'http://remote'
headers = RemoteConnection.get_remote_connection_headers(parse.urlparse(url))
assert 'Authorization' not in headers.keys()
assert 'Connection' not in headers.keys()
assert headers.get('Accept') == 'application/json'
assert headers.get('Content-Type') == 'application/json;charset=UTF-8'
assert headers.get('User-Agent') == 'Python http auth'


def test_get_remote_connection_headers_adds_auth_header_if_pass():
url = 'http://user:pass@remote'
headers = RemoteConnection.get_remote_connection_headers(parse.urlparse(url))
assert headers.get('Authorization') == 'Basic dXNlcjpwYXNz'


def test_get_remote_connection_headers_adds_keep_alive_if_requested():
url = 'http://remote'
headers = RemoteConnection.get_remote_connection_headers(parse.urlparse(url), keep_alive=True)
assert headers.get('Connection') == 'keep-alive'


class MockResponse:
code = 200
headers = []

def read(self):
return b"{}"

def close(self):
pass

def getheader(self, *args, **kwargs):
pass


def test_remote_connection_adds_connection_headers_from_get_remote_connection_headers(mocker):
test_headers = {'FOO': 'bar'}

# Stub out the get_remote_connection_headers method to return something testable
mocker.patch(
'selenium.webdriver.remote.remote_connection.RemoteConnection.get_remote_connection_headers'
).return_value = test_headers

# Stub out response
try:
method = mocker.patch('urllib.request.OpenerDirector.open')
mock_open = mocker.patch('urllib.request.OpenerDirector.open')
except ImportError:
method = mocker.patch('urllib2.OpenerDirector.open')
method.side_effect = check
mock_open = mocker.patch('urllib2.OpenerDirector.open')

def assert_header_added(request, timeout):
assert request.headers == test_headers
return MockResponse()

mock_open.side_effect = assert_header_added

with pytest.raises(AttributeError):
RemoteConnection('http://user:pass@remote', resolve_ip=False) \
.execute('status', {})
RemoteConnection('http://remote', resolve_ip=False).execute('status', {})