Skip to content

Commit

Permalink
Merge pull request #246 from CloudWanderer-io/feature/improve-account…
Browse files Browse the repository at this point in the history
…-id-and-region-fetching

Allow the setting of account_id and enabled_regions in CloudWandererA…
  • Loading branch information
Sam-Martin committed Dec 13, 2021
2 parents 826f749 + 0e87b5c commit 3ac1e6c
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 12 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# 0.26.0

- Allow the setting of `account_id` and `enabled_regions` in `CloudWandererAWSInterface` if you already know these values and want to avoid unnecessary API calls.
- Added the option of passing a `CloudWandererBoto3SessionGetterClientConfig` for configuring internal getter clients in `CloudWandererBoto3Session`.
# 0.25.2

- Fix #242 by moving to using `MANIFEST.in`
Expand Down
9 changes: 7 additions & 2 deletions cloudwanderer/aws_interface/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
"""The CloudWanderer AWS Interface."""
from .interface import CloudWandererAWSInterface
from .session import CloudWandererBoto3Session
from .models import AWSResourceTypeFilter
from .session import CloudWandererBoto3Session, CloudWandererBoto3SessionGetterClientConfig

__all__ = ["CloudWandererAWSInterface", "CloudWandererBoto3Session", "AWSResourceTypeFilter"]
__all__ = [
"CloudWandererAWSInterface",
"CloudWandererBoto3Session",
"AWSResourceTypeFilter",
"CloudWandererBoto3SessionGetterClientConfig",
]
22 changes: 16 additions & 6 deletions cloudwanderer/aws_interface/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,35 @@ class CloudWandererAWSInterface(CloudInterface):
def __init__(
self,
cloudwanderer_boto3_session: Optional[CloudWandererBoto3Session] = None,
account_id: Optional[str] = None,
enabled_regions: Optional[List[str]] = None,
) -> None:
"""Simplifies lookup of Boto3 services and resources.
Arguments:
cloudwanderer_boto3_session:
A CloudWandererBoto3Session session, if not provided the default will be used.
account_id:
The AWS account ID we're fetching resources from. This will be fetched automatically via API
call if not supplied.
enabled_regions:
The list of regions enabled in this AWS account. This will be fetched automatically via API
call if not supplied.
"""
self.cloudwanderer_boto3_session = cloudwanderer_boto3_session or CloudWandererBoto3Session()
self.account_id = account_id
self.enabled_regions = enabled_regions

def get_enabled_regions(self) -> List[str]:
"""Return the list of regions enabled.
Fulfils the interface requirements for :class:`cloudwanderer.cloud_wanderer.CloudWanderer` to call.
"""
return self.cloudwanderer_boto3_session.get_enabled_regions()
return self.enabled_regions or self.cloudwanderer_boto3_session.get_enabled_regions()

def get_account_id(self) -> str:
"""Return the ID of the account we're getting resources from."""
return self.account_id or self.cloudwanderer_boto3_session.get_account_id()

