Skip to content

Commit

Permalink
Check that needed (backport) providers are installed (#12902)
Browse files Browse the repository at this point in the history
It's all very well telling people they need to use a new name, but if
that results in them getting an ImportError they aren't going to be very
happy.

It is also aware of backport vs not for the brave souls who might
upgrade to 2.0.0 _then_ run this via `python -m airflow.upgrade.checker`
  • Loading branch information
ashb committed Dec 11, 2020
1 parent e67b03b commit 9b1759c
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 17 deletions.
61 changes: 51 additions & 10 deletions airflow/upgrade/rules/import_changes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,31 @@
# specific language governing permissions and limitations
# under the License.

import itertools
from typing import NamedTuple, Optional, List

from cached_property import cached_property
from packaging.version import Version

from airflow import conf
from airflow.upgrade.rules.base_rule import BaseRule
from airflow.upgrade.rules.renamed_classes import ALL
from airflow.utils.dag_processing import list_py_file_paths

try:
from importlib_metadata import PackageNotFoundError, distribution
except ImportError:
from importlib.metadata import PackageNotFoundError, distribution


class ImportChange(
NamedTuple(
"ImportChange",
[("old_path", str), ("new_path", str), ("providers_package", Optional[None])],
[("old_path", str), ("new_path", str), ("providers_package", Optional[str])],
)
):
def info(self, file_path=None):
msg = "Using `{}` will be replaced by `{}`".format(self.old_path, self.new_path)
if self.providers_package:
msg += " and requires `{}` providers package".format(
self.providers_package
)
msg = "Using `{}` should be replaced by `{}`".format(self.old_path, self.new_path)
if file_path:
msg += ". Affected file: {}".format(file_path)
return msg
Expand All @@ -49,9 +52,20 @@ def old_class(self):
def new_class(self):
return self.new_path.split(".")[-1]

@classmethod
def provider_stub_from_module(cls, module):
if "providers" not in module:
return None

# [2:] strips off the airflow.providers. part
parts = module.split(".")[2:]
if parts[0] in ('apache', 'cncf', 'microsoft'):
return '-'.join(parts[:2])
return parts[0]

@classmethod
def from_new_old_paths(cls, new_path, old_path):
providers_package = new_path.split(".")[2] if "providers" in new_path else None
providers_package = cls.provider_stub_from_module(new_path)
return cls(
old_path=old_path, new_path=new_path, providers_package=providers_package
)
Expand All @@ -73,17 +87,44 @@ class ImportChangesRule(BaseRule):
@staticmethod
def _check_file(file_path):
problems = []
providers = set()
with open(file_path, "r") as file:
content = file.read()
for change in ImportChangesRule.ALL_CHANGES:
if change.old_class in content:
problems.append(change.info(file_path))
return problems
if change.providers_package:
providers.add(change.providers_package)
return problems, providers

@staticmethod
def _check_missing_providers(providers):

current_airflow_version = Version(__import__("airflow").__version__)
if current_airflow_version.major >= 2:
prefix = "apache-airflow-providers-"
else:
prefix = "apache-airflow-backport-providers-"

for provider in providers:
dist_name = prefix + provider
try:
distribution(dist_name)
except PackageNotFoundError:
yield "Please install `{}`".format(dist_name)

def check(self):
dag_folder = conf.get("core", "dags_folder")
files = list_py_file_paths(directory=dag_folder, include_examples=False)
problems = []
providers = set()
# Split in to two groups - install backports first, then make changes
for file in files:
problems.extend(self._check_file(file))
return problems
new_problems, new_providers = self._check_file(file)
problems.extend(new_problems)
providers |= new_providers

return itertools.chain(
self._check_missing_providers(sorted(providers)),
problems,
)
13 changes: 6 additions & 7 deletions tests/upgrade/rules/test_import_changes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,14 @@ def test_info(self):
)
assert change.info(
"file.py"
) == "Using `{}` will be replaced by `{}` and requires `{}` providers package. " \
"Affected file: file.py".format(OLD_PATH, NEW_PATH, PROVIDER)
) == "Using `{}` should be replaced by `{}`. Affected file: file.py".format(OLD_PATH, NEW_PATH)
assert change.old_class == OLD_CLASS
assert change.new_class == NEW_CLASS

def test_from_new_old_paths(self):
paths_tuple = (NEW_PATH, OLD_PATH)
change = ImportChange.from_new_old_paths(*paths_tuple)
assert change.info() == "Using `{}` will be replaced by `{}` and requires `{}` " \
"providers package".format(OLD_PATH, NEW_PATH, PROVIDER)
assert change.info() == "Using `{}` should be replaced by `{}`".format(OLD_PATH, NEW_PATH)


class TestImportChangesRule:
Expand All @@ -58,11 +56,12 @@ def test_check(self, mock_list_files):

temp.write("from airflow.contrib import %s" % OLD_CLASS)
temp.flush()
msgs = ImportChangesRule().check()
msgs = list(ImportChangesRule().check())

assert len(msgs) == 1
assert len(msgs) == 2
msg = msgs[0]
assert msg == 'Please install `apache-airflow-backport-providers-dummy`'
msg = msgs[1]
assert temp.name in msg
assert OLD_PATH in msg
assert OLD_CLASS in msg
assert "requires `{}`".format(PROVIDER) in msg

0 comments on commit 9b1759c

Please sign in to comment.