Skip to content

Commit

Permalink
Merge pull request #195 from fmigneault/httplib-port
Browse files Browse the repository at this point in the history
Settings and Improvements
  • Loading branch information
jensens committed Jun 9, 2023
2 parents 4c2384e + 176d11d commit 39c6079
Show file tree
Hide file tree
Showing 31 changed files with 148 additions and 86 deletions.
5 changes: 1 addition & 4 deletions authomatic/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1328,10 +1328,7 @@ def __init__(
self.logging_level = logging_level
self.prefix = prefix
self._logger = logger or logging.getLogger(str(id(self)))

# Set logging level.
if logger is None:
self._logger.setLevel(logging_level)
self._logger.setLevel(logging_level)

def login(self, adapter, provider_name, callback=None,
session=None, session_saver=None, **kwargs):
Expand Down
146 changes: 98 additions & 48 deletions authomatic/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import hashlib
import logging
import random
import ssl
import sys
import traceback
import uuid
Expand All @@ -35,8 +36,7 @@
CredentialsError,
)
from authomatic import six
from authomatic.six.moves import urllib_parse as parse
from authomatic.six.moves import http_client
from authomatic.six.moves import urllib_parse as parse, http_client
from authomatic.exceptions import CancellationError

__all__ = [
Expand Down Expand Up @@ -280,11 +280,18 @@ def _kwarg(self, kwargs, kwname, default=None):
Name of the desired keyword argument.
"""

return kwargs.get(kwname) or \
self.settings.config.get(self.name, {}).get(kwname) or \
self.settings.config.get('__defaults__', {}).get(kwname) or \
default
# check against `None` instead of multiple 'or' in case default value
# is `False`, which could be considered a valid 'found' value
getters = [
lambda: kwargs.get(kwname),
lambda: self.settings.config.get(self.name, {}).get(kwname),
lambda: self.settings.config.get('__defaults__', {}).get(kwname),
]
for get in getters:
value = get()
if value is not None:
return value
return default

def _session_key(self, key):
"""
Expand Down Expand Up @@ -352,8 +359,36 @@ def _log(cls, level, msg, **kwargs):
level, ': '.join(
('authomatic', cls.__name__, msg)), **kwargs)

@classmethod
def _log_param(cls, param, value='', last=None,
level=logging.DEBUG, **kwargs):
"""
Same as :meth:`_log` but in DEBUG, and with option indicator in front
of the message according to :param:`last`.
:param str param:
Parameter name.
:param Any value:
Parameter value.
:param bool last:
"|-" like character if `False`, "|_" if `True`, None if `None`.
:param int level:
Logging level as specified in the
`login module <http://docs.python.org/2/library/logging.html>`_ of
Python standard library.
"""
info_style = u' \u251C\u2500 '
last_style = u' \u2514\u2500 '
style = u'' if last is None else last_style if last else info_style
cls._log(logging.DEBUG, u'{0}{1}: {2!s}'.format(style, param, value))

