Skip to content

Commit

Permalink
Retry for authentication(getting tokens) (#251)
Browse files Browse the repository at this point in the history
* Add retry to check_tokenn and auth function

* Added tests

* Fix test environment variables

* Upgrade requests minimum version because of CVE 2018-18074
  • Loading branch information
akharit committed Oct 29, 2018
1 parent dc7d053 commit aff7d1a
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 39 deletions.
2 changes: 1 addition & 1 deletion azure/datalake/store/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def concat(self, outfile, filelist, delete_source=False):
self.azure.call('MSCONCAT', outfile.as_posix(),
data=bytearray(json.dumps(sources,separators=(',', ':')), encoding="utf-8"),
deleteSourceDirectory=delete,
headers={'Content-Type': "application/json"},)
headers={'Content-Type': "application/json"})
self.invalidate_cache(outfile)

merge = concat
Expand Down
80 changes: 47 additions & 33 deletions azure/datalake/store/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
else:
import urllib

from .retry import ExponentialRetryPolicy
from .retry import ExponentialRetryPolicy, retry_decorator_for_auth

# 3rd party imports
import adal
Expand Down Expand Up @@ -74,7 +74,7 @@
def auth(tenant_id=None, username=None,
password=None, client_id=default_client,
client_secret=None, resource=DEFAULT_RESOURCE_ENDPOINT,
require_2fa=False, authority=None, **kwargs):
require_2fa=False, authority=None, retry_policy=None, **kwargs):
""" User/password authentication
Parameters
Expand Down Expand Up @@ -103,6 +103,7 @@ def auth(tenant_id=None, username=None,
-------
:type DataLakeCredential :mod: `A DataLakeCredential object`
"""

if not authority:
authority = 'https://login.microsoftonline.com/'

