Skip to content

Commit

Permalink
Add DBApiHook check for 2.0 migration (#12730)
Browse files Browse the repository at this point in the history
* Add DBApiHook check for 2.0 migration

Adds a check that ensures that any hook that uses the
run, get_pandas_df or get_records functions does not import from the
base_hook

* exception for grpc_hook

* fix plugin

* fix plugin

* fix plugin

* py2 compliance and add full lineage

* black

* fix
  • Loading branch information
dimberman committed Dec 11, 2020
1 parent 539be00 commit 2e1f813
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 0 deletions.
97 changes: 97 additions & 0 deletions airflow/upgrade/rules/db_api_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from airflow.hooks.base_hook import BaseHook
from airflow.upgrade.rules.base_rule import BaseRule


def check_get_pandas_df(cls):
try:
cls.__new__(cls).get_pandas_df("fake SQL")
return return_error_string(cls, "get_pandas_df")
except NotImplementedError:
pass
except Exception:
return return_error_string(cls, "get_pandas_df")


def check_run(cls):
try:
cls.__new__(cls).run("fake SQL")
return return_error_string(cls, "run")
except NotImplementedError:
pass
except Exception:
return return_error_string(cls, "run")


def check_get_records(cls):
try:
cls.__new__(cls).get_records("fake SQL")
return return_error_string(cls, "get_records")
except NotImplementedError:
pass
except Exception:
return return_error_string(cls, "get_records")


def return_error_string(cls, method):
return (
"Class {} incorrectly implements the function {} while inheriting from BaseHook. "
"Please make this class inherit from airflow.hooks.db_api_hook.DbApiHook instead".format(
cls, method
)
)


def get_all_non_dbapi_children():
basehook_children = [
child for child in BaseHook.__subclasses__() if child.__name__ != "DbApiHook"
]
res = basehook_children[:]
while basehook_children:
next_generation = []
for child in basehook_children:
subclasses = child.__subclasses__()
if subclasses:
next_generation.extend(subclasses)
res.extend(next_generation)
basehook_children = next_generation
return res


class DbApiRule(BaseRule):
title = "Hooks that run DB functions must inherit from DBApiHook"

description = (
"Hooks that run DB functions must inherit from DBApiHook instead of BaseHook"
)

def check(self):
basehook_subclasses = get_all_non_dbapi_children()
incorrect_implementations = []
for child in basehook_subclasses:
pandas_df = check_get_pandas_df(child)
if pandas_df:
incorrect_implementations.append(pandas_df)
run = check_run(child)
if run:
incorrect_implementations.append(run)
get_records = check_get_records(child)
if get_records:
incorrect_implementations.append(get_records)
return incorrect_implementations
71 changes: 71 additions & 0 deletions tests/upgrade/rules/test_db_api_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest import TestCase

from airflow.hooks.base_hook import BaseHook
from airflow.hooks.dbapi_hook import DbApiHook
from airflow.upgrade.rules.db_api_functions import DbApiRule


class MyHook(BaseHook):
def run(self, sql):
pass

def get_pandas_df(self, sql):
pass

def get_conn(self):
pass


class GrandChildHook(MyHook):
def __init__(self, foo, bar):
self.foo = foo
self.bar = bar

def get_records(self, sql):
pass


class ProperDbApiHook(DbApiHook):
def bulk_dump(self, table, tmp_file):
pass

def bulk_load(self, table, tmp_file):
pass

def get_records(self, sql, *kwargs):
pass

def run(self, sql, *kwargs):
pass

def get_pandas_df(self, sql, *kwargs):
pass


class TestSqlHookCheck(TestCase):
def test_fails_on_incorrect_hook(self):
db_api_rule_failures = DbApiRule().check()
myhook_errors = [d for d in db_api_rule_failures if "MyHook" in d]
grandchild_errors = [d for d in db_api_rule_failures if "GrandChild" in d]
self.assertEqual(len(myhook_errors), 2)
self.assertEqual(len(grandchild_errors), 3)
proper_db_api_hook_failures = [
failure for failure in db_api_rule_failures if "ProperDbApiHook" in failure
]
self.assertEqual(len(proper_db_api_hook_failures), 0)

0 comments on commit 2e1f813

Please sign in to comment.