Skip to content
This repository has been archived by the owner on Apr 10, 2024. It is now read-only.

Commit

Permalink
address review feedback on probe type, sequences
Browse files Browse the repository at this point in the history
  • Loading branch information
yugangw-msft committed Oct 25, 2018
1 parent 3c61142 commit 13bcc41
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 78 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Expand Up @@ -11,4 +11,5 @@ autorest
coverage.xml
Pipfile.lock
build
dist
dist
env
197 changes: 123 additions & 74 deletions msrestazure/azure_local_creds_prober.py
Expand Up @@ -36,107 +36,156 @@

_LOGGER = logging.getLogger(__name__)

class AzureLocalCredentialProber(object):
'''
Probing logics:
1. Managed service identity
a. app service
b. virtual machine
2. AZURE_CONN_STR, with SDK auth code file content in json.
https://github.com/Azure/azure-sdk-for-java/wiki/Authentication
3. Individual environment variables to estabslish a service principal's creds
https://github.com/Azure/azure-sdk-for-go
4. Azure CLI, through "az account get-access-token"
'''
#pylint: disable=too-few-public-methods,missing-docstring

def __init__(self, subscription_id=None):
'''
subscription_id: if missing, prober will find one based on detected creds.
'''
self.subscription_id = subscription_id
self.creds = None
self._probe()
class CredsProber:

def __init__(self, resource):
self.enabled = True
self.resource = resource

def signed_session(self, session=None):
self.creds.signed_session(session)

def _probe(self):
subscription_id = self.subscription_id or os.environ.get('AZURE_SUBSCRIPTION_ID')
class ManagedServiceIdentityProber(CredsProber):

def probe(self, subscription_id=None):
if not self.enabled:
return None
try:
creds = MSIAuthentication()
_LOGGER.warning('Managed system identity was detected')
return creds, subscription_id or _get_subscription_id(creds)
except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError):
if os.environ.get('AZURE_CONN_STR'): # this should be java auth code file content
auth_info = json.loads(os.environ.get('AZURE_CONN_STR'))
subscription_id = auth_info.get('subscriptionId')
creds = ServicePrincipalCredentials(client_id=auth_info['clientId'],
secret=auth_info['clientSecret'],
tennt_id=auth_info['tenantId'])
elif os.environ.get('AZURE_CLIENT_ID'):
# TODO error out if other env vars are not set
creds = ServicePrincipalCredentials(client_id=os.environ.get('AZURE_CLIENT_ID'),
secret=os.environ.get('AZURE_CLIENT_SECRET'),
tennt_id=os.environ.get('AZURE_TENANT_ID'))
_LOGGER.warning('Service principal credentials was detected')
else:
try:
creds = CLICredentials()
_LOGGER.warning('Azure CLI credentials was detected')
except NotImplementedError:
raise ValueError('No credential was detected from the local machine')
self.creds = creds
if not subscription_id:
try:
from azure.mgmt.resource.subscriptions import SubscriptionClient
subscriptions = list(SubscriptionClient(creds).subscriptions.list())
if subscriptions:
subscription_ids = [s.id.split('/')[-1] for s in subscriptions]
subscription_id = subscription_ids[0]
_LOGGER.warning('Found subscription "%s" to use', subscription_ids[0])
if len(subscription_ids) > 1:
_LOGGER.warning('You also have accesses to a few other subscriptions "%S".'
' You can supply subscription_id on creating the probe object')
except ImportError: # should be rare
_LOGGER.warning('Failed to load azure.mgmt.resource.subscriptions to find the default subscription.'
' If this is expected, supply subscription_id on creating the probe object')
self.subscription_id = subscription_id
return None, None


class CLICredentials(BasicTokenAuthentication):
class ConnectionStrEnvProber(CredsProber):
'''
Detect environment variable AZURE_CONN_STR
'''
def probe(self, subscription_id=None):
creds = None
if self.enabled and os.environ.get('AZURE_CONN_STR'):
auth_info = json.loads(os.environ.get('AZURE_CONN_STR'))
creds = ServicePrincipalCredentials(client_id=auth_info['clientId'],
secret=auth_info['clientSecret'],
tennt_id=auth_info['tenantId'])
return creds, (subscription_id or auth_info.get('subscriptionId') or
_get_subscription_id(creds))
return None, None

def __init__(self, resource=None, subscription_id=None): # allow subscriptions
super(CLICredentials, self).__init__(None)

class ServicePrincipalEnvProber(CredsProber):
'''
Detect envrionment variable AZURE_CLIENT_ID, AZURE_CLIENT_SECRET and AZURE_TENANT_ID
'''
def probe(self, subscription_id=None):
creds = None
if os.environ.get('AZURE_CLIENT_ID'):
client_id, client_secret, tenant_id = (os.environ.get('AZURE_CLIENT_ID'),
os.environ.get('AZURE_CLIENT_SECRET'),
os.environ.get('AZURE_TENANT_ID'))
if not client_secret or not tenant_id:
raise ValueError('Environment variables of AZURE_CLIENT_SECRET and'
' AZURE_TENANT_ID must be set')
creds = ServicePrincipalCredentials(client_id=client_id, secret=client_secret,
tenant=tenant_id, resource=self.resource)

