Skip to content

Commit

Permalink
Add new parameter --credential
Browse files Browse the repository at this point in the history
The tool is now able to push instance whitelist considering also additional credentials to connect to a region.
Additional credential is a comma separated list in the format region,endpoint,ARN,externalId.
It could be specified multiple times, for different credential

Signed-off-by: Luca Carrogu <carrogu@amazon.com>
  • Loading branch information
lukeseawalker committed Jul 5, 2019
1 parent b7c5122 commit 4cf8b80
Showing 1 changed file with 84 additions and 27 deletions.
111 changes: 84 additions & 27 deletions util/instance-whitelist.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,24 @@
from botocore.exceptions import ClientError, EndpointConnectionError


def get_all_aws_regions(partition):
if partition == "commercial":
region = "us-east-1"
elif partition == "govcloud":
region = "us-gov-west-1"
elif partition == "china":
region = "cn-north-1"
else:
print("Unsupported partition %s" % partition)
sys.exit(1)

def get_all_aws_regions(region):
ec2 = boto3.client("ec2", region_name=region)
return set(sorted(r.get("RegionName") for r in ec2.describe_regions().get("Regions")))


def get_batch_instance_whitelist(args, region):

def get_batch_instance_whitelist(region, aws_credentials=None):
instances = []
# try to create a dummy compute environmment
batch_client = boto3.client("batch", region_name=region)
if aws_credentials:
batch_client = boto3.client(
"batch",
region_name=region,
aws_access_key_id=aws_credentials.get("AccessKeyId"),
aws_secret_access_key=aws_credentials.get("SecretAccessKey"),
aws_session_token=aws_credentials.get("SessionToken"),
)
else:
batch_client = boto3.client("batch", region_name=region)

try:
response = batch_client.create_compute_environment(
Expand Down Expand Up @@ -77,9 +75,17 @@ def get_batch_instance_whitelist(args, region):
return instances


def upload_to_s3(args, region, instances, key):

s3_client = boto3.resource("s3", region_name=region)
def upload_to_s3(args, region, instances, key, aws_credentials=None):
if aws_credentials:
s3_client = boto3.resource(
"s3",
region_name=region,
aws_access_key_id=aws_credentials.get("AccessKeyId"),
aws_secret_access_key=aws_credentials.get("SecretAccessKey"),
aws_session_token=aws_credentials.get("SessionToken"),
)
else:
s3_client = boto3.resource("s3", region_name=region)

bucket = args.bucket if args.bucket else "%s-aws-parallelcluster" % region

Expand All @@ -101,16 +107,42 @@ def upload_to_s3(args, region, instances, key):
return response


def main(args):
def main(main_region, args):
# For all regions
for region in args.regions:
batch_instances = get_batch_instance_whitelist(args, region)
if args.efa:
efa_instances = args.efa.split(",")
instances = {"Features": {"efa": {"instances": efa_instances}, "batch": {"instances": batch_instances}}}
upload_to_s3(args, region, instances, "features/feature_whitelist.json")
else:
upload_to_s3(args, region, batch_instances, "instances/batch_instances.json")
push_whitelist(args, region)

if main_region == region:
for credential in credentials:
credential_region = credential[0]
credential_endpoint = credential[1]
credential_arn = credential[2]
credential_external_id = credential[3]

try:
sts = boto3.client("sts", region_name=main_region, endpoint_url=credential_endpoint)
assumed_role_object = sts.assume_role(
RoleArn=credential_arn,
ExternalId=credential_external_id,
RoleSessionName=credential_region + "upload_instance_slot_map_sts_session",
)
aws_credentials = assumed_role_object["Credentials"]

push_whitelist(args, credential_region, aws_credentials)

except ClientError:
print("Warning: non authorized in region '{0}', skipping".format(credential_region))
pass


def push_whitelist(args, region, aws_credentials=None):
batch_instances = get_batch_instance_whitelist(region, aws_credentials)
if args.efa:
efa_instances = args.efa.split(",")
instances = {"Features": {"efa": {"instances": efa_instances}, "batch": {"instances": batch_instances}}}
upload_to_s3(args, region, instances, "features/feature_whitelist.json", aws_credentials)
else:
upload_to_s3(args, region, batch_instances, "instances/batch_instances.json", aws_credentials)


if __name__ == "__main__":
Expand All @@ -123,16 +155,41 @@ def main(args):
help='Valid Regions, can include "all", or comma seperated list of regions',
required=True,
)
parser.add_argument(
"--credential",
type=str,
action="append",
help="STS credential endpoint, in the format <region>,<endpoint>,<ARN>,<externalId>. Could be specified multiple times",
required=False,
)
parser.add_argument(
"--bucket", type=str, help="Bucket to upload too, defaults to [region]-aws-parallelcluster", required=False
)
parser.add_argument("--efa", type=str, help="Comma separated list of instances supported by EFA", required=False)
parser.add_argument("--dryrun", type=str, help="Doesn't push anything to S3, just outputs", required=True)
args = parser.parse_args()

if args.partition == "commercial":
main_region = "us-east-1"
elif args.partition == "govcloud":
main_region = "us-gov-west-1"
elif args.partition == "china":
main_region = "cn-north-1"
else:
print("Unsupported partition %s" % args.partition)
sys.exit(1)

credentials = []
if args.credential:
credentials = [
tuple(credential_tuple.strip().split(","))
for credential_tuple in args.credential
if credential_tuple.strip()
]

if args.regions == "all":
args.regions = get_all_aws_regions(args.partition)
args.regions = get_all_aws_regions(main_region)
else:
args.regions = args.regions.split(",")

main(args)
main(main_region, args)

0 comments on commit 4cf8b80

Please sign in to comment.