diff --git a/util/upload-instance-slot-map.py b/util/upload-instance-slot-map.py index 8676f04636..bd5ebc1df1 100644 --- a/util/upload-instance-slot-map.py +++ b/util/upload-instance-slot-map.py @@ -43,26 +43,60 @@ def get_all_aws_regions(region): return sorted(r.get("RegionName") for r in ec2.describe_regions().get("Regions")) -def upload(regions): - for region in regions: - bucket_name = region + "-aws-parallelcluster" - print(bucket_name) - try: +def push_to_s3(region, aws_credentials=None): + bucket_name = region + "-aws-parallelcluster" + print(bucket_name) + try: + if aws_credentials: + s3 = boto3.resource( + "s3", + region_name=region, + aws_access_key_id=aws_credentials["AccessKeyId"], + aws_secret_access_key=aws_credentials["SecretAccessKey"], + aws_session_token=aws_credentials["SessionToken"], + ) + else: s3 = boto3.resource("s3", region_name=region) - s3.meta.client.head_bucket(Bucket=bucket_name) - except ClientError as e: - # If a client error is thrown, then check that it was a 404 error. - # If it was a 404 error, then the bucket does not exist. - error_code = int(e.response["Error"]["Code"]) - if error_code == 404: - print("Bucket %s does not exist", bucket_name) - continue - raise - - bucket = s3.Bucket(bucket_name) - bucket.upload_file("instances.json", "instances/instances.json") - object_acl = s3.ObjectAcl(bucket_name, "instances/instances.json") - object_acl.put(ACL="public-read") + s3.meta.client.head_bucket(Bucket=bucket_name) + except ClientError as e: + # If a client error is thrown, then check that it was a 404 error. + # If it was a 404 error, then the bucket does not exist. + error_code = int(e.response["Error"]["Code"]) + if error_code == 404: + print("Bucket %s does not exist", bucket_name) + return + raise + + bucket = s3.Bucket(bucket_name) + bucket.upload_file("instances.json", "instances/instances.json") + object_acl = s3.ObjectAcl(bucket_name, "instances/instances.json") + object_acl.put(ACL="public-read") + + +def upload(regions, main_region, credentials): + for region in regions: + push_to_s3(region) + + if main_region == region: + for credential in credentials: + credential_region = credential[0] + credential_endpoint = credential[1] + credential_arn = credential[2] + credential_external_id = credential[3] + + try: + sts = boto3.client("sts", region_name=main_region, endpoint_url=credential_endpoint) + assumed_role_object = sts.assume_role( + RoleArn=credential_arn, + ExternalId=credential_external_id, + RoleSessionName=credential_region + "upload_instance_slot_map_sts_session", + ) + aws_credentials = assumed_role_object["Credentials"] + + push_to_s3(credential_region, aws_credentials) + except ClientError: + print("Warning: non authorized in region '{0}', skipping".format(credential_region)) + pass if __name__ == "__main__": @@ -77,20 +111,35 @@ def upload(regions): required=False, default="instance-details.json", ) + parser.add_argument( + "--credential", + type=str, + action="append", + help="STS credential endpoint, in the format ,,,. Could be specified multiple times", + required=False, + ) args = parser.parse_args() if args.partition == "commercial": - region = "us-east-1" + main_region = "us-east-1" elif args.partition == "govcloud": - region = "us-gov-west-1" + main_region = "us-gov-west-1" elif args.partition == "china": - region = "cn-north-1" + main_region = "cn-north-1" else: print("Unsupported partition %s" % args.partition) sys.exit(1) + credentials = [] + if args.credential: + credentials = [ + tuple(credential_tuple.strip().split(",")) + for credential_tuple in args.credential + if credential_tuple.strip() + ] + dump_instances(args.instance_details) - regions = get_all_aws_regions(region) + regions = get_all_aws_regions(main_region) - upload(regions) + upload(regions, main_region, credentials)