def _fetch(self, url, method='GET', params=None, headers=None,
body='', max_redirects=5, content_parser=None):
body='', max_redirects=5, content_parser=None,
certificate_file=None, ssl_verify=True):
"""
Fetches a URL.
Expand All @@ -379,6 +414,11 @@ def _fetch(self, url, method='GET', params=None, headers=None,
A callable to be used to parse the :attr:`.Response.data`
from :attr:`.Response.content`.
:param str certificate_file:
Optional certificate file to be used for HTTPS connection.
:param bool ssl_verify:
Verify SSL on HTTPS connection.
"""
# 'magic' using _kwarg method
# pylint:disable=no-member
Expand All @@ -388,7 +428,7 @@ def _fetch(self, url, method='GET', params=None, headers=None,
headers = headers or {}
headers.update(self.access_headers)

scheme, host, path, query, fragment = parse.urlsplit(url)
url_parsed = parse.urlsplit(url)
query = parse.urlencode(params)

if method in ('POST', 'PUT', 'PATCH'):
Expand All @@ -398,22 +438,30 @@ def _fetch(self, url, method='GET', params=None, headers=None,
query = ''
headers.update(
{'Content-Type': 'application/x-www-form-urlencoded'})
request_path = parse.urlunsplit(('', '', path or '', query or '', ''))
request_path = parse.urlunsplit(
('', '', url_parsed.path or '', query or '', ''))

self._log(logging.DEBUG, u' \u251C\u2500 host: {0}'.format(host))
self._log(
logging.DEBUG,
u' \u251C\u2500 path: {0}'.format(request_path))
self._log(logging.DEBUG, u' \u251C\u2500 method: {0}'.format(method))
self._log(logging.DEBUG, u' \u251C\u2500 body: {0}'.format(body))
self._log(logging.DEBUG, u' \u251C\u2500 params: {0}'.format(params))
self._log(logging.DEBUG, u' \u2514\u2500 headers: {0}'.format(headers))
self._log_param('host', url_parsed.hostname, last=False)
self._log_param('method', method, last=False)
self._log_param('body', body, last=False)
self._log_param('params', params, last=False)
self._log_param('headers', headers, last=False)
self._log_param('certificate', certificate_file, last=False)
self._log_param('SSL verify', ssl_verify, last=True)

# Connect
if scheme.lower() == 'https':
connection = http_client.HTTPSConnection(host)
if url_parsed.scheme.lower() == 'https':
context = None if ssl_verify else ssl._create_unverified_context()
cert_file = certificate_file if ssl_verify else None
connection = http_client.HTTPSConnection(
url_parsed.hostname,
port=url_parsed.port,
cert_file=cert_file,
context=context)
else:
connection = http_client.HTTPConnection(host)
connection = http_client.HTTPConnection(
url_parsed.hostname,
port=url_parsed.port)

try:
connection.request(method, request_path, body, headers)
Expand All @@ -434,32 +482,27 @@ def _fetch(self, url, method='GET', params=None, headers=None,
elif max_redirects > 0:
remaining_redirects = max_redirects - 1

self._log(logging.DEBUG, u'Redirecting to {0}'.format(url))
self._log(logging.DEBUG, u'Remaining redirects: {0}'
.format(remaining_redirects))
self._log_param('Redirecting to', url)
self._log_param('Remaining redirects', remaining_redirects)

# Call this method again.
response = self._fetch(url=location,
params=params,
method=method,
headers=headers,
max_redirects=remaining_redirects)
max_redirects=remaining_redirects,
certificate_file=certificate_file,
ssl_verify=ssl_verify)

else:
raise FetchError('Max redirects reached!',
url=location,
status=response.status)
else:
self._log(logging.DEBUG, u'Got response:')
self._log(logging.DEBUG, u' \u251C\u2500 url: {0}'.format(url))
self._log(
logging.DEBUG,
u' \u251C\u2500 status: {0}'.format(
response.status))
self._log(
logging.DEBUG,
u' \u2514\u2500 headers: {0}'.format(
response.getheaders()))
self._log_param('Got response')
self._log_param('url', url, last=False)
self._log_param('status', response.status, last=False)
self._log_param('headers', response.getheaders(), last=True)

return authomatic.core.Response(response, content_parser)

Expand Down Expand Up @@ -773,19 +816,20 @@ def type_id(self):
str(mod.PROVIDER_ID_MAP.index(cls))

def access(self, url, params=None, method='GET', headers=None,
body='', max_redirects=5, content_parser=None):
body='', max_redirects=5, content_parser=None,
certificate_file=None, ssl_verify=True):
"""
Fetches the **protected resource** of an authenticated **user**.
:param credentials:
The **user's** :class:`.Credentials` (serialized or normal).
:param str url:
The URL of the **protected resource**.
:param str method:
HTTP method of the request.
:param dict params:
Dictionary of request parameters.
:param dict headers:
HTTP headers of the request.
Expand All @@ -799,6 +843,12 @@ def access(self, url, params=None, method='GET', headers=None,
A function to be used to parse the :attr:`.Response.data`
from :attr:`.Response.content`.
:param str certificate_file:
Optional certificate file to be used for HTTPS connection.
:param bool ssl_verify:
Verify SSL on HTTPS connection.
:returns:
:class:`.Response`
Expand All @@ -809,9 +859,7 @@ def access(self, url, params=None, method='GET', headers=None,

headers = headers or {}

self._log(
logging.INFO,
u'Accessing protected resource {0}.'.format(url))
self._log_param('Accessing protected resource', url, level=logging.INFO)

request_elements = self.create_request_elements(
request_type=self.PROTECTED_RESOURCE_REQUEST_TYPE,
Expand All @@ -825,12 +873,12 @@ def access(self, url, params=None, method='GET', headers=None,

response = self._fetch(*request_elements,
max_redirects=max_redirects,
content_parser=content_parser)
content_parser=content_parser,
certificate_file=certificate_file,
ssl_verify=ssl_verify)

self._log(
logging.INFO,
u'Got response. HTTP status = {0}.'.format(
response.status))
status = response.status
self._log_param('Got response. HTTP status', status, level=logging.INFO)
return response

def async_access(self, *args, **kwargs):
Expand Down Expand Up @@ -976,7 +1024,9 @@ def _access_user_info(self):
"""
url = self.user_info_url.format(**self.user.__dict__)
return self.access(url)
cert = self._kwarg({}, 'certificate_file', None)
verify = self._kwarg({}, 'ssl_verify', True)
return self.access(url, certificate_file=cert, ssl_verify=verify)


class AuthenticationProvider(BaseProvider):
Expand Down
16 changes: 14 additions & 2 deletions authomatic/providers/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ def __init__(self, *args, **kwargs):
*offline access token*.
Default is ``False``.
:param str certificate_file:
Certificate file to employ for HTTPS connection where needed.
:param bool ssl_verify:
Certificate file to employ for HTTPS connection where needed.
As well as those inherited from :class:`.AuthorizationProvider`
constructor.
Expand All @@ -92,6 +98,8 @@ def __init__(self, *args, **kwargs):

self.scope = self._kwarg(kwargs, 'scope', [])
self.offline = self._kwarg(kwargs, 'offline', False)
self.cert = self._kwarg(kwargs, 'certificate_file', None)
self.verify = self._kwarg(kwargs, 'ssl_verify', True)

# ========================================================================
# Internal methods
Expand Down Expand Up @@ -312,7 +320,9 @@ def refresh_credentials(self, credentials):
)

