diff --git a/airflow/providers/apache/druid/hooks/druid.py b/airflow/providers/apache/druid/hooks/druid.py index 7708684e60cf5..fb219fdebe25e 100644 --- a/airflow/providers/apache/druid/hooks/druid.py +++ b/airflow/providers/apache/druid/hooks/druid.py @@ -156,6 +156,10 @@ class DruidDbApiHook(DbApiHook): This hook is purely for users to query druid broker. For ingestion, please use druidHook. + + :param context: Optional query context parameters to pass to the SQL endpoint. + Example: ``{"sqlFinalizeOuterSketches": True}`` + See: https://druid.apache.org/docs/latest/querying/sql-query-context/ """ conn_name_attr = "druid_broker_conn_id" @@ -164,6 +168,10 @@ class DruidDbApiHook(DbApiHook): hook_name = "Druid" supports_autocommit = False + def __init__(self, context: dict | None = None, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.context = context or {} + def get_conn(self) -> connect: """Establish a connection to druid broker.""" conn = self.get_connection(getattr(self, self.conn_name_attr)) @@ -174,6 +182,7 @@ def get_conn(self) -> connect: scheme=conn.extra_dejson.get("schema", "http"), user=conn.login, password=conn.password, + context=self.context, ) self.log.info("Get the connection to druid broker on %s using user %s", conn.host, conn.login) return druid_broker_conn diff --git a/tests/providers/apache/druid/hooks/test_druid.py b/tests/providers/apache/druid/hooks/test_druid.py index 38f26f06cc380..04e68f765af7d 100644 --- a/tests/providers/apache/druid/hooks/test_druid.py +++ b/tests/providers/apache/druid/hooks/test_druid.py @@ -233,6 +233,38 @@ def get_connection(self, conn_id): self.db_hook = TestDruidDBApiHook + @patch("airflow.providers.apache.druid.hooks.druid.DruidDbApiHook.get_connection") + @patch("airflow.providers.apache.druid.hooks.druid.connect") + @pytest.mark.parametrize( + ("specified_context", "passed_context"), + [ + (None, {}), + ({"query_origin": "airflow"}, {"query_origin": "airflow"}), + ], + ) + def test_get_conn_with_context( + self, mock_connect, mock_get_connection, specified_context, passed_context + ): + get_conn_value = MagicMock() + get_conn_value.host = "test_host" + get_conn_value.conn_type = "https" + get_conn_value.login = "test_login" + get_conn_value.password = "test_password" + get_conn_value.port = 10000 + get_conn_value.extra_dejson = {"endpoint": "/test/endpoint", "schema": "https"} + mock_get_connection.return_value = get_conn_value + hook = DruidDbApiHook(context=specified_context) + hook.get_conn() + mock_connect.assert_called_with( + host="test_host", + port=10000, + path="/test/endpoint", + scheme="https", + user="test_login", + password="test_password", + context=passed_context, + ) + def test_get_uri(self): db_hook = self.db_hook() assert "druid://host:1000/druid/v2/sql" == db_hook.get_uri()