Expand All @@ -124,24 +125,30 @@ def auth(tenant_id=None, username=None,
if not client_secret:
client_secret = os.environ.get('azure_client_secret', None)

# You can explicitly authenticate with 2fa, or pass in nothing to the auth call and
# You can explicitly authenticate with 2fa, or pass in nothing to the auth call
# and the user will be prompted to login interactively through a browser.
if require_2fa or (username is None and password is None and client_secret is None):
code = context.acquire_user_code(resource, client_id)
print(code['message'])
out = context.acquire_token_with_device_code(resource, code, client_id)

elif username and password:
out = context.acquire_token_with_username_password(resource, username,
password, client_id)
elif client_id and client_secret:
out = context.acquire_token_with_client_credentials(resource, client_id,
client_secret)
# for service principal, we store the secret in the credential object for use when refreshing.
out.update({'secret': client_secret})
else:
raise ValueError("No authentication method found for credentials")

@retry_decorator_for_auth(retry_policy=retry_policy)
def get_token_internal():
# Internal function used so as to use retry decorator
if require_2fa or (username is None and password is None and client_secret is None):
code = context.acquire_user_code(resource, client_id)
print(code['message'])
out = context.acquire_token_with_device_code(resource, code, client_id)

elif username and password:
out = context.acquire_token_with_username_password(resource, username,
password, client_id)
elif client_id and client_secret:
out = context.acquire_token_with_client_credentials(resource, client_id,
client_secret)
# for service principal, we store the secret in the credential object for use when refreshing.
out.update({'secret': client_secret})
else:
raise ValueError("No authentication method found for credentials")
return out

out = get_token_internal()
out.update({'access': out['accessToken'], 'resource': resource,
'refresh': out.get('refreshToken', False),
'time': time.time(), 'tenant': tenant_id, 'client': client_id})
Expand All @@ -152,22 +159,22 @@ class DataLakeCredential:
def __init__(self, token):
self.token = token

def signed_session(self):
def signed_session(self, retry_policy=None):
# type: () -> requests.Session
"""Create requests session with any required auth headers applied.
:rtype: requests.Session
"""
session = requests.Session()
if time.time() - self.token['time'] > self.token['expiresIn'] - 100:
self.refresh_token()
self.refresh_token(retry_poliy=retry_policy)

scheme, token = self.token['tokenType'], self.token['access']
header = "{} {}".format(scheme, token)
session.headers['Authorization'] = header
return session

def refresh_token(self, authority=None):
def refresh_token(self, authority=None, retry_policy=None):
""" Refresh an expired authorization token
Parameters
Expand All @@ -183,15 +190,22 @@ def refresh_token(self, authority=None):

context = adal.AuthenticationContext(authority +
self.token['tenant'])
if self.token.get('secret') and self.token.get('client'):
out = context.acquire_token_with_client_credentials(self.token['resource'], self.token['client'],
self.token['secret'])
out.update({'secret': self.token['secret']})
else:
out = context.acquire_token_with_refresh_token(self.token['refresh'],
client_id=self.token['client'],
resource=self.token['resource'])
out.update({'refresh': out['refreshToken']})

@retry_decorator_for_auth(retry_policy=retry_policy)
def get_token_internal():
# Internal function used so as to use retry decorator
if self.token.get('secret') and self.token.get('client'):
out = context.acquire_token_with_client_credentials(self.token['resource'],
self.token['client'],
self.token['secret'])
out.update({'secret': self.token['secret']})
else:
out = context.acquire_token_with_refresh_token(self.token['refresh'],
client_id=self.token['client'],
resource=self.token['resource'])
return out

out = get_token_internal()
# common items to update
out.update({'access': out['accessToken'],
'time': time.time(), 'tenant': self.token['tenant'],
Expand Down Expand Up @@ -257,7 +271,7 @@ def __init__(self, store_name=default_store, token=None,
# There is a case where the user can opt to exclude an API version, in which case
# the service itself decides on the API version to use (it's default).
self.api_version = api_version or None
self.head = {'Authorization': token.signed_session().headers['Authorization']}
self.head = {'Authorization': token.signed_session(retry_policy=None).headers['Authorization']}
self.url = 'https://%s.%s/' % (store_name, url_suffix)
self.webhdfs = 'webhdfs/v1/'
self.extended_operations = 'webhdfsext/'
Expand All @@ -282,8 +296,8 @@ def session(self):
self.local.session = s
return s

def _check_token(self):
cur_session = self.token.signed_session()
def _check_token(self, retry_policy=None):
cur_session = self.token.signed_session(retry_policy=retry_policy)
if not self.head or self.head.get('Authorization') != cur_session.headers['Authorization']:
self.head = {'Authorization': cur_session.headers['Authorization']}
self.local.session = None
Expand Down
67 changes: 63 additions & 4 deletions azure/datalake/store/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging
import sys
import time

from functools import wraps
# local imports

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -45,6 +45,9 @@ def should_retry(self, response, last_exception, retry_count):
self.__backoff()
return True

if response is None:
return False

status_code = response.status_code

if(status_code == 501
Expand All @@ -58,8 +61,8 @@ def should_retry(self, response, last_exception, retry_count):
if(status_code >= 500
or status_code == 401
or status_code == 408
or status_code == 429):

or status_code == 429
or status_code == 104):
self.__backoff()
return True

Expand All @@ -70,4 +73,60 @@ def should_retry(self, response, last_exception, retry_count):

def __backoff(self):
time.sleep(self.exponential_retry_interval)
self.exponential_retry_interval *= self.exponential_factor
self.exponential_retry_interval *= self.exponential_factor


def retry_decorator_for_auth(retry_policy = None):
import adal
from requests import HTTPError
if retry_policy is None:
retry_policy = ExponentialRetryPolicy(max_retries=2)

def deco_retry(func):
@wraps(func)
def f_retry(*args, **kwargs):
retry_count = -1
last_exception = None
out = None
while True:
retry_count += 1
try:
out = func(*args, **kwargs)
except (adal.adal_error.AdalError, HTTPError) as e:
# ADAL error corresponds to everything but 429, which bubbles up HTTP error.
last_exception = e
logger.exception("Retry count " + str(retry_count) + "Exception :" + str(last_exception))

if hasattr(last_exception, 'error_response'): # ADAL exception
response = response_from_adal_exception(last_exception)
if hasattr(last_exception, 'response'): # HTTP exception i.e 429
response = last_exception.response

request_successful = last_exception is None or response.status_code == 401 # 401 = Invalid credentials
if request_successful or not retry_policy.should_retry(response, last_exception, retry_count):
break
if out is None:
raise last_exception
return out

return f_retry

return deco_retry


def response_from_adal_exception(e):
import re
from collections import namedtuple

response = e.error_response
http_code = re.search("http error: (\d+)", str(e))
if http_code is not None: # Add status_code to response object for use in should_retry
keys = list(response.keys()) + ['status_code']
status_code = int(http_code.group(1))
values = list(response.values()) + [status_code]

Response = namedtuple("Response", keys)
response = Response(
*values) # Construct response object with adal exception response and http code
return response

1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
install_requires=[
'cffi',
'adal>=0.4.2',
'requests>=2.20.0'
],
extras_require={
":python_version<'3.4'": ['pathlib2'],
Expand Down
1 change: 1 addition & 0 deletions tests/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
SUBSCRIPTION_ID = fake_settings.SUBSCRIPTION_ID
RESOURCE_GROUP_NAME = fake_settings.RESOURCE_GROUP_NAME
RECORD_MODE = os.environ.get('RECORD_MODE', 'all').lower()
CLIENT_ID = os.environ['azure_service_principal']
'''
RECORD_MODE = os.environ.get('RECORD_MODE', 'none').lower()
Expand Down
Loading

0 comments on commit aff7d1a

Please sign in to comment.