Skip to content

Commit

Permalink
Validate host and schema for Spark JDBC Hook (#30223)
Browse files Browse the repository at this point in the history
The host and schema of JDBC Hook should not contain / and ?
as they are delimiting end of those fields.
  • Loading branch information
potiuk committed Mar 22, 2023
1 parent 1290a17 commit d9dea5c
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
4 changes: 4 additions & 0 deletions airflow/providers/apache/spark/hooks/spark_jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ def _resolve_jdbc_connection(self) -> dict[str, Any]:
conn_data = {"url": "", "schema": "", "conn_prefix": "", "user": "", "password": ""}
try:
conn = self.get_connection(self._jdbc_conn_id)
if "/" in conn.host:
raise ValueError("The jdbc host should not contain a '/'")
if "?" in conn.schema:
raise ValueError("The jdbc schema should not contain a '?'")
if conn.port:
conn_data["url"] = f"{conn.host}:{conn.port}"
else:
Expand Down
34 changes: 34 additions & 0 deletions tests/providers/apache/spark/hooks/test_spark_jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
# under the License.
from __future__ import annotations

import pytest

from airflow.models import Connection
from airflow.providers.apache.spark.hooks.spark_jdbc import SparkJDBCHook
from airflow.utils import db
Expand Down Expand Up @@ -82,6 +84,30 @@ def setup_method(self):
extra='{"conn_prefix":"jdbc:postgresql://"}',
)
)
db.merge_conn(
Connection(
conn_id="jdbc-invalid-host",
conn_type="postgres",
host="localhost/test",
schema="default",
port=5432,
login="user",
password="supersecret",
extra='{"conn_prefix":"jdbc:postgresql://"}',
)
)
db.merge_conn(
Connection(
conn_id="jdbc-invalid-schema",
conn_type="postgres",
host="localhost",
schema="default?test=",
port=5432,
login="user",
password="supersecret",
extra='{"conn_prefix":"jdbc:postgresql://"}',
)
)

def test_resolve_jdbc_connection(self):
# Given
Expand Down Expand Up @@ -150,3 +176,11 @@ def test_build_jdbc_arguments_invalid(self):

# Expect Exception
hook._build_jdbc_application_arguments(hook._resolve_jdbc_connection())

def test_invalid_host(self):
with pytest.raises(ValueError, match="host should not contain a"):
SparkJDBCHook(jdbc_conn_id="jdbc-invalid-host", **self._config)

def test_invalid_schema(self):
with pytest.raises(ValueError, match="schema should not contain a"):
SparkJDBCHook(jdbc_conn_id="jdbc-invalid-schema", **self._config)

0 comments on commit d9dea5c

Please sign in to comment.