Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions tests/integration-tests/clusters_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def create_cluster(self, cluster, log_error=True, raise_on_error=True, **kwargs)
timeout=7200,
raise_on_error=raise_on_error,
log_error=log_error,
custom_cli_credentials=kwargs.get("custom_cli_credentials"),
custom_cli_credentials=cluster.custom_cli_credentials,
)
logging.info("create-cluster response: %s", result.stdout)
response = json.loads(result.stdout)
Expand Down Expand Up @@ -481,11 +481,10 @@ def _build_command(cluster, kwargs):
kwargs["suppress_validators"] = validators_list

for k, val in kwargs.items():
if k != "custom_cli_credentials":
if isinstance(val, (list, tuple)):
command.extend([f"--{kebab_case(k)}"] + list(map(str, val)))
else:
command.extend([f"--{kebab_case(k)}", str(val)])
if isinstance(val, (list, tuple)):
command.extend([f"--{kebab_case(k)}"] + list(map(str, val)))
else:
command.extend([f"--{kebab_case(k)}", str(val)])

return command, wait

Expand Down
4 changes: 2 additions & 2 deletions tests/integration-tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def clusters_factory(request, region):
"""
factory = ClustersFactory(delete_logs_on_success=request.config.getoption("delete_logs_on_success"))

def _cluster_factory(cluster_config, upper_case_cluster_name=False, **kwargs):
def _cluster_factory(cluster_config, upper_case_cluster_name=False, custom_cli_credentials=None, **kwargs):
cluster_config = _write_config_to_outdir(request, cluster_config, "clusters_configs")
cluster = Cluster(
name=request.config.getoption("cluster")
Expand All @@ -386,7 +386,7 @@ def _cluster_factory(cluster_config, upper_case_cluster_name=False, **kwargs):
config_file=cluster_config,
ssh_key=request.config.getoption("key_path"),
region=region,
custom_cli_credentials=kwargs.get("custom_cli_credentials"),
custom_cli_credentials=custom_cli_credentials,
)
if not request.config.getoption("cluster"):
cluster.creation_response = factory.create_cluster(cluster, **kwargs)
Expand Down
7 changes: 2 additions & 5 deletions tests/integration-tests/framework/credential_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,15 @@ def register_cli_credentials_for_region(region, iam_role):
cli_credentials[region] = iam_role


def run_pcluster_command(*args, **kwargs):
def run_pcluster_command(*args, custom_cli_credentials=None, **kwargs):
"""Run a command after assuming the role configured through register_cli_credentials_for_region."""

region = kwargs.get("region")
if not region:
region = os.environ["AWS_DEFAULT_REGION"]

if region in cli_credentials:
with sts_credential_provider(
region, credential_arn=kwargs.get("custom_cli_credentials") or cli_credentials.get(region)
):
kwargs.pop("custom_cli_credentials", None)
with sts_credential_provider(region, credential_arn=custom_cli_credentials or cli_credentials.get(region)):
return run_command(*args, **kwargs)
else:
return run_command(*args, **kwargs)
Expand Down