_LOGGER.warning('Service principal credentials was detected')
return creds, (subscription_id or os.environ.get('AZURE_SUBSCRIPTION_ID') or
_get_subscription_id(creds))
return None, None


class AzureCLIProber(CredsProber):
'''
Detect CLI installations
'''
def probe(self, subscription_id=None): # pylint: disable=no-self-use
uname = platform.uname()
# python 2, `platform.uname()` returns: tuple(system, node, release, version, machine, processor)
platform_name = getattr(uname, 'system', None) or uname[0]
platform_name = platform_name.lower()
if platform_name == 'windows':
program_files_folder = os.environ.get('ProgramFiles(x86)') or os.environ.get('ProgramFiles')
probing_paths = [os.path.join(program_files_folder, 'Microsoft SDKs', 'Azure', 'CLI2', 'wbin', 'az.cmd')]
program_files_folder = (os.environ.get('ProgramFiles(x86)') or
os.environ.get('ProgramFiles'))
probing_paths = [os.path.join(program_files_folder, 'Microsoft SDKs',
'Azure', 'CLI2', 'wbin', 'az.cmd')]
else:
probing_paths = ['/usr/bin/az', '/usr/local/bin/az']

cli_path = next((p for p in probing_paths if os.path.isfile(p)), None)
if cli_path is None:
raise NotImplementedError('Azure CLI is not installed')
if cli_path:
creds = CLICredentials(cli_path)
if subscription_id is None:
subscription_id = creds.invoke_cli_token_command()['subscription']
return creds, subscription_id
return None, None


class CLICredentials(BasicTokenAuthentication):

def __init__(self, cli_path, subscription_id=None): # allow subscriptions
super(CLICredentials, self).__init__(None)
self.cli_path = cli_path
self.resource = resource
self.subscription_id = subscription_id
self.token = None

def set_token(self):
info = self.invoke_cli_token_command()
self.scheme, self.token = info['tokenType'], {'access_token': info['accessToken']}

def invoke_cli_token_command(self):
args = [self.cli_path, 'account', 'get-access-token']
if self.subscription_id:
args.extend(['--subscription', self.subscription_id])
p = Popen(args, stdout=PIPE, stderr=PIPE)
stdout, stderr = p.communicate()
p.wait()
process = Popen(args, stdout=PIPE, stderr=PIPE)
stdout, stderr = process.communicate()
process.wait()
if stderr:
raise ValueError('Retrieving acccess token failed: ' + stderr)
info = json.loads(stdout)
self.scheme, self.token = info['tokenType'], {'access_token': info['accessToken']}
return json.loads(stdout)

def signed_session(self, session=None):
# Token cache is handled by the VM extension, call each time to avoid expiration
self.set_token()
return super(CLICredentials, self).signed_session(session)

def _get_subscription_id(creds):
subscription_id = None
try:
from azure.mgmt.resource.subscriptions import SubscriptionClient
subscriptions = list(SubscriptionClient(creds).subscriptions.list())
if subscriptions:
subscription_ids = [s.id.split('/')[-1] for s in subscriptions]
subscription_id = subscription_ids[0]
_LOGGER.warning('Found subscription "%s" to use', subscription_ids[0])
if len(subscription_ids) > 1:
_LOGGER.warning('You also have accesses to a few other subscriptions "%S".'
' You can supply subscription_id on creating the probe object')
except ImportError: # should be rare
_LOGGER.warning('Failed to load azure.mgmt.resource.subscriptions to find the default'
' subscription. If this is expected, supply subscription_id on creating'
' the probe object')
return subscription_id


def get_client_through_local_creds_probing(client_class, **kwargs):
'''
Probing logics:
1. AZURE_CONN_STR, with SDK auth code file content in json.
https://github.com/Azure/azure-sdk-for-java/wiki/Authentication
2. Individual environment variables to estabslish a service principal's creds
https://github.com/Azure/azure-sdk-for-go
3. Managed service identity
a. app service
b. virtual machine
4. Azure CLI, through "az account get-access-token"
'''
resource = kwargs.get('resource')
if not resource: #TODO figure out the right way
from .azure_cloud import AZURE_PUBLIC_CLOUD
resource = AZURE_PUBLIC_CLOUD.endpoints.resource_manager
probers = [ConnectionStrEnvProber(resource), ServicePrincipalEnvProber(resource),
ManagedServiceIdentityProber(resource), AzureCLIProber(resource)]
for prober in probers:
creds, subscription_id = prober.probe(subscription_id=kwargs.get('subscription_id'))
if creds:
return client_class(creds, subscription_id)
raise ValueError('No credential was detected from the local machine')
5 changes: 2 additions & 3 deletions tests/test_local_creds_prober.py
Expand Up @@ -26,9 +26,8 @@


from azure.mgmt.storage import StorageManagementClient
from msrestazure.azure_local_creds_prober import AzureLocalCredentialProber
from msrestazure.azure_local_creds_prober import get_client_through_local_creds_probing

prober = AzureLocalCredentialProber()
client = StorageManagementClient(prober, prober.subscription_id)
client = get_client_through_local_creds_probing(StorageManagementClient)
accounts = list(client.storage_accounts.list())
print('Found {} accounts'.format(len(accounts)))

0 comments on commit 13bcc41

Please sign in to comment.