diff --git a/util/instance-whitelist.py b/util/instance-whitelist.py index 094516bd52..d483078033 100755 --- a/util/instance-whitelist.py +++ b/util/instance-whitelist.py @@ -27,26 +27,24 @@ from botocore.exceptions import ClientError, EndpointConnectionError -def get_all_aws_regions(partition): - if partition == "commercial": - region = "us-east-1" - elif partition == "govcloud": - region = "us-gov-west-1" - elif partition == "china": - region = "cn-north-1" - else: - print("Unsupported partition %s" % partition) - sys.exit(1) - +def get_all_aws_regions(region): ec2 = boto3.client("ec2", region_name=region) return set(sorted(r.get("RegionName") for r in ec2.describe_regions().get("Regions"))) -def get_batch_instance_whitelist(args, region): - +def get_batch_instance_whitelist(region, aws_credentials=None): instances = [] # try to create a dummy compute environmment - batch_client = boto3.client("batch", region_name=region) + if aws_credentials: + batch_client = boto3.client( + "batch", + region_name=region, + aws_access_key_id=aws_credentials.get("AccessKeyId"), + aws_secret_access_key=aws_credentials.get("SecretAccessKey"), + aws_session_token=aws_credentials.get("SessionToken"), + ) + else: + batch_client = boto3.client("batch", region_name=region) try: response = batch_client.create_compute_environment( @@ -77,9 +75,17 @@ def get_batch_instance_whitelist(args, region): return instances -def upload_to_s3(args, region, instances, key): - - s3_client = boto3.resource("s3", region_name=region) +def upload_to_s3(args, region, instances, key, aws_credentials=None): + if aws_credentials: + s3_client = boto3.resource( + "s3", + region_name=region, + aws_access_key_id=aws_credentials.get("AccessKeyId"), + aws_secret_access_key=aws_credentials.get("SecretAccessKey"), + aws_session_token=aws_credentials.get("SessionToken"), + ) + else: + s3_client = boto3.resource("s3", region_name=region) bucket = args.bucket if args.bucket else "%s-aws-parallelcluster" % region @@ -101,16 +107,42 @@ def upload_to_s3(args, region, instances, key): return response -def main(args): +def main(main_region, args): # For all regions for region in args.regions: - batch_instances = get_batch_instance_whitelist(args, region) - if args.efa: - efa_instances = args.efa.split(",") - instances = {"Features": {"efa": {"instances": efa_instances}, "batch": {"instances": batch_instances}}} - upload_to_s3(args, region, instances, "features/feature_whitelist.json") - else: - upload_to_s3(args, region, batch_instances, "instances/batch_instances.json") + push_whitelist(args, region) + + if main_region == region: + for credential in credentials: + credential_region = credential[0] + credential_endpoint = credential[1] + credential_arn = credential[2] + credential_external_id = credential[3] + + try: + sts = boto3.client("sts", region_name=main_region, endpoint_url=credential_endpoint) + assumed_role_object = sts.assume_role( + RoleArn=credential_arn, + ExternalId=credential_external_id, + RoleSessionName=credential_region + "upload_instance_slot_map_sts_session", + ) + aws_credentials = assumed_role_object["Credentials"] + + push_whitelist(args, credential_region, aws_credentials) + + except ClientError: + print("Warning: non authorized in region '{0}', skipping".format(credential_region)) + pass + + +def push_whitelist(args, region, aws_credentials=None): + batch_instances = get_batch_instance_whitelist(region, aws_credentials) + if args.efa: + efa_instances = args.efa.split(",") + instances = {"Features": {"efa": {"instances": efa_instances}, "batch": {"instances": batch_instances}}} + upload_to_s3(args, region, instances, "features/feature_whitelist.json", aws_credentials) + else: + upload_to_s3(args, region, batch_instances, "instances/batch_instances.json", aws_credentials) if __name__ == "__main__": @@ -123,6 +155,13 @@ def main(args): help='Valid Regions, can include "all", or comma seperated list of regions', required=True, ) + parser.add_argument( + "--credential", + type=str, + action="append", + help="STS credential endpoint, in the format ,,,. Could be specified multiple times", + required=False, + ) parser.add_argument( "--bucket", type=str, help="Bucket to upload too, defaults to [region]-aws-parallelcluster", required=False ) @@ -130,9 +169,27 @@ def main(args): parser.add_argument("--dryrun", type=str, help="Doesn't push anything to S3, just outputs", required=True) args = parser.parse_args() + if args.partition == "commercial": + main_region = "us-east-1" + elif args.partition == "govcloud": + main_region = "us-gov-west-1" + elif args.partition == "china": + main_region = "cn-north-1" + else: + print("Unsupported partition %s" % args.partition) + sys.exit(1) + + credentials = [] + if args.credential: + credentials = [ + tuple(credential_tuple.strip().split(",")) + for credential_tuple in args.credential + if credential_tuple.strip() + ] + if args.regions == "all": - args.regions = get_all_aws_regions(args.partition) + args.regions = get_all_aws_regions(main_region) else: args.regions = args.regions.split(",") - main(args) + main(main_region, args)