Skip to content

Commit

Permalink
Replace pkg_resources with importlib.metadata to avoid VersionConflic…
Browse files Browse the repository at this point in the history
…t errors (#12694)

Using `pkg_resources.iter_entry_points` validates the version
constraints, and if any fail it will throw an Exception for that
entrypoint.

This sounds nice, but is a huge mis-feature.

So instead of that, switch to using importlib.metadata (well, it's
backport importlib_metadata) that just gives us the entrypoints - no
other verification of requirements is performed.

This has two advantages:

1. providers and plugins load much more reliably.
2. it's faster too

Closes #12692

(cherry picked from commit 7ef9aa7)
  • Loading branch information
ashb authored and kaxil committed Dec 3, 2020
1 parent fc6d0a8 commit b49838f
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 40 deletions.
39 changes: 31 additions & 8 deletions airflow/plugins_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@
import warnings
from typing import Any, Dict, List, Type

import pkg_resources
from six import with_metaclass

try:
import importlib.metadata as importlib_metadata
except ImportError:
import importlib_metadata

from airflow import settings
from airflow.models.baseoperator import BaseOperatorLink

Expand Down Expand Up @@ -109,6 +113,23 @@ def on_load(cls, *args, **kwargs):
"""


def entry_points_with_dist(group):
"""
Return EntryPoint objects of the given group, along with the distribution information.
This is like the ``entry_points()`` function from importlib.metadata,
except it also returns the distribution the entry_point was loaded from.
:param group: FIlter results to only this entrypoint group
:return: Generator of (EntryPoint, Distribution) objects for the specified groups
"""
for dist in importlib_metadata.distributions():
for e in dist.entry_points:
if e.group != group:
continue
yield (e, dist)


def load_entrypoint_plugins(entry_points, airflow_plugins):
"""
Load AirflowPlugin subclasses from the entrypoints
Expand All @@ -122,16 +143,18 @@ def load_entrypoint_plugins(entry_points, airflow_plugins):
:rtype: list[airflow.plugins_manager.AirflowPlugin]
"""
global import_errors # pylint: disable=global-statement
for entry_point in entry_points:
for entry_point, dist in entry_points:
log.debug('Importing entry_point plugin %s', entry_point.name)
try:
plugin_obj = entry_point.load()
plugin_obj.__usable_import_name = entry_point.module_name
if is_valid_plugin(plugin_obj, airflow_plugins):
if callable(getattr(plugin_obj, 'on_load', None)):
plugin_obj.on_load()
plugin_obj.__usable_import_name = entry_point.module
if not is_valid_plugin(plugin_obj, airflow_plugins):
continue

if callable(getattr(plugin_obj, 'on_load', None)):
plugin_obj.on_load()

airflow_plugins.append(plugin_obj)
airflow_plugins.append(plugin_obj)
except Exception as e: # pylint: disable=broad-except
log.exception("Failed to import plugin %s", entry_point.name)
import_errors[entry_point.module_name] = str(e)
Expand Down Expand Up @@ -204,7 +227,7 @@ def is_valid_plugin(plugin_obj, existing_plugins):
import_errors[filepath] = str(e)

plugins = load_entrypoint_plugins(
pkg_resources.iter_entry_points('airflow.plugins'),
entry_points_with_dist('airflow.plugins'),
plugins
)

Expand Down
56 changes: 30 additions & 26 deletions tests/plugins/test_plugins_manager_rbac.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,16 @@
from __future__ import print_function
from __future__ import unicode_literals

import unittest
import six
from tests.compat import mock
import logging

import pkg_resources
import pytest

from airflow.www_rbac import app as application
from tests.compat import mock


class PluginsTestRBAC(unittest.TestCase):
def setUp(self):
class TestPluginsRBAC(object):
def setup_method(self, method):
self.app, self.appbuilder = application.create_app(testing=True)

def test_flaskappbuilder_views(self):
Expand All @@ -41,18 +40,18 @@ def test_flaskappbuilder_views(self):
plugin_views = [view for view in self.appbuilder.baseviews
if view.blueprint.name == appbuilder_class_name]

self.assertTrue(len(plugin_views) == 1)
assert len(plugin_views) == 1

# view should have a menu item matching category of v_appbuilder_package
links = [menu_item for menu_item in self.appbuilder.menu.menu
if menu_item.name == v_appbuilder_package['category']]

self.assertTrue(len(links) == 1)
assert len(links) == 1

# menu link should also have a link matching the name of the package.
link = links[0]
self.assertEqual(link.name, v_appbuilder_package['category'])
self.assertEqual(link.childs[0].name, v_appbuilder_package['name'])
assert link.name == v_appbuilder_package['category']
assert link.childs[0].name == v_appbuilder_package['name']

def test_flaskappbuilder_menu_links(self):
from tests.plugins.test_plugin import appbuilder_mitem
Expand All @@ -61,40 +60,45 @@ def test_flaskappbuilder_menu_links(self):
links = [menu_item for menu_item in self.appbuilder.menu.menu
if menu_item.name == appbuilder_mitem['category']]

self.assertTrue(len(links) == 1)
assert len(links) == 1

# menu link should also have a link matching the name of the package.
link = links[0]
self.assertEqual(link.name, appbuilder_mitem['category'])
self.assertEqual(link.childs[0].name, appbuilder_mitem['name'])
assert link.name == appbuilder_mitem['category']
assert link.childs[0].name == appbuilder_mitem['name']

def test_app_blueprints(self):
from tests.plugins.test_plugin import bp

# Blueprint should be present in the app
self.assertTrue('test_plugin' in self.app.blueprints)
self.assertEqual(self.app.blueprints['test_plugin'].name, bp.name)
assert 'test_plugin' in self.app.blueprints
assert self.app.blueprints['test_plugin'].name == bp.name

@unittest.skipIf(six.PY2, 'self.assertLogs not available for Python 2')
@mock.patch('pkg_resources.iter_entry_points')
def test_entrypoint_plugin_errors_dont_raise_exceptions(self, mock_ep_plugins):
@pytest.mark.quarantined
def test_entrypoint_plugin_errors_dont_raise_exceptions(self, caplog):
"""
Test that Airflow does not raise an Error if there is any Exception because of the
Plugin.
"""
from airflow.plugins_manager import load_entrypoint_plugins, import_errors
from airflow.plugins_manager import import_errors, load_entrypoint_plugins, entry_points_with_dist

mock_dist = mock.Mock()

mock_entrypoint = mock.Mock()
mock_entrypoint.name = 'test-entrypoint'
mock_entrypoint.group = 'airflow.plugins'
mock_entrypoint.module_name = 'test.plugins.test_plugins_manager'
mock_entrypoint.load.side_effect = Exception('Version Conflict')
mock_ep_plugins.return_value = [mock_entrypoint]
mock_entrypoint.load.side_effect = ImportError('my_fake_module not found')
mock_dist.entry_points = [mock_entrypoint]

with mock.patch('importlib_metadata.distributions', return_value=[mock_dist]), caplog.at_level(
logging.ERROR, logger='airflow.plugins_manager'
):
load_entrypoint_plugins(entry_points_with_dist('airflow.plugins'), [])

with self.assertLogs("airflow.plugins_manager", level="ERROR") as log_output:
load_entrypoint_plugins(pkg_resources.iter_entry_points('airflow.plugins'), [])
received_logs = log_output.output[0]
received_logs = caplog.text
# Assert Traceback is shown too
assert "Traceback (most recent call last):" in received_logs
assert "Version Conflict" in received_logs
assert "my_fake_module not found" in received_logs
assert "Failed to import plugin test-entrypoint" in received_logs
assert ('test.plugins.test_plugins_manager', 'Version Conflict') in import_errors.items()
assert ("test.plugins.test_plugins_manager", "my_fake_module not found") in import_errors.items()
17 changes: 11 additions & 6 deletions tests/plugins/test_plugins_manager_www.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from __future__ import unicode_literals

import six
from mock import MagicMock, PropertyMock
from mock import MagicMock, Mock

from flask.blueprints import Blueprint
from flask_admin.menu import MenuLink, MenuView
Expand Down Expand Up @@ -119,11 +119,16 @@ def setUp(self):
]

def _build_mock(self, plugin_obj):
m = MagicMock(**{
'load.return_value': plugin_obj
})
type(m).name = PropertyMock(return_value='plugin-' + plugin_obj.name)
return m

mock_dist = Mock()

mock_entrypoint = Mock()
mock_entrypoint.name = 'plugin-' + plugin_obj.name
mock_entrypoint.group = 'airflow.plugins'
mock_entrypoint.load.return_value = plugin_obj
mock_dist.entry_points = [mock_entrypoint]

return (mock_entrypoint, mock_dist)

def test_load_entrypoint_plugins(self):
self.assertListEqual(
Expand Down

0 comments on commit b49838f

Please sign in to comment.