def get_resource(
self,
Expand Down Expand Up @@ -292,11 +306,7 @@ def _inflate_action_set_regions(self, action_set_templates: List[TemplateActionS
enabled_regions = self.cloudwanderer_boto3_session.get_enabled_regions()
result = []
for action_set_template in action_set_templates:
result.append(
action_set_template.inflate(
regions=enabled_regions, account_id=self.cloudwanderer_boto3_session.get_account_id()
)
)
result.append(action_set_template.inflate(regions=enabled_regions, account_id=self.get_account_id()))

return result

Expand Down
33 changes: 30 additions & 3 deletions cloudwanderer/aws_interface/session.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Subclass of Boto3 Session class to provide additional helper methods."""
import logging
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import boto3
import botocore
Expand All @@ -17,6 +17,30 @@
logger = logging.getLogger(__name__)


class CloudWandererBoto3SessionGetterClientConfig:
"""Allows the specification of internal getter client config :class:`CloudWandererBoto3SessionGetterClientConfig` .
Example:
Configure the sts client (used in :attr:`CloudWandererBoto3Session.get_account_id`) to use
a regional endpoint url.
>>> from cloudwanderer.aws_interface import (
... CloudWandererBoto3SessionGetterClientConfig,
... CloudWandererBoto3Session
... )
>>> getter_client_config = CloudWandererBoto3SessionGetterClientConfig(
... sts={"endpoint_url": "https://eu-west-1.sts.amazonaws.com"}
... )
>>> cloudwanderer_boto3_session = CloudWandererBoto3Session(getter_client_config=getter_client_config)
"""

def __init__(self, **kwargs: Dict[str, Dict[str, Any]]) -> None:
self.client_configs: Dict[str, Dict[str, Any]] = kwargs

def __call__(self, service_name: str) -> Dict[str, Any]:
return self.client_configs.get(service_name, {})


class CloudWandererBoto3Session(boto3.session.Session):
"""Subclass of Boto3 Session class to provide additional helper methods."""

Expand All @@ -30,6 +54,7 @@ def __init__(
profile_name=None,
resource_factory=None,
service_mapping_loader: Loader = None,
getter_client_config: Optional[CloudWandererBoto3SessionGetterClientConfig] = None,
) -> None:
self.service_mapping_loader = service_mapping_loader
super().__init__(
Expand All @@ -47,10 +72,12 @@ def __init__(
cloudwanderer_boto3_session=self,
)

self.getter_client_config = getter_client_config or CloudWandererBoto3SessionGetterClientConfig()

@memoized_method()
def get_account_id(self) -> str:
"""Return the AWS Account ID our Boto3 session is authenticated against."""
sts = self.client("sts")
sts = self.client("sts", **self.getter_client_config("sts"))
return sts.get_caller_identity()["Account"]

def _setup_loader(self) -> None:
Expand All @@ -60,7 +87,7 @@ def _setup_loader(self) -> None:
@memoized_method()
def get_enabled_regions(self) -> List[str]:
"""Return a list of enabled regions in this account."""
regions = self.client("ec2").describe_regions()["Regions"]
regions = self.client("ec2", **self.getter_client_config("ec2")).describe_regions()["Regions"]
return [region["RegionName"] for region in regions if region["OptInStatus"] != "not-opted-in"]

def resource( # type: ignore[override]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
long_description = re.sub(r"..\s+doctest\s+::", ".. code-block ::", f.read())

setup(
version="0.25.2",
version="0.26.0",
python_requires=">=3.6.0",
name="cloudwanderer",
packages=find_packages(include=["cloudwanderer", "cloudwanderer.*"]),
Expand Down
31 changes: 31 additions & 0 deletions tests/unit/aws_interface/test_aws_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from unittest.mock import MagicMock

from cloudwanderer.aws_interface import CloudWandererAWSInterface


def test_get_account_id():
subject = CloudWandererAWSInterface(
cloudwanderer_boto3_session=MagicMock(**{"get_account_id.return_value": "0123456789012"})
)

assert subject.get_account_id() == "0123456789012"


def test_get_account_id_init_arg():
subject = CloudWandererAWSInterface(account_id="0123456789012")

assert subject.get_account_id() == "0123456789012"


def test_get_enabled_regions():
subject = CloudWandererAWSInterface(
cloudwanderer_boto3_session=MagicMock(**{"get_enabled_regions.return_value": ["eu-west-1"]})
)

assert subject.get_enabled_regions() == ["eu-west-1"]


def test_get_enabled_regions_init_arg():
subject = CloudWandererAWSInterface(enabled_regions=["eu-west-1"])

assert subject.get_enabled_regions() == ["eu-west-1"]
59 changes: 59 additions & 0 deletions tests/unit/aws_interface/test_cloudwanderer_boto3_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from unittest.mock import MagicMock

from cloudwanderer.aws_interface import CloudWandererBoto3Session, CloudWandererBoto3SessionGetterClientConfig


def test_get_account_id():
botocore_session = MagicMock()
subject = CloudWandererBoto3Session(
aws_access_key_id="A",
aws_secret_access_key="A",
aws_session_token="A",
botocore_session=botocore_session,
getter_client_config=CloudWandererBoto3SessionGetterClientConfig(
sts={"endpoint_url": "sts.eu-west-1.amazonaws.com"}
),
)

subject.get_account_id()

botocore_session.create_client.assert_called_with(
"sts",
region_name=None,
api_version=None,
use_ssl=True,
verify=None,
endpoint_url="sts.eu-west-1.amazonaws.com",
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
config=None,
)


def test_get_enabled_regions():
botocore_session = MagicMock()
subject = CloudWandererBoto3Session(
aws_access_key_id="A",
aws_secret_access_key="A",
aws_session_token="A",
botocore_session=botocore_session,
getter_client_config=CloudWandererBoto3SessionGetterClientConfig(
ec2={"endpoint_url": "ec2.eu-west-1.amazonaws.com"}
),
)

subject.get_enabled_regions()

botocore_session.create_client.assert_called_with(
"ec2",
region_name=None,
api_version=None,
use_ssl=True,
verify=None,
endpoint_url="ec2.eu-west-1.amazonaws.com",
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
config=None,
)

0 comments on commit 3ac1e6c

Please sign in to comment.