diff --git a/providers/amazon/tests/system/amazon/aws/example_emr_eks.py b/providers/amazon/tests/system/amazon/aws/example_emr_eks.py index f73335b5f1e94..82f0a38a564e9 100644 --- a/providers/amazon/tests/system/amazon/aws/example_emr_eks.py +++ b/providers/amazon/tests/system/amazon/aws/example_emr_eks.py @@ -17,11 +17,13 @@ from __future__ import annotations import json +import logging import subprocess import time from datetime import datetime import boto3 +from tenacity import retry, retry_if_exception_type, stop_after_delay, wait_exponential from airflow.providers.amazon.aws.hooks.eks import ClusterStates, NodegroupStates from airflow.providers.amazon.aws.operators.eks import EksCreateClusterOperator, EksDeleteClusterOperator @@ -59,7 +61,6 @@ JOB_ROLE_ARN_KEY = "JOB_ROLE_ARN" JOB_ROLE_NAME_KEY = "JOB_ROLE_NAME" SUBNETS_KEY = "SUBNETS" -UPDATE_TRUST_POLICY_WAIT_TIME_KEY = "UPDATE_TRUST_POLICY_WAIT_TIME" sys_test_context_task = ( SystemTestContextBuilder() @@ -67,7 +68,6 @@ .add_variable(JOB_ROLE_ARN_KEY) .add_variable(JOB_ROLE_NAME_KEY) .add_variable(SUBNETS_KEY, split_string=True) - .add_variable(UPDATE_TRUST_POLICY_WAIT_TIME_KEY, optional=True, default_value="10") .build() ) @@ -141,7 +141,7 @@ def delete_iam_oidc_identity_provider(cluster_name): @task -def update_trust_policy_execution_role(cluster_name, cluster_namespace, role_name, wait_time): +def update_trust_policy_execution_role(cluster_name, cluster_namespace, role_name): # Remove any already existing trusted entities added with "update-role-trust-policy" # Prevent getting an error "Cannot exceed quota for ACLSizePerRole" client = boto3.client("iam") @@ -177,8 +177,65 @@ def update_trust_policy_execution_role(cluster_name, cluster_namespace, role_nam if build.returncode != 0: raise RuntimeError(err) - # Wait for IAM changes to propagate to avoid authentication failures - time.sleep(int(wait_time)) + +@task +def wait_for_trust_policy_propagation(cluster_name, role_name): + """Validate that the IAM trust policy has propagated by checking the role's + trust policy contains the expected OIDC provider. + + Uses exponential backoff retries (up to 5 minutes) instead of a fixed sleep, + which avoids both wasting time when propagation is fast and failing when it's slow. + """ + log = logging.getLogger(__name__) + + # Determine the expected OIDC provider ARN from the EKS cluster + eks_client = boto3.client("eks") + oidc_issuer_url = eks_client.describe_cluster(name=cluster_name)["cluster"]["identity"]["oidc"]["issuer"] + oidc_issuer_endpoint = oidc_issuer_url.replace("https://", "") + account_id = boto3.client("sts").get_caller_identity()["Account"] + expected_oidc_provider_arn = f"arn:aws:iam::{account_id}:oidc-provider/{oidc_issuer_endpoint}" + + @retry( + retry=retry_if_exception_type(RuntimeError), + wait=wait_exponential(multiplier=1, min=5, max=30), + stop=stop_after_delay(300), + reraise=True, + ) + def _validate_trust_policy(): + iam_client = boto3.client("iam") + + # Verify the trust policy document contains the expected OIDC provider + role = iam_client.get_role(RoleName=role_name)["Role"] + trust_policy = role["AssumeRolePolicyDocument"] + + has_oidc_statement = False + for statement in trust_policy.get("Statement", []): + if statement.get("Action") != "sts:AssumeRoleWithWebIdentity": + continue + principal = statement.get("Principal", {}) + federated = principal.get("Federated", "") + if oidc_issuer_endpoint in federated: + has_oidc_statement = True + break + + if not has_oidc_statement: + log.info( + "Trust policy does not yet contain OIDC provider %s, retrying...", + expected_oidc_provider_arn, + ) + raise RuntimeError( + f"Trust policy for role {role_name} does not yet contain " + f"the expected OIDC provider: {expected_oidc_provider_arn}" + ) + + log.info("Trust policy document confirmed for role %s", role_name) + + _validate_trust_policy() + + # Brief buffer after IAM confirms the trust policy document — cross-service + # caches (EKS/EMR) may still serve the old policy for a few seconds. + time.sleep(15) + log.info("Trust policy validation complete, proceeding.") @task(trigger_rule=TriggerRule.ALL_DONE) @@ -200,7 +257,6 @@ def delete_virtual_cluster(virtual_cluster_id): subnets = test_context[SUBNETS_KEY] job_role_arn = test_context[JOB_ROLE_ARN_KEY] job_role_name = test_context[JOB_ROLE_NAME_KEY] - update_trust_policy_wait_time = test_context[UPDATE_TRUST_POLICY_WAIT_TIME_KEY] s3_bucket_name = f"{env_id}-bucket" eks_cluster_name = f"{env_id}-cluster" @@ -328,9 +384,8 @@ def delete_virtual_cluster(virtual_cluster_id): create_cluster_and_nodegroup, await_create_nodegroup, run_eksctl_commands(eks_cluster_name, eks_namespace), - update_trust_policy_execution_role( - eks_cluster_name, eks_namespace, job_role_name, update_trust_policy_wait_time - ), + update_trust_policy_execution_role(eks_cluster_name, eks_namespace, job_role_name), + wait_for_trust_policy_propagation(eks_cluster_name, job_role_name), # TEST BODY create_emr_eks_cluster, job_starter,