Skip to content

Commit

Permalink
Adding ChainBetwenDAGAndOperatorNotAllowedRule for easing upgrade to …
Browse files Browse the repository at this point in the history
…2.0 (#11839)

* Adding ChainBetwenDAGAndOperatorNotAllowedRule for checking upgrade to Airflow 2.0 (#11040)

* Cleaning up tests for #11839
  • Loading branch information
jmelot committed Oct 27, 2020
1 parent 9c68453 commit 6b7588f
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import

import re

from airflow import conf
from airflow.upgrade.rules.base_rule import BaseRule
from airflow.utils.dag_processing import list_py_file_paths


class ChainBetweenDAGAndOperatorNotAllowedRule(BaseRule):

title = "Chain between DAG and operator not allowed."

description = "Assigning task to a DAG using bitwise shift (bit-shift) operators are no longer supported."

def _change_info(self, file_path, line_number):
return "{} Affected file: {} (line {})".format(
self.title, file_path, line_number
)

def _check_file(self, file_path):
problems = []
with open(file_path, "r") as file_pointer:
lines = file_pointer.readlines()
python_space = r"\s*\\?\s*\n?\s*"
# Find all the dag variable names.
dag_vars = re.findall(r"([A-Za-z0-9_]+){}={}DAG\(".format(python_space, python_space),
"".join(lines))
history = ""
for line_number, line in enumerate(lines, 1):
# Someone could have put the bitshift operator on a different line than the dag they
# were using it on, so search for dag >> or << dag in all previous lines that did
# not contain a logged issue.
history += line
matches = [
re.search(r"DAG\([^\)]+\){}>>".format(python_space), history),
re.search(r"<<{}DAG\(".format(python_space), history)
]
for dag_var in dag_vars:
matches.extend([
re.search(r"(\s|^){}{}>>".format(dag_var, python_space), history),
re.search(r"<<\s*{}{}".format(python_space, dag_var), history),
])
if any(matches):
problems.append(self._change_info(file_path, line_number))
# If we found a problem, clear our history so we don't re-log the problem
# on the next line.
history = ""
return problems

def check(self):
dag_folder = conf.get("core", "dags_folder")
file_paths = list_py_file_paths(directory=dag_folder, include_examples=False)
problems = []
for file_path in file_paths:
problems.extend(self._check_file(file_path))
return problems
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from contextlib import contextmanager
from unittest import TestCase

from tempfile import NamedTemporaryFile
from tests.compat import mock

from airflow.upgrade.rules.chain_between_dag_and_operator_not_allowed_rule import \
ChainBetweenDAGAndOperatorNotAllowedRule


@contextmanager
def create_temp_file(mock_list_files, lines):
with NamedTemporaryFile("w+") as temp_file:
mock_list_files.return_value = [temp_file.name]
temp_file.writelines("\n".join(lines))
temp_file.flush()
yield temp_file


@mock.patch("airflow.upgrade.rules.chain_between_dag_and_operator_not_allowed_rule.list_py_file_paths")
class TestChainBetweenDAGAndOperatorNotAllowedRule(TestCase):
msg_template = "{} Affected file: {} (line {})"

def test_rule_metadata(self, _):
rule = ChainBetweenDAGAndOperatorNotAllowedRule()
assert isinstance(rule.description, str)
assert isinstance(rule.title, str)

def test_valid_check(self, mock_list_files):
lines = ["with DAG('my_dag') as dag:",
" dummy1 = DummyOperator(task_id='dummy1')",
" dummy2 = DummyOperator(task_id='dummy2')",
" dummy1 >> dummy2"]

with create_temp_file(mock_list_files, lines):
rule = ChainBetweenDAGAndOperatorNotAllowedRule()
msgs = rule.check()
assert 0 == len(msgs)

def test_invalid_check(self, mock_list_files):
lines = ["my_dag1 = DAG('my_dag')",
"dummy = DummyOperator(task_id='dummy')",
"my_dag1 >> dummy"]

with create_temp_file(mock_list_files, lines) as temp_file:
rule = ChainBetweenDAGAndOperatorNotAllowedRule()
msgs = rule.check()
expected_messages = [self.msg_template.format(rule.title, temp_file.name, 3)]
assert expected_messages == msgs

def test_invalid_check_no_var_rshift(self, mock_list_files):
lines = ["DAG('my_dag') >> DummyOperator(task_id='dummy')"]

with create_temp_file(mock_list_files, lines) as temp_file:
rule = ChainBetweenDAGAndOperatorNotAllowedRule()
msgs = rule.check()
expected_messages = [self.msg_template.format(rule.title, temp_file.name, 1)]
assert expected_messages == msgs

def test_invalid_check_no_var_lshift(self, mock_list_files):
lines = ["DummyOperator(",
"task_id='dummy') << DAG('my_dag')"]

with create_temp_file(mock_list_files, lines) as temp_file:
rule = ChainBetweenDAGAndOperatorNotAllowedRule()
msgs = rule.check()
expected_messages = [self.msg_template.format(rule.title, temp_file.name, 2)]
assert expected_messages == msgs

def test_invalid_check_multiline(self, mock_list_files):
lines = ["dag = \\",
" DAG('my_dag')",
"dummy = DummyOperator(task_id='dummy')",
"",
"dag >> \\",
"dummy"]

with create_temp_file(mock_list_files, lines) as temp_file:
rule = ChainBetweenDAGAndOperatorNotAllowedRule()
msgs = rule.check()
expected_messages = [self.msg_template.format(rule.title, temp_file.name, 5)]
assert expected_messages == msgs

def test_invalid_check_multiline_lshift(self, mock_list_files):
lines = ["dag = \\",
" DAG('my_dag')",
"dummy = DummyOperator(task_id='dummy')",
"",
"dummy << \\",
"dag"]

with create_temp_file(mock_list_files, lines) as temp_file:
rule = ChainBetweenDAGAndOperatorNotAllowedRule()
msgs = rule.check()
expected_messages = [self.msg_template.format(rule.title, temp_file.name, 6)]
assert expected_messages == msgs

0 comments on commit 6b7588f

Please sign in to comment.