diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index 974e88a513e..5a74259eb3a 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -2,6 +2,7 @@ import asyncio import base64 +import binascii import datetime import functools import io @@ -42,6 +43,27 @@ def __new__(cls, login, password='', encoding='latin1'): return super().__new__(cls, login, password, encoding) + @classmethod + def decode(cls, auth_header, encoding='latin1'): + """Create a :class:`BasicAuth` object from an ``Authorization`` HTTP + header.""" + split = auth_header.strip().split(' ') + if len(split) == 2: + if split[0].strip().lower() != 'basic': + raise ValueError('Unknown authorization method %s' % split[0]) + to_decode = split[1] + else: + raise ValueError('Could not parse authorization header.') + + try: + username, _, password = base64.b64decode( + to_decode.encode('ascii') + ).decode(encoding).partition(':') + except binascii.Error: + raise ValueError('Invalid base64 encoding.') + + return cls(username, password, encoding=encoding) + def encode(self): """Encode credentials.""" creds = ('%s:%s' % (self.login, self.password)).encode(self.encoding) diff --git a/docs/client_reference.rst b/docs/client_reference.rst index df70d396b96..6c1de527535 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -1291,6 +1291,16 @@ BasicAuth e.g. *auth* parameter for :meth:`ClientSession.request`. + .. classmethod:: decode(auth_header, encoding='latin1') + + Decode HTTP basic authentication credentials. + + :param str auth_header: The ``Authorization`` header to decode. + :param str encoding: (optional) encoding ('latin1' by default) + + :return: decoded authentication data, :class:`BasicAuth`. + + .. method:: encode() Encode credentials into string suitable for ``Authorization`` diff --git a/tests/test_helpers.py b/tests/test_helpers.py index dbf954bdb43..4287cf0a3cd 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -71,6 +71,27 @@ def test_basic_auth4(): assert auth.encode() == 'Basic bmtpbTpwd2Q=' +def test_basic_auth_decode(): + auth = helpers.BasicAuth.decode('Basic bmtpbTpwd2Q=') + assert auth.login == 'nkim' + assert auth.password == 'pwd' + + +def test_basic_auth_invalid(): + with pytest.raises(ValueError): + helpers.BasicAuth.decode('bmtpbTpwd2Q=') + + +def test_basic_auth_decode_not_basic(): + with pytest.raises(ValueError): + helpers.BasicAuth.decode('Complex bmtpbTpwd2Q=') + + +def test_basic_auth_decode_bad_base64(): + with pytest.raises(ValueError): + helpers.BasicAuth.decode('Basic bmtpbTpwd2Q') + + def test_invalid_formdata_params(): with pytest.raises(TypeError): helpers.FormData('asdasf')