self._log(logging.INFO, u'Refreshing credentials.')
response = self._fetch(*request_elements)
response = self._fetch(*request_elements,
certificate_file=self.cert,
ssl_verify=self.verify)

# We no longer need consumer info.
credentials.consumer_key = None
Expand Down Expand Up @@ -410,7 +420,9 @@ def login(self):
headers=self.access_token_headers
)

response = self._fetch(*request_elements)
response = self._fetch(*request_elements,
certificate_file=self.cert,
ssl_verify=self.verify)
self.access_token_response = response

access_token = response.data.get('access_token', '')
Expand Down
2 changes: 1 addition & 1 deletion doc/source/examples/simple.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Yo will need the ``consumer_key`` and ``consumer_secret`` which you can get

Facebook and other |oauth2| providers require a **redirect URI**
which should be the URL of the *login request handler*
which we will create in this tutorial and whose walue in our case will be
which we will create in this tutorial and whose value in our case will be
``https://[hostname]:[port]/login/fb`` for Facebook.

.. literalinclude:: ../../../examples/gae/simple/config-template.py
Expand Down
11 changes: 7 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
from setuptools import find_packages
from setuptools import setup

oid_reqs = [
"python-openid ; python_version<'3'",
"python3-openid ; python_version>='3'",
]

setup(
packages=find_packages(),
package_data={'': ['*.txt', '*.rst']},
extras_require={
'OpenID': [
"python-openid ; python_version<'3'",
"python3-openid ; python_version>='3'",
],
'OpenID: python_version < "3"': ['python-openid'],
'OpenID: python_version >= "3"': ['python3-openid'],
},
)
2 changes: 1 addition & 1 deletion tests/functional_tests/expected_values/amazon.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
conf.no_phone +
conf.no_timezone +
conf.no_username,
# True means that any thruthy value is expected
# True means that any truthy value is expected
'credentials': {
'token_type': 'Bearer',
'provider_type_id': '2-18',
Expand Down
2 changes: 1 addition & 1 deletion tests/functional_tests/expected_values/bitbucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
conf.no_nickname +
conf.no_phone +
conf.no_timezone,
# True means that any thruthy value is expected
# True means that any truthy value is expected
'credentials': {
'_expiration_time': None,
'_expire_in': True,
Expand Down
2 changes: 1 addition & 1 deletion tests/functional_tests/expected_values/bitly.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
conf.no_email + conf.no_location +
conf.no_gender + conf.no_locale +
conf.no_first_name + conf.no_last_name,
# True means that any thruthy value is expected
# True means that any truthy value is expected
'credentials': {
'token_type': None,
'provider_type_id': '2-2',
Expand Down
2 changes: 1 addition & 1 deletion tests/functional_tests/expected_values/deviantart.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
conf.no_email + conf.no_location +
conf.no_gender + conf.no_locale +
conf.no_first_name + conf.no_last_name,
# True means that any thruthy value is expected
# True means that any truthy value is expected
'credentials': {
'token_type': 'Bearer',
'provider_type_id': '2-4',
Expand Down
2 changes: 1 addition & 1 deletion tests/functional_tests/expected_values/eventbrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
conf.no_nickname +
conf.no_phone +
conf.no_timezone,
# True means that any thruthy value is expected
# True means that any truthy value is expected
'credentials': {
'token_type': 'Bearer',
'provider_type_id': '2-17',
Expand Down
Loading

0 comments on commit 39c6079

Please sign in to comment.