diff --git a/tests/integration-tests/clusters_factory.py b/tests/integration-tests/clusters_factory.py index 0b0c8832fc..3aec85cb97 100644 --- a/tests/integration-tests/clusters_factory.py +++ b/tests/integration-tests/clusters_factory.py @@ -430,7 +430,7 @@ def create_cluster(self, cluster, log_error=True, raise_on_error=True, **kwargs) timeout=7200, raise_on_error=raise_on_error, log_error=log_error, - custom_cli_credentials=kwargs.get("custom_cli_credentials"), + custom_cli_credentials=cluster.custom_cli_credentials, ) logging.info("create-cluster response: %s", result.stdout) response = json.loads(result.stdout) @@ -481,11 +481,10 @@ def _build_command(cluster, kwargs): kwargs["suppress_validators"] = validators_list for k, val in kwargs.items(): - if k != "custom_cli_credentials": - if isinstance(val, (list, tuple)): - command.extend([f"--{kebab_case(k)}"] + list(map(str, val))) - else: - command.extend([f"--{kebab_case(k)}", str(val)]) + if isinstance(val, (list, tuple)): + command.extend([f"--{kebab_case(k)}"] + list(map(str, val))) + else: + command.extend([f"--{kebab_case(k)}", str(val)]) return command, wait diff --git a/tests/integration-tests/conftest.py b/tests/integration-tests/conftest.py index 1b5d4838c2..afb43ce44b 100644 --- a/tests/integration-tests/conftest.py +++ b/tests/integration-tests/conftest.py @@ -373,7 +373,7 @@ def clusters_factory(request, region): """ factory = ClustersFactory(delete_logs_on_success=request.config.getoption("delete_logs_on_success")) - def _cluster_factory(cluster_config, upper_case_cluster_name=False, **kwargs): + def _cluster_factory(cluster_config, upper_case_cluster_name=False, custom_cli_credentials=None, **kwargs): cluster_config = _write_config_to_outdir(request, cluster_config, "clusters_configs") cluster = Cluster( name=request.config.getoption("cluster") @@ -386,7 +386,7 @@ def _cluster_factory(cluster_config, upper_case_cluster_name=False, **kwargs): config_file=cluster_config, ssh_key=request.config.getoption("key_path"), region=region, - custom_cli_credentials=kwargs.get("custom_cli_credentials"), + custom_cli_credentials=custom_cli_credentials, ) if not request.config.getoption("cluster"): cluster.creation_response = factory.create_cluster(cluster, **kwargs) diff --git a/tests/integration-tests/framework/credential_providers.py b/tests/integration-tests/framework/credential_providers.py index d450b08f55..ee24bf8269 100644 --- a/tests/integration-tests/framework/credential_providers.py +++ b/tests/integration-tests/framework/credential_providers.py @@ -23,7 +23,7 @@ def register_cli_credentials_for_region(region, iam_role): cli_credentials[region] = iam_role -def run_pcluster_command(*args, **kwargs): +def run_pcluster_command(*args, custom_cli_credentials=None, **kwargs): """Run a command after assuming the role configured through register_cli_credentials_for_region.""" region = kwargs.get("region") @@ -31,10 +31,7 @@ def run_pcluster_command(*args, **kwargs): region = os.environ["AWS_DEFAULT_REGION"] if region in cli_credentials: - with sts_credential_provider( - region, credential_arn=kwargs.get("custom_cli_credentials") or cli_credentials.get(region) - ): - kwargs.pop("custom_cli_credentials", None) + with sts_credential_provider(region, credential_arn=custom_cli_credentials or cli_credentials.get(region)): return run_command(*args, **kwargs) else: return run_command(*args, **kwargs)