diff --git a/src/aws_encryption_sdk_cli/internal/master_key_parsing.py b/src/aws_encryption_sdk_cli/internal/master_key_parsing.py index 165d60b7..c273e68e 100644 --- a/src/aws_encryption_sdk_cli/internal/master_key_parsing.py +++ b/src/aws_encryption_sdk_cli/internal/master_key_parsing.py @@ -14,9 +14,9 @@ import copy import logging from collections import defaultdict +from importlib.metadata import EntryPoint, distributions import aws_encryption_sdk -import pkg_resources from aws_encryption_sdk import CachingCryptoMaterialsManager # noqa pylint: disable=unused-import from aws_encryption_sdk import DefaultCryptoMaterialsManager # noqa pylint: disable=unused-import from aws_encryption_sdk.key_providers.base import MasterKeyProvider # noqa pylint: disable=unused-import @@ -38,7 +38,7 @@ __all__ = ("build_crypto_materials_manager_from_args",) _LOGGER = logging.getLogger(LOGGER_NAME) -_ENTRY_POINTS = defaultdict(dict) # type: DefaultDict[str, Dict[str, pkg_resources.EntryPoint]] +_ENTRY_POINTS = defaultdict(dict) # type: DefaultDict[str, Dict[str, EntryPoint]] def _discover_entry_points(): @@ -46,33 +46,41 @@ def _discover_entry_points(): """Discover all registered entry points.""" _LOGGER.debug("Discovering master key provider plugins") - for entry_point in pkg_resources.iter_entry_points(MASTER_KEY_PROVIDERS_ENTRY_POINT): - _LOGGER.info('Collecting plugin "%s" registered by "%s"', entry_point.name, entry_point.dist) - _LOGGER.debug( - "Plugin details: %s", - dict( - name=entry_point.name, - module_name=entry_point.module_name, - attrs=entry_point.attrs, - extras=entry_point.extras, - dist=entry_point.dist, - ), - ) - - if PLUGIN_NAMESPACE_DIVIDER in entry_point.name: - _LOGGER.warning( - 'Invalid substring "%s" in discovered entry point "%s". It will not be usable.', - PLUGIN_NAMESPACE_DIVIDER, - entry_point.name, + for dist in distributions(): + dist_name = dist.metadata.get("Name") or dist.metadata.get("name") or "unknown" + + for entry_point in dist.entry_points: + if entry_point.group != MASTER_KEY_PROVIDERS_ENTRY_POINT: + continue + + # entry_point.value looks like "pkg.module:attr" + module_name, _, attr = entry_point.value.partition(":") + + _LOGGER.info('Collecting plugin "%s" registered by "%s"', entry_point.name, dist_name) + _LOGGER.debug( + "Plugin details: %s", + dict( + name=entry_point.name, + module_name=module_name, + attrs=[attr] if attr else [], + extras=getattr(entry_point, "extras", ()), + dist=dist_name, + ), ) - continue - # mypy has trouble with pkgs_resources.iter_entry_points members - _ENTRY_POINTS[entry_point.name][entry_point.dist.project_name] = entry_point # type: ignore + if PLUGIN_NAMESPACE_DIVIDER in entry_point.name: + _LOGGER.warning( + 'Invalid substring "%s" in discovered entry point "%s". It will not be usable.', + PLUGIN_NAMESPACE_DIVIDER, + entry_point.name, + ) + continue + + _ENTRY_POINTS[entry_point.name][dist_name] = entry_point def _entry_points(): - # type: () -> DefaultDict[str, Dict[str, pkg_resources.EntryPoint]] + # type: () -> DefaultDict[str, Dict[str, EntryPoint]] """Discover all entry points for required groups if they have not already been found. :returns: Mapping of group to name to entry points @@ -109,7 +117,7 @@ def _load_master_key_provider(name): raise BadUserArgumentError( "Multiple entry points discovered and no package specified. Packages discovered registered by: ({})".format( - ", ".join([str(entry.dist) for entry in entry_points.values()]) + ", ".join(entry_points.keys()) ) ) @@ -123,7 +131,7 @@ def _load_master_key_provider(name): ).format( requested=name, entry_point=entry_point_name, - discovered=", ".join([str(entry.dist) for entry in entry_points.values()]), + discovered=", ".join(entry_points.keys()), ) ) diff --git a/test/unit/internal/test_master_key_parsing.py b/test/unit/internal/test_master_key_parsing.py index f6cb8ffe..20cfd09d 100644 --- a/test/unit/internal/test_master_key_parsing.py +++ b/test/unit/internal/test_master_key_parsing.py @@ -59,9 +59,8 @@ def patch_aws_encryption_sdk(mocker): @pytest.fixture -def patch_iter_entry_points(mocker): - mocker.patch.object(master_key_parsing.pkg_resources, "iter_entry_points") - yield master_key_parsing.pkg_resources.iter_entry_points +def patch_distributions(mocker): + yield mocker.patch.object(master_key_parsing, "distributions") @pytest.fixture @@ -84,8 +83,10 @@ def entry_points_cleaner(): # "name" is a special, non-overridable attribute on mock objects -FakeEntryPoint = namedtuple("FakeEntryPoint", ["name", "module_name", "attrs", "extras", "dist"]) -FakeEntryPoint.__new__.__defaults__ = ("MODULE", "ATTRS", "EXTRAS", MagicMock(project_name="PROJECT")) +FakeEntryPoint = namedtuple("FakeEntryPoint", ["name", "value", "group", "extras"]) +FakeEntryPoint.__new__.__defaults__ = ("MODULE:ATTR", "GROUP", "EXTRAS") + +FakeDistribution = namedtuple("FakeDistribution", ["metadata", "entry_points"]) def test_entry_points(monkeypatch): @@ -99,8 +100,15 @@ def test_entry_points_aws_kms(): assert master_key_parsing._entry_points()["aws-kms"]["aws-encryption-sdk-cli"].load() is aws_kms_master_key_provider -def test_entry_points_invalid_substring(logger_stream, patch_iter_entry_points): - patch_iter_entry_points.return_value = [FakeEntryPoint("BAD::NAME")] +def test_entry_points_invalid_substring(logger_stream, patch_distributions): + bad_ep = FakeEntryPoint( + name="BAD::NAME", + value="module:attr", + group=master_key_parsing.MASTER_KEY_PROVIDERS_ENTRY_POINT, + ) + fake_dist = FakeDistribution(metadata={"Name": "fake-dist"}, entry_points=[bad_ep]) + patch_distributions.return_value = [fake_dist] + master_key_parsing._discover_entry_points() key = 'Invalid substring "::" in discovered entry point "BAD::NAME". It will not be usable.' @@ -109,11 +117,28 @@ def test_entry_points_invalid_substring(logger_stream, patch_iter_entry_points): assert "BAD::NAME" not in master_key_parsing._ENTRY_POINTS -def test_entry_points_multiple_per_name(entry_points_cleaner, patch_iter_entry_points): - entry_point_a = FakeEntryPoint(name="aws-kms", dist=MagicMock(project_name="aws-encryption-sdk-cli")) - entry_point_b = FakeEntryPoint(name="aws-kms", dist=MagicMock(project_name="some-other-thing")) - entry_point_c = FakeEntryPoint(name="zzz", dist=MagicMock(project_name="yet-another-thing")) - patch_iter_entry_points.return_value = [entry_point_a, entry_point_b, entry_point_c] +def test_entry_points_multiple_per_name(entry_points_cleaner, patch_distributions): + entry_point_a = FakeEntryPoint( + name="aws-kms", + value="module_a:attr", + group=master_key_parsing.MASTER_KEY_PROVIDERS_ENTRY_POINT, + ) + entry_point_b = FakeEntryPoint( + name="aws-kms", + value="module_b:attr", + group=master_key_parsing.MASTER_KEY_PROVIDERS_ENTRY_POINT, + ) + entry_point_c = FakeEntryPoint( + name="zzz", + value="module_c:attr", + group=master_key_parsing.MASTER_KEY_PROVIDERS_ENTRY_POINT, + ) + + dist_a = FakeDistribution(metadata={"Name": "aws-encryption-sdk-cli"}, entry_points=[entry_point_a]) + dist_b = FakeDistribution(metadata={"Name": "some-other-thing"}, entry_points=[entry_point_b]) + dist_c = FakeDistribution(metadata={"Name": "yet-another-thing"}, entry_points=[entry_point_c]) + + patch_distributions.return_value = [dist_a, dist_b, dist_c] test = master_key_parsing._entry_points() @@ -142,9 +167,15 @@ def test_load_master_key_provider_known_name_only_multiple_entry_points(monkeypa "aws-kms", { "aws-encryption-sdk-cli": FakeEntryPoint( - name="aws-kms", dist=MagicMock(project_name="aws-encryption-sdk-cli") + name="aws-kms", + value="module_a:attr", + group=master_key_parsing.MASTER_KEY_PROVIDERS_ENTRY_POINT, + ), + "my-fake-package": FakeEntryPoint( + name="aws-kms", + value="module_b:attr", + group=master_key_parsing.MASTER_KEY_PROVIDERS_ENTRY_POINT, ), - "my-fake-package": FakeEntryPoint(name="aws-kms", module_name="my-fake-package"), }, ) @@ -164,14 +195,21 @@ def test_load_master_key_provider_known_name_unknown_name(monkeypatch): monkeypatch.setitem( master_key_parsing._ENTRY_POINTS, "aws-kms", - {"my-fake-package": FakeEntryPoint(name="aws-kms", module_name="my-fake-package")}, + { + "my-fake-package": FakeEntryPoint( + name="aws-kms", + value="module_b:attr", + group=master_key_parsing.MASTER_KEY_PROVIDERS_ENTRY_POINT, + ) + }, ) with pytest.raises(BadUserArgumentError) as excinfo: master_key_parsing._load_master_key_provider("aws-encryption-sdk-cli::aws-kms") excinfo.match( - r'Requested master key provider not found: "aws-encryption-sdk-cli::aws-kms". Packages discovered for *' + r'Requested master key provider not found: "aws-encryption-sdk-cli::aws-kms". ' + r'Packages discovered for "aws-kms" registered by: .*' )