Skip to content

Commit

Permalink
Remove odbc dependency in microsoft.mssql provider (#15594)
Browse files Browse the repository at this point in the history
  • Loading branch information
dstandish committed Apr 29, 2021
1 parent ff4b7c4 commit 5045419
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 32 deletions.
1 change: 0 additions & 1 deletion CONTRIBUTING.rst
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,6 @@ discord http
google amazon,apache.beam,apache.cassandra,cncf.kubernetes,facebook,microsoft.azure,microsoft.mssql,mysql,oracle,postgres,presto,salesforce,sftp,ssh,trino
hashicorp google
microsoft.azure google,oracle
microsoft.mssql odbc
mysql amazon,presto,trino,vertica
opsgenie http
postgres amazon
Expand Down
3 changes: 0 additions & 3 deletions airflow/providers/dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@
"google",
"oracle"
],
"microsoft.mssql": [
"odbc"
],
"mysql": [
"amazon",
"presto",
Expand Down
17 changes: 9 additions & 8 deletions airflow/providers/microsoft/mssql/operators/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Iterable, Mapping, Optional, Union
from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Union

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
from airflow.providers.odbc.hooks.odbc import OdbcHook
from airflow.utils.decorators import apply_defaults

if TYPE_CHECKING:
from airflow.hooks.dbapi import DbApiHook


class MsSqlOperator(BaseOperator):
"""
Expand Down Expand Up @@ -68,15 +70,14 @@ def __init__(
self.parameters = parameters
self.autocommit = autocommit
self.database = database
self._hook: Optional[Union[MsSqlHook, OdbcHook]] = None
self._hook: Optional[Union[MsSqlHook, 'DbApiHook']] = None

def get_hook(self) -> Optional[Union[MsSqlHook, OdbcHook]]:
def get_hook(self) -> Optional[Union[MsSqlHook, 'DbApiHook']]:
"""
Will retrieve hook as determined by Connection.
Will retrieve hook as determined by :meth:`~.Connection.get_hook` if one is defined, and
:class:`~.MsSqlHook` otherwise.
If conn_type is ``'odbc'``, will use
:py:class:`~airflow.providers.odbc.hooks.odbc.OdbcHook`.
Otherwise, :py:class:`~airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook` will be used.
For example, if the connection ``conn_type`` is ``'odbc'``, :class:`~.OdbcHook` will be used.
"""
if not self._hook:
conn = MsSqlHook.get_connection(conn_id=self.mssql_conn_id)
Expand Down
47 changes: 27 additions & 20 deletions tests/providers/microsoft/mssql/operators/test_mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,41 @@

import unittest
from unittest import mock
from unittest.mock import MagicMock, Mock

from airflow import PY38
from airflow.models import Connection
from airflow.providers.odbc.hooks.odbc import OdbcHook
from airflow import PY38, AirflowException

if not PY38:
from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
from airflow.providers.microsoft.mssql.operators.mssql import MsSqlOperator

ODBC_CONN = Connection(
conn_id='test-odbc',
conn_type='odbc',
)
PYMSSQL_CONN = Connection(
conn_id='test-pymssql',
conn_type='anything',
)


class TestMsSqlOperator:
@unittest.skipIf(PY38, "Mssql package not available when Python >= 3.8.")
@mock.patch('airflow.hooks.base.BaseHook.get_connection')
def test_get_hook(self, get_connection):
def test_get_hook_from_conn(self, get_connection):
"""
:class:`~.MsSqlOperator` should use the hook returned by :meth:`airflow.models.Connection.get_hook`
if one is returned.
This behavior is necessary in order to support usage of :class:`~.OdbcHook` with this operator.
Specifically we verify here that :meth:`~.MsSqlOperator.get_hook` returns the hook returned from a
call of ``get_hook`` on the object returned from :meth:`~.BaseHook.get_connection`.
"""
mock_hook = MagicMock()
get_connection.return_value.get_hook.return_value = mock_hook

op = MsSqlOperator(task_id='test', sql='')
assert op.get_hook() == mock_hook

@unittest.skipIf(PY38, "Mssql package not available when Python >= 3.8.")
@mock.patch('airflow.hooks.base.BaseHook.get_connection')
def test_get_hook_default(self, get_connection):
"""
Operator should use odbc hook if conn type is ``odbc`` and pymssql-based hook otherwise.
If :meth:`airflow.models.Connection.get_hook` does not return a hook (e.g. because of an invalid
conn type), then :class:`~.MsSqlHook` should be used.
"""
for conn, hook_class in [(ODBC_CONN, OdbcHook), (PYMSSQL_CONN, MsSqlHook)]:
get_connection.return_value = conn
op = MsSqlOperator(task_id='test', sql='', mssql_conn_id=conn.conn_id)
hook = op.get_hook()
assert hook.__class__ == hook_class
get_connection.return_value.get_hook.side_effect = Mock(side_effect=AirflowException())

op = MsSqlOperator(task_id='test', sql='')
assert op.get_hook().__class__.__name__ == 'MsSqlHook'

0 comments on commit 5045419

Please sign in to comment.