Skip to content
60 changes: 34 additions & 26 deletions src/aws_encryption_sdk_cli/internal/master_key_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
import copy
import logging
from collections import defaultdict
from importlib.metadata import EntryPoint, distributions

import aws_encryption_sdk
import pkg_resources
from aws_encryption_sdk import CachingCryptoMaterialsManager # noqa pylint: disable=unused-import
from aws_encryption_sdk import DefaultCryptoMaterialsManager # noqa pylint: disable=unused-import
from aws_encryption_sdk.key_providers.base import MasterKeyProvider # noqa pylint: disable=unused-import
Expand All @@ -38,41 +38,49 @@

__all__ = ("build_crypto_materials_manager_from_args",)
_LOGGER = logging.getLogger(LOGGER_NAME)
_ENTRY_POINTS = defaultdict(dict) # type: DefaultDict[str, Dict[str, pkg_resources.EntryPoint]]
_ENTRY_POINTS = defaultdict(dict) # type: DefaultDict[str, Dict[str, EntryPoint]]


def _discover_entry_points():
# type: () -> None
"""Discover all registered entry points."""
_LOGGER.debug("Discovering master key provider plugins")

for entry_point in pkg_resources.iter_entry_points(MASTER_KEY_PROVIDERS_ENTRY_POINT):
_LOGGER.info('Collecting plugin "%s" registered by "%s"', entry_point.name, entry_point.dist)
_LOGGER.debug(
"Plugin details: %s",
dict(
name=entry_point.name,
module_name=entry_point.module_name,
attrs=entry_point.attrs,
extras=entry_point.extras,
dist=entry_point.dist,
),
)

