From 13bcc4181f782b83ee803b1bc38628699cdcb628 Mon Sep 17 00:00:00 2001 From: yugangw-msft Date: Wed, 24 Oct 2018 19:29:19 -0700 Subject: [PATCH] address review feedback on probe type, sequences --- .gitignore | 3 +- msrestazure/azure_local_creds_prober.py | 197 +++++++++++++++--------- tests/test_local_creds_prober.py | 5 +- 3 files changed, 127 insertions(+), 78 deletions(-) diff --git a/.gitignore b/.gitignore index 0aefe2d..438b873 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,5 @@ autorest coverage.xml Pipfile.lock build -dist \ No newline at end of file +dist +env \ No newline at end of file diff --git a/msrestazure/azure_local_creds_prober.py b/msrestazure/azure_local_creds_prober.py index a21f854..9138591 100644 --- a/msrestazure/azure_local_creds_prober.py +++ b/msrestazure/azure_local_creds_prober.py @@ -36,107 +36,156 @@ _LOGGER = logging.getLogger(__name__) -class AzureLocalCredentialProber(object): - ''' - Probing logics: - 1. Managed service identity - a. app service - b. virtual machine - 2. AZURE_CONN_STR, with SDK auth code file content in json. - https://github.com/Azure/azure-sdk-for-java/wiki/Authentication - 3. Individual environment variables to estabslish a service principal's creds - https://github.com/Azure/azure-sdk-for-go - 4. Azure CLI, through "az account get-access-token" - ''' +#pylint: disable=too-few-public-methods,missing-docstring - def __init__(self, subscription_id=None): - ''' - subscription_id: if missing, prober will find one based on detected creds. - ''' - self.subscription_id = subscription_id - self.creds = None - self._probe() +class CredsProber: + + def __init__(self, resource): + self.enabled = True + self.resource = resource - def signed_session(self, session=None): - self.creds.signed_session(session) - def _probe(self): - subscription_id = self.subscription_id or os.environ.get('AZURE_SUBSCRIPTION_ID') +class ManagedServiceIdentityProber(CredsProber): + + def probe(self, subscription_id=None): + if not self.enabled: + return None try: creds = MSIAuthentication() _LOGGER.warning('Managed system identity was detected') + return creds, subscription_id or _get_subscription_id(creds) except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError): - if os.environ.get('AZURE_CONN_STR'): # this should be java auth code file content - auth_info = json.loads(os.environ.get('AZURE_CONN_STR')) - subscription_id = auth_info.get('subscriptionId') - creds = ServicePrincipalCredentials(client_id=auth_info['clientId'], - secret=auth_info['clientSecret'], - tennt_id=auth_info['tenantId']) - elif os.environ.get('AZURE_CLIENT_ID'): - # TODO error out if other env vars are not set - creds = ServicePrincipalCredentials(client_id=os.environ.get('AZURE_CLIENT_ID'), - secret=os.environ.get('AZURE_CLIENT_SECRET'), - tennt_id=os.environ.get('AZURE_TENANT_ID')) - _LOGGER.warning('Service principal credentials was detected') - else: - try: - creds = CLICredentials() - _LOGGER.warning('Azure CLI credentials was detected') - except NotImplementedError: - raise ValueError('No credential was detected from the local machine') - self.creds = creds - if not subscription_id: - try: - from azure.mgmt.resource.subscriptions import SubscriptionClient - subscriptions = list(SubscriptionClient(creds).subscriptions.list()) - if subscriptions: - subscription_ids = [s.id.split('/')[-1] for s in subscriptions] - subscription_id = subscription_ids[0] - _LOGGER.warning('Found subscription "%s" to use', subscription_ids[0]) - if len(subscription_ids) > 1: - _LOGGER.warning('You also have accesses to a few other subscriptions "%S".' - ' You can supply subscription_id on creating the probe object') - except ImportError: # should be rare - _LOGGER.warning('Failed to load azure.mgmt.resource.subscriptions to find the default subscription.' - ' If this is expected, supply subscription_id on creating the probe object') - self.subscription_id = subscription_id + return None, None -class CLICredentials(BasicTokenAuthentication): +class ConnectionStrEnvProber(CredsProber): + ''' + Detect environment variable AZURE_CONN_STR + ''' + def probe(self, subscription_id=None): + creds = None + if self.enabled and os.environ.get('AZURE_CONN_STR'): + auth_info = json.loads(os.environ.get('AZURE_CONN_STR')) + creds = ServicePrincipalCredentials(client_id=auth_info['clientId'], + secret=auth_info['clientSecret'], + tennt_id=auth_info['tenantId']) + return creds, (subscription_id or auth_info.get('subscriptionId') or + _get_subscription_id(creds)) + return None, None - def __init__(self, resource=None, subscription_id=None): # allow subscriptions - super(CLICredentials, self).__init__(None) + +class ServicePrincipalEnvProber(CredsProber): + ''' + Detect envrionment variable AZURE_CLIENT_ID, AZURE_CLIENT_SECRET and AZURE_TENANT_ID + ''' + def probe(self, subscription_id=None): + creds = None + if os.environ.get('AZURE_CLIENT_ID'): + client_id, client_secret, tenant_id = (os.environ.get('AZURE_CLIENT_ID'), + os.environ.get('AZURE_CLIENT_SECRET'), + os.environ.get('AZURE_TENANT_ID')) + if not client_secret or not tenant_id: + raise ValueError('Environment variables of AZURE_CLIENT_SECRET and' + ' AZURE_TENANT_ID must be set') + creds = ServicePrincipalCredentials(client_id=client_id, secret=client_secret, + tenant=tenant_id, resource=self.resource) + + _LOGGER.warning('Service principal credentials was detected') + return creds, (subscription_id or os.environ.get('AZURE_SUBSCRIPTION_ID') or + _get_subscription_id(creds)) + return None, None + + +class AzureCLIProber(CredsProber): + ''' + Detect CLI installations + ''' + def probe(self, subscription_id=None): # pylint: disable=no-self-use uname = platform.uname() - # python 2, `platform.uname()` returns: tuple(system, node, release, version, machine, processor) platform_name = getattr(uname, 'system', None) or uname[0] platform_name = platform_name.lower() if platform_name == 'windows': - program_files_folder = os.environ.get('ProgramFiles(x86)') or os.environ.get('ProgramFiles') - probing_paths = [os.path.join(program_files_folder, 'Microsoft SDKs', 'Azure', 'CLI2', 'wbin', 'az.cmd')] + program_files_folder = (os.environ.get('ProgramFiles(x86)') or + os.environ.get('ProgramFiles')) + probing_paths = [os.path.join(program_files_folder, 'Microsoft SDKs', + 'Azure', 'CLI2', 'wbin', 'az.cmd')] else: probing_paths = ['/usr/bin/az', '/usr/local/bin/az'] - + cli_path = next((p for p in probing_paths if os.path.isfile(p)), None) - if cli_path is None: - raise NotImplementedError('Azure CLI is not installed') + if cli_path: + creds = CLICredentials(cli_path) + if subscription_id is None: + subscription_id = creds.invoke_cli_token_command()['subscription'] + return creds, subscription_id + return None, None + + +class CLICredentials(BasicTokenAuthentication): + + def __init__(self, cli_path, subscription_id=None): # allow subscriptions + super(CLICredentials, self).__init__(None) self.cli_path = cli_path - self.resource = resource self.subscription_id = subscription_id - self.token = None def set_token(self): + info = self.invoke_cli_token_command() + self.scheme, self.token = info['tokenType'], {'access_token': info['accessToken']} + + def invoke_cli_token_command(self): args = [self.cli_path, 'account', 'get-access-token'] if self.subscription_id: args.extend(['--subscription', self.subscription_id]) - p = Popen(args, stdout=PIPE, stderr=PIPE) - stdout, stderr = p.communicate() - p.wait() + process = Popen(args, stdout=PIPE, stderr=PIPE) + stdout, stderr = process.communicate() + process.wait() if stderr: raise ValueError('Retrieving acccess token failed: ' + stderr) - info = json.loads(stdout) - self.scheme, self.token = info['tokenType'], {'access_token': info['accessToken']} + return json.loads(stdout) def signed_session(self, session=None): - # Token cache is handled by the VM extension, call each time to avoid expiration self.set_token() return super(CLICredentials, self).signed_session(session) + +def _get_subscription_id(creds): + subscription_id = None + try: + from azure.mgmt.resource.subscriptions import SubscriptionClient + subscriptions = list(SubscriptionClient(creds).subscriptions.list()) + if subscriptions: + subscription_ids = [s.id.split('/')[-1] for s in subscriptions] + subscription_id = subscription_ids[0] + _LOGGER.warning('Found subscription "%s" to use', subscription_ids[0]) + if len(subscription_ids) > 1: + _LOGGER.warning('You also have accesses to a few other subscriptions "%S".' + ' You can supply subscription_id on creating the probe object') + except ImportError: # should be rare + _LOGGER.warning('Failed to load azure.mgmt.resource.subscriptions to find the default' + ' subscription. If this is expected, supply subscription_id on creating' + ' the probe object') + return subscription_id + + +def get_client_through_local_creds_probing(client_class, **kwargs): + ''' + Probing logics: + 1. AZURE_CONN_STR, with SDK auth code file content in json. + https://github.com/Azure/azure-sdk-for-java/wiki/Authentication + 2. Individual environment variables to estabslish a service principal's creds + https://github.com/Azure/azure-sdk-for-go + 3. Managed service identity + a. app service + b. virtual machine + 4. Azure CLI, through "az account get-access-token" + ''' + resource = kwargs.get('resource') + if not resource: #TODO figure out the right way + from .azure_cloud import AZURE_PUBLIC_CLOUD + resource = AZURE_PUBLIC_CLOUD.endpoints.resource_manager + probers = [ConnectionStrEnvProber(resource), ServicePrincipalEnvProber(resource), + ManagedServiceIdentityProber(resource), AzureCLIProber(resource)] + for prober in probers: + creds, subscription_id = prober.probe(subscription_id=kwargs.get('subscription_id')) + if creds: + return client_class(creds, subscription_id) + raise ValueError('No credential was detected from the local machine') diff --git a/tests/test_local_creds_prober.py b/tests/test_local_creds_prober.py index 9c34e81..5e7cefd 100644 --- a/tests/test_local_creds_prober.py +++ b/tests/test_local_creds_prober.py @@ -26,9 +26,8 @@ from azure.mgmt.storage import StorageManagementClient -from msrestazure.azure_local_creds_prober import AzureLocalCredentialProber +from msrestazure.azure_local_creds_prober import get_client_through_local_creds_probing -prober = AzureLocalCredentialProber() -client = StorageManagementClient(prober, prober.subscription_id) +client = get_client_through_local_creds_probing(StorageManagementClient) accounts = list(client.storage_accounts.list()) print('Found {} accounts'.format(len(accounts)))