diff --git a/src/k8s-extension/HISTORY.rst b/src/k8s-extension/HISTORY.rst index 0acd01b180c..bb9b20019ff 100644 --- a/src/k8s-extension/HISTORY.rst +++ b/src/k8s-extension/HISTORY.rst @@ -3,6 +3,10 @@ Release History =============== +1.5.3 +++++++++++++++++++ +* Add WorkloadIAM extension support and tests. + 1.5.2 ++++++++++++++++++ * Update help text on configuration-settings and configuration-protected-settings properties. diff --git a/src/k8s-extension/azext_k8s_extension/custom.py b/src/k8s-extension/azext_k8s_extension/custom.py index c88461bb48e..42b44227764 100644 --- a/src/k8s-extension/azext_k8s_extension/custom.py +++ b/src/k8s-extension/azext_k8s_extension/custom.py @@ -30,6 +30,7 @@ from .partner_extensions.AzureMLKubernetes import AzureMLKubernetes from .partner_extensions.DataProtectionKubernetes import DataProtectionKubernetes from .partner_extensions.Dapr import Dapr +from .partner_extensions.WorkloadIAM import WorkloadIAM from .partner_extensions.DefaultExtension import ( DefaultExtension, user_confirmation_factory, @@ -51,6 +52,7 @@ def ExtensionFactory(extension_name): "microsoft.azureml.kubernetes": AzureMLKubernetes, "microsoft.dapr": Dapr, "microsoft.dataprotection.kubernetes": DataProtectionKubernetes, + "microsoft.workloadiam": WorkloadIAM, } # Return the extension if we find it in the map, else return the default diff --git a/src/k8s-extension/azext_k8s_extension/partner_extensions/WorkloadIAM.py b/src/k8s-extension/azext_k8s_extension/partner_extensions/WorkloadIAM.py new file mode 100644 index 00000000000..1a2f4e0e7cd --- /dev/null +++ b/src/k8s-extension/azext_k8s_extension/partner_extensions/WorkloadIAM.py @@ -0,0 +1,152 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import subprocess + +from knack.log import get_logger +from knack.util import CLIError + +from azure.cli.core.azclierror import InvalidArgumentValueError + +from ..vendored_sdks.models import (Extension, Scope, ScopeCluster) + +from .DefaultExtension import DefaultExtension + +logger = get_logger(__name__) + +# The user settings are case-insensitive +CONFIG_SETTINGS_USER_TRUST_DOMAIN = 'trustdomain' +CONFIG_SETTINGS_USER_LOCAL_AUTHORITY = 'localauthority' +CONFIG_SETTINGS_USER_TENANT_ID = 'tenantid' +CONFIG_SETTINGS_USER_JOIN_TOKEN = 'jointoken' + +CONFIG_SETTINGS_HELM_TRUST_DOMAIN = 'global.workload-iam.trustDomain' +CONFIG_SETTINGS_HELM_TENANT_ID = 'global.workload-iam.tenantID' +CONFIG_SETTINGS_HELM_JOIN_TOKEN = 'workload-iam-local-authority.localAuthorityArgs.joinToken' + + +class WorkloadIAM(DefaultExtension): + + def Create(self, cmd, client, resource_group_name, cluster_name, name, cluster_type, cluster_rp, + extension_type, scope, auto_upgrade_minor_version, release_train, version, target_namespace, + release_namespace, configuration_settings, configuration_protected_settings, + configuration_settings_file, configuration_protected_settings_file, + plan_name, plan_publisher, plan_product): + """ + Create method for ExtensionType 'microsoft.workloadiam'. + """ + + # Ensure that the values provided by the user for generic values of Arc extensions are + # valid, set sensible default values if not. + if release_train is None: + # TODO - Set this to 'stable' when the extension is ready + release_train = 'preview' + + scope = scope.lower() + if scope is None: + scope = 'cluster' + elif scope != 'cluster': + raise InvalidArgumentValueError( + f"Invalid scope '{scope}'. This extension can only be installed at 'cluster' scope.") + + # Scope is always cluster + scope_cluster = ScopeCluster(release_namespace=release_namespace) + ext_scope = Scope(cluster=scope_cluster, namespace=None) + + # Create new dictionary where the keys of the user settings are all lowercase (but leave the + # others alone in case they are specific settings that have to be passed to the Helm chart). + validated_settings = dict() + all_user_settings = [CONFIG_SETTINGS_USER_TRUST_DOMAIN, CONFIG_SETTINGS_USER_TENANT_ID, + CONFIG_SETTINGS_USER_LOCAL_AUTHORITY, CONFIG_SETTINGS_USER_JOIN_TOKEN] + for key, value in configuration_settings.items(): + if key.lower() in all_user_settings: + validated_settings[key.lower()] = value + else: + validated_settings[key] = value + config_settings = validated_settings + + # Get user configuration values and remove them from the dictionary so that they aren't + # passed to the Helm chart + trust_domain = config_settings.pop(CONFIG_SETTINGS_USER_TRUST_DOMAIN, None) + tenant_id = config_settings.pop(CONFIG_SETTINGS_USER_TENANT_ID, None) + local_authority = config_settings.pop(CONFIG_SETTINGS_USER_LOCAL_AUTHORITY, None) + join_token = config_settings.pop(CONFIG_SETTINGS_USER_JOIN_TOKEN, None) + + # A trust domain name is always required + if trust_domain is None: + raise InvalidArgumentValueError( + "Invalid configuration settings. Please provide a trust domain name.") + + if tenant_id is None: + raise InvalidArgumentValueError( + "Invalid configuration settings. Please provide a tenant ID.") + + # If the user hasn't provided a join token, create one + if join_token is None: + if local_authority is None: + raise InvalidArgumentValueError( + "Invalid configuration settings. Either a join token or a local authority name " + "must be provided.") + join_token = self.get_join_token(trust_domain, local_authority) + else: + logger.info("Join token is provided") + + # Save configuration setting values to overwrite values in the Helm chart + configuration_settings[CONFIG_SETTINGS_HELM_TRUST_DOMAIN] = trust_domain + configuration_settings[CONFIG_SETTINGS_HELM_TENANT_ID] = tenant_id + configuration_settings[CONFIG_SETTINGS_HELM_JOIN_TOKEN] = join_token + + logger.debug("Configuration settings value for Helm: %s" % str(configuration_settings)) + + create_identity = True + extension = Extension( + extension_type=extension_type, + auto_upgrade_minor_version=auto_upgrade_minor_version, + release_train=release_train, + version=version, + scope=ext_scope, + configuration_settings=configuration_settings, + configuration_protected_settings=configuration_protected_settings + ) + return extension, name, create_identity + + def get_join_token(self, trust_domain, local_authority): + """ + Invoke the az command to obtain a join token. + """ + + logger.info("Getting a join token from the control plane") + + # Invoke az workload-iam command to obtain the join token + cmd = [ + "az", "workload-iam", "local-authority", "attestation-method", "create", + "--td", trust_domain, + "--la", local_authority, + "--type", "joinTokenAttestationMethod", + "--query", "singleUseToken", + "--dn", "myJoinToken", + ] + cmd_str = " ".join(cmd) + + try: + # Note: We can't use get_default_cli() here because its invoke() method + # always prints the console output, which we want to avoid. + result = subprocess.run(cmd, capture_output=True, shell=True) + except Exception as e: + logger.error(f"Error while generating a join token: {cmd_str}") + raise e + + if result.returncode != 0: + raise CLIError(f"Failed to generate a join token (exit code {result.returncode}): {cmd_str}") + + try: + # Strip double quotes from the output + command_output = result.stdout.decode("utf-8") + token = command_output.strip("\r\n").strip("\"") + except Exception as e: + logger.error(f"Failed to parse output of join token command: {cmd_str}") + raise e + + return token diff --git a/src/k8s-extension/azext_k8s_extension/tests/latest/test_workload_iam.py b/src/k8s-extension/azext_k8s_extension/tests/latest/test_workload_iam.py new file mode 100644 index 00000000000..0cb837d7f32 --- /dev/null +++ b/src/k8s-extension/azext_k8s_extension/tests/latest/test_workload_iam.py @@ -0,0 +1,319 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +# pylint: disable=protected-access + +import unittest + +from azure.cli.core.azclierror import InvalidArgumentValueError +from azext_k8s_extension.partner_extensions.WorkloadIAM import ( + WorkloadIAM, + CONFIG_SETTINGS_USER_TRUST_DOMAIN, + CONFIG_SETTINGS_USER_LOCAL_AUTHORITY, + CONFIG_SETTINGS_USER_TENANT_ID, + CONFIG_SETTINGS_USER_JOIN_TOKEN, + CONFIG_SETTINGS_HELM_TRUST_DOMAIN, + CONFIG_SETTINGS_HELM_TENANT_ID, + CONFIG_SETTINGS_HELM_JOIN_TOKEN +) + +from knack.util import CLIError + +from unittest.mock import patch + +class TestWorkloadIAM(unittest.TestCase): + + def test_workload_iam_create_without_join_token_success(self): + """ + Test that, when the user doesn't provide a join token, the Create() method calls + get_join_token() and creates a new one, and that the final configuration settings + are the expected ones. + """ + + mock_trust_domain_name = 'any_trust_domain_name.com' + mock_local_authority_name = 'any_local_authority_name' + mock_tenant_id = 'any_tenant_id' + mock_join_token = 'any_join_token' + + settings = { + CONFIG_SETTINGS_USER_TRUST_DOMAIN: mock_trust_domain_name, + CONFIG_SETTINGS_USER_LOCAL_AUTHORITY: mock_local_authority_name, + CONFIG_SETTINGS_USER_TENANT_ID: mock_tenant_id, + } + + def mock_extension_init(_self, *, extension_type, auto_upgrade_minor_version, release_train, + version, scope, configuration_settings, configuration_protected_settings): + assert(release_train == "dev") + assert(configuration_settings[CONFIG_SETTINGS_HELM_JOIN_TOKEN] == mock_join_token); + assert(configuration_settings[CONFIG_SETTINGS_HELM_TRUST_DOMAIN] == mock_trust_domain_name) + assert(configuration_settings[CONFIG_SETTINGS_HELM_TENANT_ID] == mock_tenant_id) + + + with patch('azext_k8s_extension.partner_extensions.WorkloadIAM.Extension.__init__', + new=mock_extension_init), \ + patch('azext_k8s_extension.partner_extensions.WorkloadIAM.WorkloadIAM.get_join_token', + return_value=mock_join_token): + + # Test & assert + workload_iam = WorkloadIAM() + _, name, _ = workload_iam.Create(cmd=None, client=None, resource_group_name=None, + cluster_name=None, name='workload-iam', cluster_type=None, cluster_rp=None, + extension_type=None, scope='cluster', auto_upgrade_minor_version=None, + release_train='dev', version='0.1.0', target_namespace=None, + release_namespace=None, configuration_settings=settings, + configuration_protected_settings=None, configuration_settings_file=None, + configuration_protected_settings_file=None, plan_name=None, plan_publisher=None, + plan_product=None) + self.assertEqual(name, 'workload-iam') + + + def test_workload_iam_create_with_join_token_and_local_authority_success(self): + """ + Test that, when the user provides a join token, the Create() method doesn't call + get_join_token(), and that the final configuration settings are the expected ones. The + provided local authority is only required to generate a new join token. As no token is + created, the local authority will just be ignored. + """ + + mock_trust_domain_name = 'any_trust_domain_name.com' + mock_local_authority_name = 'any_local_authority_name' + mock_tenant_id = 'any_tenant_id' + mock_join_token = 'any_join_token' + + settings = { + CONFIG_SETTINGS_USER_TRUST_DOMAIN: mock_trust_domain_name, + CONFIG_SETTINGS_USER_LOCAL_AUTHORITY: mock_local_authority_name, + CONFIG_SETTINGS_USER_TENANT_ID: mock_tenant_id, + CONFIG_SETTINGS_USER_JOIN_TOKEN: mock_join_token, + } + + def mock_extension_init(_self, *, extension_type, auto_upgrade_minor_version, release_train, + version, scope, configuration_settings, configuration_protected_settings): + assert(release_train == "dev") + assert(configuration_settings[CONFIG_SETTINGS_HELM_JOIN_TOKEN] == mock_join_token); + assert(configuration_settings[CONFIG_SETTINGS_HELM_TRUST_DOMAIN] == mock_trust_domain_name) + assert(configuration_settings[CONFIG_SETTINGS_HELM_TENANT_ID] == mock_tenant_id) + + + with patch('azext_k8s_extension.partner_extensions.WorkloadIAM.Extension.__init__', + new=mock_extension_init), \ + patch('azext_k8s_extension.partner_extensions.WorkloadIAM.WorkloadIAM.get_join_token', + return_value='BAD_JOIN_TOKEN'): + + # Test & assert + workload_iam = WorkloadIAM() + _, name, _ = workload_iam.Create(cmd=None, client=None, resource_group_name=None, + cluster_name=None, name='workload-iam', cluster_type=None, cluster_rp=None, + extension_type=None, scope='cluster', auto_upgrade_minor_version=None, + release_train='dev', version='0.1.0', target_namespace=None, + release_namespace=None, configuration_settings=settings, + configuration_protected_settings=None, configuration_settings_file=None, + configuration_protected_settings_file=None, plan_name=None, plan_publisher=None, + plan_product=None) + self.assertEqual(name, 'workload-iam') + + + def test_workload_iam_create_with_join_token_and_no_local_authority_success(self): + """ + Test that, when the user provides a join token, the Create() method doesn't call + get_join_token(), and that the final configuration settings are the expected ones. The + provided local authority is only required to generate a new join token, so the test should + pass even without it. + """ + + mock_trust_domain_name = 'any_trust_domain_name.com' + mock_tenant_id = 'any_tenant_id' + mock_join_token = 'any_join_token' + + settings = { + CONFIG_SETTINGS_USER_TRUST_DOMAIN: mock_trust_domain_name, + CONFIG_SETTINGS_USER_JOIN_TOKEN: mock_join_token, + CONFIG_SETTINGS_USER_TENANT_ID: mock_tenant_id, + } + + def mock_extension_init(_self, *, extension_type, auto_upgrade_minor_version, release_train, + version, scope, configuration_settings, configuration_protected_settings): + assert(release_train == "dev") + assert(configuration_settings[CONFIG_SETTINGS_HELM_JOIN_TOKEN] == mock_join_token); + assert(configuration_settings[CONFIG_SETTINGS_HELM_TRUST_DOMAIN] == mock_trust_domain_name) + assert(configuration_settings[CONFIG_SETTINGS_HELM_TENANT_ID] == mock_tenant_id) + + + with patch('azext_k8s_extension.partner_extensions.WorkloadIAM.Extension.__init__', + new=mock_extension_init), \ + patch('azext_k8s_extension.partner_extensions.WorkloadIAM.WorkloadIAM.get_join_token', + return_value='BAD_JOIN_TOKEN'): + + # Test & assert + workload_iam = WorkloadIAM() + _, name, _ = workload_iam.Create(cmd=None, client=None, resource_group_name=None, + cluster_name=None, name='workload-iam', cluster_type=None, cluster_rp=None, + extension_type=None, scope='cluster', auto_upgrade_minor_version=None, + release_train='dev', version='0.1.0', target_namespace=None, + release_namespace=None, configuration_settings=settings, + configuration_protected_settings=None, configuration_settings_file=None, + configuration_protected_settings_file=None, plan_name=None, plan_publisher=None, + plan_product=None) + self.assertEqual(name, 'workload-iam') + + def test_workload_iam_create_with_trust_domain_local_authority_no_tenant_id(self): + """ + Test that, when the user doesn't provide a tenant ID, there is an error. + """ + + mock_trust_domain_name = 'any_trust_domain_name.com' + mock_local_authority_name = 'any_local_authority_name' + + settings = { + CONFIG_SETTINGS_USER_TRUST_DOMAIN: mock_trust_domain_name, + CONFIG_SETTINGS_USER_LOCAL_AUTHORITY: mock_local_authority_name, + } + + with self.assertRaises(InvalidArgumentValueError) as context: + workload_iam = WorkloadIAM() + workload_iam.Create(cmd=None, client=None, resource_group_name=None, + cluster_name=None, name='workload-iam', cluster_type=None, cluster_rp=None, + extension_type=None, scope='cluster', auto_upgrade_minor_version=None, + release_train='dev', version='0.1.0', target_namespace=None, + release_namespace=None, configuration_settings=settings, + configuration_protected_settings=None, configuration_settings_file=None, + configuration_protected_settings_file=None, plan_name=None, plan_publisher=None, + plan_product=None) + + self.assertEqual(str(context.exception), + f"Invalid configuration settings. Please provide a tenant ID.") + + def test_workload_iam_create_with_wrong_scope_fails(self): + """ + Test that when the user provides a scope that isn't "cluster" the method Create() fails. + """ + + bad_scope = 'namespace' + + with self.assertRaises(InvalidArgumentValueError) as context: + workload_iam = WorkloadIAM() + workload_iam.Create(cmd=None, client=None, resource_group_name=None, + cluster_name=None, name='workload-iam', cluster_type=None, cluster_rp=None, + extension_type=None, scope=bad_scope, auto_upgrade_minor_version=None, + release_train='dev', version='0.1.0', target_namespace=None, + release_namespace=None, configuration_settings=None, + configuration_protected_settings=None, configuration_settings_file=None, + configuration_protected_settings_file=None, plan_name=None, plan_publisher=None, + plan_product=None) + + self.assertEqual(str(context.exception), + f"Invalid scope '{bad_scope}'. This extension can only be installed at 'cluster' scope.") + + + def test_workload_iam_create_with_not_enough_settings_fails(self): + """ + Test that when the user doesn't provide the trust domain or local authority the method + Create() fails. + """ + + mock_trust_domain_name = 'any_trust_domain_name.com' + mock_local_authority_name = 'any_local_authority_name' + mock_tenant_id = 'any_tenant_id' + + # Missing local authority + + settings = { + CONFIG_SETTINGS_USER_TRUST_DOMAIN: mock_trust_domain_name, + CONFIG_SETTINGS_USER_TENANT_ID: mock_tenant_id, + } + + with self.assertRaises(InvalidArgumentValueError) as context: + workload_iam = WorkloadIAM() + workload_iam.Create(cmd=None, client=None, resource_group_name=None, + cluster_name=None, name='workload-iam', cluster_type=None, cluster_rp=None, + extension_type=None, scope='cluster', auto_upgrade_minor_version=None, + release_train='dev', version='0.1.0', target_namespace=None, + release_namespace=None, configuration_settings=settings, + configuration_protected_settings=None, configuration_settings_file=None, + configuration_protected_settings_file=None, plan_name=None, plan_publisher=None, + plan_product=None) + + str_settings = str(settings) + self.assertEqual(str(context.exception), + f"Invalid configuration settings. Either a join token or a local authority name " + "must be provided.") + + # Missing trust domain + + settings = { + CONFIG_SETTINGS_USER_LOCAL_AUTHORITY: mock_local_authority_name, + } + + with self.assertRaises(InvalidArgumentValueError) as context: + workload_iam = WorkloadIAM() + workload_iam.Create(cmd=None, client=None, resource_group_name=None, + cluster_name=None, name='workload-iam', cluster_type=None, cluster_rp=None, + extension_type=None, scope='cluster', auto_upgrade_minor_version=None, + release_train='dev', version='0.1.0', target_namespace=None, + release_namespace=None, configuration_settings=settings, + configuration_protected_settings=None, configuration_settings_file=None, + configuration_protected_settings_file=None, plan_name=None, plan_publisher=None, + plan_product=None) + + str_settings = str(settings) + self.assertEqual(str(context.exception), + f"Invalid configuration settings. Please provide a trust domain name.") + + def test_workload_iam_get_join_token_with_valid_argument_success(self): + """ + Test that when get_join_token() succeedes it returns a token in the right format (between + double quotes) and that the arguments passed to "az workload-iam" are the expected ones. + """ + + mock_trust_domain_name = 'any_trust_domain_name.com' + mock_local_authority_name = 'any_local_authority_name' + mock_join_token = 'any_join_token' + + class MockResult(): + def __init__(self): + self.returncode = 0 + self.stdout = ('\"' + mock_join_token + '\"').encode('utf-8') + + with patch('azext_k8s_extension.partner_extensions.WorkloadIAM.subprocess.run', + return_value=MockResult()): + # Test & assert + workload_iam = WorkloadIAM() + join_token = workload_iam.get_join_token(mock_trust_domain_name, mock_local_authority_name) + self.assertEqual(join_token, mock_join_token) + + + def test_workload_iam_get_join_token_with_bad_exit_code(self): + """ + Test that get_join_token() fails with the right error message if "az workload-iam" returns a + non-zero error code (and if no exception is raised). + """ + + # Set up mocks + mock_trust_domain_name = 'any_trust_domain_name.com' + mock_local_authority_name = 'any_local_authority_name' + mock_join_token = 'any_join_token' + mock_exit_code = 1 + + cmd = [ + "az", "workload-iam", "local-authority", "attestation-method", "create", + "--td", mock_trust_domain_name, + "--la", mock_local_authority_name, + "--type", "joinTokenAttestationMethod", + "--query", "singleUseToken", + "--dn", "myJoinToken", + ] + + class MockResult(): + def __init__(self): + self.returncode = mock_exit_code + + with patch('azext_k8s_extension.partner_extensions.WorkloadIAM.subprocess.run', + return_value=MockResult()): + # Test & assert + workload_iam = WorkloadIAM() + cmd_str = " ".join(cmd) + self.assertRaisesRegex(CLIError, + f"Failed to generate a join token \(exit code {mock_exit_code}\): {cmd_str}", + workload_iam.get_join_token, mock_trust_domain_name, mock_local_authority_name) diff --git a/src/k8s-extension/setup.py b/src/k8s-extension/setup.py index 94d5b7c14cb..fa846158d6e 100644 --- a/src/k8s-extension/setup.py +++ b/src/k8s-extension/setup.py @@ -33,7 +33,7 @@ # TODO: Add any additional SDK dependencies here DEPENDENCIES = [] -VERSION = "1.5.2" +VERSION = "1.5.3" with open("README.rst", "r", encoding="utf-8") as f: README = f.read()