Skip to content

Commit

Permalink
Restrict direct usage of driver params via extras for JDBC connection (
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajkoti committed Jun 12, 2023
1 parent a906d73 commit 0edbe91
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 74 deletions.
14 changes: 14 additions & 0 deletions airflow/providers/jdbc/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@
Changelog
---------

4.0.0
.....

Breaking changes
~~~~~~~~~~~~~~~~

To configure driver parameters (driver path and driver class), you can use the following methods:

1. Supply them as constructor arguments when instantiating the hook.
2. Set the "driver_path" and/or "driver_class" parameters in the "hook_params" dictionary when creating the hook using SQL operators.
3. Set the "driver_path" and/or "driver_class" extra in the connection and correspondingly enable the "allow_driver_path_in_extra" and/or "allow_driver_class_in_extra" options in the "providers.jdbc" section of the Airflow configuration.
4. Patch the "JdbcHook.default_driver_path" and/or "JdbcHook.default_driver_class" values in the "local_settings.py" file.


3.4.0
.....

Expand Down
111 changes: 81 additions & 30 deletions airflow/providers/jdbc/hooks/jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,25 @@ class JdbcHook(DbApiHook):
JDBC URL, username and password will be taken from the predefined connection.
Note that the whole JDBC URL must be specified in the "host" field in the DB.
Raises an airflow error if the given connection id doesn't exist.
To configure driver parameters, you can use the following methods:
1. Supply them as constructor arguments when instantiating the hook.
2. Set the "driver_path" and/or "driver_class" parameters in the "hook_params" dictionary when
creating the hook using SQL operators.
3. Set the "driver_path" and/or "driver_class" extra in the connection and correspondingly enable
the "allow_driver_path_in_extra" and/or "allow_driver_class_in_extra" options in the
"providers.jdbc" section of the Airflow configuration. If you're enabling these options in Airflow
configuration, you should make sure that you trust the users who can edit connections in the UI
to not use it maliciously.
4. Patch the ``JdbcHook.default_driver_path`` and/or ``JdbcHook.default_driver_class`` values in the
"local_settings.py" file.
See :doc:`/connections/jdbc` for full documentation.
:param args: passed to DbApiHook
:param driver_path: path to the JDBC driver jar file. See above for more info
:param driver_class: name of the JDBC driver class. See above for more info
:param kwargs: passed to DbApiHook
"""

conn_name_attr = "jdbc_conn_id"
Expand All @@ -39,57 +58,89 @@ class JdbcHook(DbApiHook):
hook_name = "JDBC Connection"
supports_autocommit = True

@staticmethod
def get_connection_form_widgets() -> dict[str, Any]:
"""Get connection widgets to add to connection form."""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from flask_babel import lazy_gettext
from wtforms import StringField
default_driver_path: str | None = None
default_driver_class: str | None = None

return {
"drv_path": StringField(lazy_gettext("Driver Path"), widget=BS3TextFieldWidget()),
"drv_clsname": StringField(lazy_gettext("Driver Class"), widget=BS3TextFieldWidget()),
}
def __init__(
self,
*args,
driver_path: str | None = None,
driver_class: str | None = None,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self._driver_path = driver_path
self._driver_class = driver_class

@staticmethod
def get_ui_field_behaviour() -> dict[str, Any]:
"""Get custom field behaviour."""
return {
"hidden_fields": ["port", "schema", "extra"],
"hidden_fields": ["port", "schema"],
"relabeling": {"host": "Connection URL"},
}

def _get_field(self, extras: dict, field_name: str):
"""Get field from extra.
@property
def connection_extra_lower(self) -> dict:
"""
``connection.extra_dejson`` but where keys are converted to lower case.
This first checks the short name, then check for prefixed name for
backward compatibility.
This is used internally for case-insensitive access of jdbc params.
"""
backcompat_prefix = "extra__jdbc__"
if field_name.startswith("extra__"):
raise ValueError(
f"Got prefixed name {field_name}; please remove the '{backcompat_prefix}' prefix "
"when using this method."
)
if field_name in extras:
return extras[field_name] or None
prefixed_name = f"{backcompat_prefix}{field_name}"
return extras.get(prefixed_name) or None
conn = self.get_connection(getattr(self, self.conn_name_attr))
return {k.lower(): v for k, v in conn.extra_dejson.items()}

@property
def driver_path(self) -> str | None:
from airflow.configuration import conf

extra_driver_path = self.connection_extra_lower.get("driver_path")
if extra_driver_path:
if conf.getboolean("providers.jdbc", "allow_driver_path_in_extra", fallback=False):
self._driver_path = extra_driver_path
else:
self.log.warning(
"You have supplied 'driver_path' via connection extra but it will not be used. In order "
"to use 'driver_path' from extra you must set airflow config setting "
"`allow_driver_path_in_extra = True` in section `providers.jdbc`. Alternatively you may "
"specify it via 'driver_path' parameter of the hook constructor or via 'hook_params' "
"dictionary with key 'driver_path' if using SQL operators."
)
if not self._driver_path:
self._driver_path = self.default_driver_path
return self._driver_path

@property
def driver_class(self) -> str | None:
from airflow.configuration import conf

extra_driver_class = self.connection_extra_lower.get("driver_class")
if extra_driver_class:
if conf.getboolean("providers.jdbc", "allow_driver_class_in_extra", fallback=False):
self._driver_class = extra_driver_class
else:
self.log.warning(
"You have supplied 'driver_class' via connection extra but it will not be used. In order "
"to use 'driver_class' from extra you must set airflow config setting "
"`allow_driver_class_in_extra = True` in section `providers.jdbc`. Alternatively you may "
"specify it via 'driver_class' parameter of the hook constructor or via 'hook_params' "
"dictionary with key 'driver_class' if using SQL operators."
)
if not self._driver_class:
self._driver_class = self.default_driver_class
return self._driver_class

def get_conn(self) -> jaydebeapi.Connection:
conn: Connection = self.get_connection(getattr(self, self.conn_name_attr))
extras = conn.extra_dejson
host: str = conn.host
login: str = conn.login
psw: str = conn.password
jdbc_driver_loc: str | None = self._get_field(extras, "drv_path")
jdbc_driver_name: str | None = self._get_field(extras, "drv_clsname")

conn = jaydebeapi.connect(
jclassname=jdbc_driver_name,
jclassname=self.driver_class,
url=str(host),
driver_args=[str(login), str(psw)],
jars=jdbc_driver_loc.split(",") if jdbc_driver_loc else None,
jars=self.driver_path.split(",") if self.driver_path else None,
)
return conn

Expand Down
1 change: 1 addition & 0 deletions airflow/providers/jdbc/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ description: |
suspended: false
versions:
- 4.0.0
- 3.4.0
- 3.3.0
- 3.2.1
Expand Down
19 changes: 16 additions & 3 deletions docs/apache-airflow-providers-jdbc/connections/jdbc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,19 @@ Port (optional)
Extra (optional)
Specify the extra parameters (as json dictionary) that can be used in JDBC connection. The following parameters out of the standard python parameters are supported:

* ``conn_prefix`` - Used to build the connection url in ``JdbcOperator``, added in front of host (``conn_prefix`` ``host`` [: ``port`` ] / ``schema``)
* ``drv_clsname`` - Full qualified Java class name of the JDBC driver. For ``JdbcOperator``.
* ``drv_path`` - Jar filename or sequence of filenames for the JDBC driver libs. For ``JdbcOperator``.
- ``driver_class``
* Full qualified Java class name of the JDBC driver. For ``JdbcOperator``.
Note that this is only considered if ``allow_driver_class_in_extra`` is set to True in airflow config section
``providers.jdbc`` (by default it is not considered). Note: if setting this config from env vars, use
``AIRFLOW__PROVIDERS_JDBC__ALLOW_DRIVER_CLASS_IN_EXTRA=true``.

- ``driver_path``
* Jar filename or sequence of filenames for the JDBC driver libs. For ``JdbcOperator``.
Note that this is only considered if ``allow_driver_path_in_extra`` is set to True in airflow config section
``providers.jdbc`` (by default it is not considered). Note: if setting this config from env vars, use
``AIRFLOW__PROVIDERS_JDBC__ALLOW_DRIVER_PATH_IN_EXTRA=true``.

.. note::
Setting ``allow_driver_path_in_extra`` or ``allow_driver_class_in_extra`` to True allows users to set the driver
via the Airflow Connection's ``extra`` field. By default this is not allowed. If enabling this functionality,
you should make sure that you trust the users who can edit connections in the UI to not use it maliciously.
132 changes: 91 additions & 41 deletions tests/providers/jdbc/hooks/test_jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,33 @@
from __future__ import annotations

import json
import os
import logging
from unittest import mock
from unittest.mock import Mock, patch

import pytest
from pytest import param

from airflow.models import Connection
from airflow.providers.jdbc.hooks.jdbc import JdbcHook
from airflow.utils import db

jdbc_conn_mock = Mock(name="jdbc_conn")


def get_hook(hook_params=None, conn_params=None):
hook_params = hook_params or {}
conn_params = conn_params or {}
connection = Connection(
**{
**dict(login="login", password="password", host="host", schema="schema", port=1234),
**conn_params,
}
)

hook = JdbcHook(**hook_params)
hook.get_connection = Mock()
hook.get_connection.return_value = connection
return hook


class TestJdbcHook:
def setup_method(self):
db.merge_conn(
Expand All @@ -41,8 +55,8 @@ def setup_method(self):
port=443,
extra=json.dumps(
{
"extra__jdbc__drv_path": "/path1/test.jar,/path2/t.jar2",
"extra__jdbc__drv_clsname": "com.driver.main",
"driver_path": "/path1/test.jar,/path2/t.jar2",
"driver_class": "com.driver.main",
}
),
)
Expand Down Expand Up @@ -70,39 +84,75 @@ def test_jdbc_conn_get_autocommit(self, _):
jdbc_hook.get_autocommit(jdbc_conn)
jdbc_conn.jconn.getAutoCommit.assert_called_once_with()

@pytest.mark.parametrize(
"uri",
[
param(
"a://?extra__jdbc__drv_path=abc&extra__jdbc__drv_clsname=abc",
id="prefix",
),
param("a://?drv_path=abc&drv_clsname=abc", id="no-prefix"),
],
)
@patch("airflow.providers.jdbc.hooks.jdbc.jaydebeapi.connect")
def test_backcompat_prefix_works(self, mock_connect, uri):
with patch.dict(os.environ, {"AIRFLOW_CONN_MY_CONN": uri}):
hook = JdbcHook("my_conn")
hook.get_conn()
mock_connect.assert_called_with(
jclassname="abc",
url="",
driver_args=["None", "None"],
jars="abc".split(","),
)
def test_driver_hook_params(self):
hook = get_hook(hook_params=dict(driver_path="Blah driver path", driver_class="Blah driver class"))
assert hook.driver_path == "Blah driver path"
assert hook.driver_class == "Blah driver class"

@patch("airflow.providers.jdbc.hooks.jdbc.jaydebeapi.connect")
def test_backcompat_prefix_both_prefers_short(self, mock_connect):
with patch.dict(
os.environ,
{"AIRFLOW_CONN_MY_CONN": "a://?drv_path=non-prefixed&extra__jdbc__drv_path=prefixed"},
):
hook = JdbcHook("my_conn")
hook.get_conn()
mock_connect.assert_called_with(
jclassname=None,
url="",
driver_args=["None", "None"],
jars="non-prefixed".split(","),
)
def test_driver_in_extra_not_used(self):
conn_params = dict(
extra=json.dumps(dict(driver_path="ExtraDriverPath", driver_class="ExtraDriverClass"))
)
hook_params = {"driver_path": "ParamDriverPath", "driver_class": "ParamDriverClass"}
hook = get_hook(conn_params=conn_params, hook_params=hook_params)
assert hook.driver_path == "ParamDriverPath"
assert hook.driver_class == "ParamDriverClass"

def test_driver_extra_raises_warning_by_default(self, caplog):
with caplog.at_level(logging.WARNING, logger="airflow.providers.jdbc.hooks.test_jdbc"):
driver_path = get_hook(conn_params=dict(extra='{"driver_path": "Blah driver path"}')).driver_path
assert (
"You have supplied 'driver_path' via connection extra but it will not be used"
) in caplog.text
assert driver_path is None

driver_class = get_hook(
conn_params=dict(extra='{"driver_class": "Blah driver class"}')
).driver_class
assert (
"You have supplied 'driver_class' via connection extra but it will not be used"
) in caplog.text
assert driver_class is None

@mock.patch.dict("os.environ", {"AIRFLOW__PROVIDERS_JDBC__ALLOW_DRIVER_PATH_IN_EXTRA": "TRUE"})
@mock.patch.dict("os.environ", {"AIRFLOW__PROVIDERS_JDBC__ALLOW_DRIVER_CLASS_IN_EXTRA": "TRUE"})
def test_driver_extra_works_when_allow_driver_extra(self):
hook = get_hook(
conn_params=dict(extra='{"driver_path": "Blah driver path", "driver_class": "Blah driver class"}')
)
assert hook.driver_path == "Blah driver path"
assert hook.driver_class == "Blah driver class"

def test_default_driver_set(self):
with patch.object(JdbcHook, "default_driver_path", "Blah driver path") as _, patch.object(
JdbcHook, "default_driver_class", "Blah driver class"
) as _:
hook = get_hook()
assert hook.driver_path == "Blah driver path"
assert hook.driver_class == "Blah driver class"

def test_driver_none_by_default(self):
hook = get_hook()
assert hook.driver_path is None
assert hook.driver_class is None

def test_driver_extra_raises_warning_and_returns_default_driver_by_default(self, caplog):
with patch.object(JdbcHook, "default_driver_path", "Blah driver path"):
with caplog.at_level(logging.WARNING, logger="airflow.providers.jdbc.hooks.test_jdbc"):
driver_path = get_hook(
conn_params=dict(extra='{"driver_path": "Blah driver path2"}')
).driver_path
assert (
"have supplied 'driver_path' via connection extra but it will not be used"
) in caplog.text
assert driver_path == "Blah driver path"

with patch.object(JdbcHook, "default_driver_class", "Blah driver class"):
with caplog.at_level(logging.WARNING, logger="airflow.providers.jdbc.hooks.test_jdbc"):
driver_class = get_hook(
conn_params=dict(extra='{"driver_class": "Blah driver class2"}')
).driver_class
assert (
"have supplied 'driver_class' via connection extra but it will not be used"
) in caplog.text
assert driver_class == "Blah driver class"

0 comments on commit 0edbe91

Please sign in to comment.