if PLUGIN_NAMESPACE_DIVIDER in entry_point.name:
_LOGGER.warning(
'Invalid substring "%s" in discovered entry point "%s". It will not be usable.',
PLUGIN_NAMESPACE_DIVIDER,
entry_point.name,
for dist in distributions():
dist_name = dist.metadata.get("Name") or dist.metadata.get("name") or "unknown"

for entry_point in dist.entry_points:
if entry_point.group != MASTER_KEY_PROVIDERS_ENTRY_POINT:
continue

# entry_point.value looks like "pkg.module:attr"
module_name, _, attr = entry_point.value.partition(":")

_LOGGER.info('Collecting plugin "%s" registered by "%s"', entry_point.name, dist_name)
_LOGGER.debug(
"Plugin details: %s",
dict(
name=entry_point.name,
module_name=module_name,
attrs=[attr] if attr else [],
extras=getattr(entry_point, "extras", ()),
dist=dist_name,
),
)
continue

# mypy has trouble with pkgs_resources.iter_entry_points members
_ENTRY_POINTS[entry_point.name][entry_point.dist.project_name] = entry_point # type: ignore
if PLUGIN_NAMESPACE_DIVIDER in entry_point.name:
_LOGGER.warning(
'Invalid substring "%s" in discovered entry point "%s". It will not be usable.',
PLUGIN_NAMESPACE_DIVIDER,
entry_point.name,
)
continue

_ENTRY_POINTS[entry_point.name][dist_name] = entry_point


def _entry_points():
# type: () -> DefaultDict[str, Dict[str, pkg_resources.EntryPoint]]
# type: () -> DefaultDict[str, Dict[str, EntryPoint]]
"""Discover all entry points for required groups if they have not already been found.

:returns: Mapping of group to name to entry points
Expand Down Expand Up @@ -109,7 +117,7 @@ def _load_master_key_provider(name):

raise BadUserArgumentError(
"Multiple entry points discovered and no package specified. Packages discovered registered by: ({})".format(
", ".join([str(entry.dist) for entry in entry_points.values()])
", ".join(entry_points.keys())
)
)

Expand All @@ -123,7 +131,7 @@ def _load_master_key_provider(name):
).format(
requested=name,
entry_point=entry_point_name,
discovered=", ".join([str(entry.dist) for entry in entry_points.values()]),
discovered=", ".join(entry_points.keys()),
)
)

Expand Down
70 changes: 54 additions & 16 deletions test/unit/internal/test_master_key_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,8 @@ def patch_aws_encryption_sdk(mocker):


@pytest.fixture
def patch_iter_entry_points(mocker):
mocker.patch.object(master_key_parsing.pkg_resources, "iter_entry_points")
yield master_key_parsing.pkg_resources.iter_entry_points
def patch_distributions(mocker):
yield mocker.patch.object(master_key_parsing, "distributions")


@pytest.fixture
Expand All @@ -84,8 +83,10 @@ def entry_points_cleaner():


# "name" is a special, non-overridable attribute on mock objects
FakeEntryPoint = namedtuple("FakeEntryPoint", ["name", "module_name", "attrs", "extras", "dist"])
FakeEntryPoint.__new__.__defaults__ = ("MODULE", "ATTRS", "EXTRAS", MagicMock(project_name="PROJECT"))
FakeEntryPoint = namedtuple("FakeEntryPoint", ["name", "value", "group", "extras"])
FakeEntryPoint.__new__.__defaults__ = ("MODULE:ATTR", "GROUP", "EXTRAS")

FakeDistribution = namedtuple("FakeDistribution", ["metadata", "entry_points"])


def test_entry_points(monkeypatch):
Expand All @@ -99,8 +100,15 @@ def test_entry_points_aws_kms():
assert master_key_parsing._entry_points()["aws-kms"]["aws-encryption-sdk-cli"].load() is aws_kms_master_key_provider


def test_entry_points_invalid_substring(logger_stream, patch_iter_entry_points):
patch_iter_entry_points.return_value = [FakeEntryPoint("BAD::NAME")]
def test_entry_points_invalid_substring(logger_stream, patch_distributions):
bad_ep = FakeEntryPoint(
name="BAD::NAME",
value="module:attr",
group=master_key_parsing.MASTER_KEY_PROVIDERS_ENTRY_POINT,
)
fake_dist = FakeDistribution(metadata={"Name": "fake-dist"}, entry_points=[bad_ep])
patch_distributions.return_value = [fake_dist]

master_key_parsing._discover_entry_points()

key = 'Invalid substring "::" in discovered entry point "BAD::NAME". It will not be usable.'
Expand All @@ -109,11 +117,28 @@ def test_entry_points_invalid_substring(logger_stream, patch_iter_entry_points):
assert "BAD::NAME" not in master_key_parsing._ENTRY_POINTS


def test_entry_points_multiple_per_name(entry_points_cleaner, patch_iter_entry_points):
entry_point_a = FakeEntryPoint(name="aws-kms", dist=MagicMock(project_name="aws-encryption-sdk-cli"))
entry_point_b = FakeEntryPoint(name="aws-kms", dist=MagicMock(project_name="some-other-thing"))
entry_point_c = FakeEntryPoint(name="zzz", dist=MagicMock(project_name="yet-another-thing"))
patch_iter_entry_points.return_value = [entry_point_a, entry_point_b, entry_point_c]
def test_entry_points_multiple_per_name(entry_points_cleaner, patch_distributions):
entry_point_a = FakeEntryPoint(
name="aws-kms",
value="module_a:attr",
group=master_key_parsing.MASTER_KEY_PROVIDERS_ENTRY_POINT,
)
entry_point_b = FakeEntryPoint(
name="aws-kms",
value="module_b:attr",
group=master_key_parsing.MASTER_KEY_PROVIDERS_ENTRY_POINT,
)
entry_point_c = FakeEntryPoint(
name="zzz",
value="module_c:attr",
group=master_key_parsing.MASTER_KEY_PROVIDERS_ENTRY_POINT,
)

dist_a = FakeDistribution(metadata={"Name": "aws-encryption-sdk-cli"}, entry_points=[entry_point_a])
dist_b = FakeDistribution(metadata={"Name": "some-other-thing"}, entry_points=[entry_point_b])
dist_c = FakeDistribution(metadata={"Name": "yet-another-thing"}, entry_points=[entry_point_c])

patch_distributions.return_value = [dist_a, dist_b, dist_c]

test = master_key_parsing._entry_points()

Expand Down Expand Up @@ -142,9 +167,15 @@ def test_load_master_key_provider_known_name_only_multiple_entry_points(monkeypa
"aws-kms",
{
"aws-encryption-sdk-cli": FakeEntryPoint(
name="aws-kms", dist=MagicMock(project_name="aws-encryption-sdk-cli")
name="aws-kms",
value="module_a:attr",
group=master_key_parsing.MASTER_KEY_PROVIDERS_ENTRY_POINT,
),
"my-fake-package": FakeEntryPoint(
name="aws-kms",
value="module_b:attr",
group=master_key_parsing.MASTER_KEY_PROVIDERS_ENTRY_POINT,
),
"my-fake-package": FakeEntryPoint(name="aws-kms", module_name="my-fake-package"),
},
)

Expand All @@ -164,14 +195,21 @@ def test_load_master_key_provider_known_name_unknown_name(monkeypatch):
monkeypatch.setitem(
master_key_parsing._ENTRY_POINTS,
"aws-kms",
{"my-fake-package": FakeEntryPoint(name="aws-kms", module_name="my-fake-package")},
{
"my-fake-package": FakeEntryPoint(
name="aws-kms",
value="module_b:attr",
group=master_key_parsing.MASTER_KEY_PROVIDERS_ENTRY_POINT,
)
},
)

with pytest.raises(BadUserArgumentError) as excinfo:
master_key_parsing._load_master_key_provider("aws-encryption-sdk-cli::aws-kms")

excinfo.match(
r'Requested master key provider not found: "aws-encryption-sdk-cli::aws-kms". Packages discovered for *'
r'Requested master key provider not found: "aws-encryption-sdk-cli::aws-kms". '
r'Packages discovered for "aws-kms" registered by: .*'
)


Expand Down
Loading