Skip to content

Commit

Permalink
[DOP-15564] Avoid urlencoding JDBC params
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus committed Apr 26, 2024
1 parent 96ce940 commit 95e6069
Show file tree
Hide file tree
Showing 24 changed files with 395 additions and 182 deletions.
2 changes: 2 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ services:
- 5433:5432
networks:
- onetl
sysctls:
- net.ipv6.conf.all.disable_ipv6=1

clickhouse:
image: ${CLICKHOUSE_IMAGE:-clickhouse/clickhouse-server:latest-alpine}
Expand Down
1 change: 1 addition & 0 deletions docs/changelog/next_release/268.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow passing JDBC connection extra params without urlencode.
13 changes: 8 additions & 5 deletions onetl/connection/db_connection/clickhouse/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,16 @@ def package(self) -> str:

@property
def jdbc_url(self) -> str:
extra = self.extra.dict(by_alias=True)
parameters = "&".join(f"{k}={v}" for k, v in sorted(extra.items()))

if self.database:
return f"jdbc:clickhouse://{self.host}:{self.port}/{self.database}?{parameters}".rstrip("?")
return f"jdbc:clickhouse://{self.host}:{self.port}/{self.database}"

return f"jdbc:clickhouse://{self.host}:{self.port}"

return f"jdbc:clickhouse://{self.host}:{self.port}?{parameters}".rstrip("?")
@property
def jdbc_params(self) -> dict:
result = super().jdbc_params
result.update(self.extra.dict(by_alias=True))
return result

@staticmethod
def _build_statement(
Expand Down
21 changes: 13 additions & 8 deletions onetl/connection/db_connection/greenplum/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,15 +250,20 @@ def instance_url(self) -> str:

@property
def jdbc_url(self) -> str:
extra = {
key: value
for key, value in self.extra.dict(by_alias=True).items()
if not (key.startswith("server.") or key.startswith("pool."))
}
extra["ApplicationName"] = extra.get("ApplicationName", self.spark.sparkContext.appName)
return f"jdbc:postgresql://{self.host}:{self.port}/{self.database}"

parameters = "&".join(f"{k}={v}" for k, v in sorted(extra.items()))
return f"jdbc:postgresql://{self.host}:{self.port}/{self.database}?{parameters}".rstrip("?")
@property
def jdbc_params(self) -> dict:
result = super().jdbc_params
result.update(
{
key: value
for key, value in self.extra.dict(by_alias=True).items()
if not (key.startswith("server.") or key.startswith("pool."))
},
)
result["ApplicationName"] = result.get("ApplicationName", self.spark.sparkContext.appName)
return result

@slot
def read_source_as_df(
Expand Down
40 changes: 4 additions & 36 deletions onetl/connection/db_connection/jdbc_connection/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,15 @@ def write_df_to_target(
options: JDBCWriteOptions | None = None,
) -> None:
write_options = self.WriteOptions.parse(options)
jdbc_params = self.options_to_jdbc_params(write_options)
jdbc_properties = self._get_jdbc_properties(write_options, exclude={"if_exists"}, exclude_none=True)

mode = (
"overwrite"
if write_options.if_exists == JDBCTableExistBehavior.REPLACE_ENTIRE_TABLE
else write_options.if_exists.value
)
log.info("|%s| Saving data to a table %r", self.__class__.__name__, target)
df.write.jdbc(table=target, mode=mode, **jdbc_params)
df.write.format("jdbc").mode(mode).options(dbtable=target, **jdbc_properties).save()
log.info("|%s| Table %r successfully written", self.__class__.__name__, target)

@slot
Expand All @@ -196,38 +196,6 @@ def get_df_schema(

return df.schema

def options_to_jdbc_params(
self,
options: JDBCReadOptions | JDBCWriteOptions,
) -> dict:
# Have to replace the <partitionColumn> parameter with <column>
# since the method takes the named <column> parameter
# link to source below
# https://github.com/apache/spark/blob/2ef8ced27a6b0170a691722a855d3886e079f037/python/pyspark/sql/readwriter.py#L465

partition_column = getattr(options, "partition_column", None)
if partition_column:
options = options.copy(
update={"column": partition_column},
exclude={"partition_column"},
)

result = self._get_jdbc_properties(
options,
include=READ_TOP_LEVEL_OPTIONS | WRITE_TOP_LEVEL_OPTIONS,
exclude={"if_exists"},
exclude_none=True,
)

result["properties"] = self._get_jdbc_properties(
options,
exclude=READ_TOP_LEVEL_OPTIONS | WRITE_TOP_LEVEL_OPTIONS | {"if_exists"},
exclude_none=True,
)

result["properties"].pop("partitioningMode", None)
return result

@slot
def get_min_max_values(
self,
Expand Down Expand Up @@ -275,8 +243,8 @@ def _query_on_executor(
query: str,
options: JDBCReadOptions,
) -> DataFrame:
jdbc_params = self.options_to_jdbc_params(options)
return self.spark.read.jdbc(table=f"({query}) T", **jdbc_params)
jdbc_properties = self._get_jdbc_properties(options, exclude={"partitioning_mode"}, exclude_none=True)
return self.spark.read.format("jdbc").options(dbtable=f"({query}) T", **jdbc_properties).load()

def _exclude_partition_options(
self,
Expand Down
27 changes: 14 additions & 13 deletions onetl/connection/db_connection/jdbc_mixin/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ class JDBCMixin(FrozenModel):
def jdbc_url(self) -> str:
"""JDBC Connection URL"""

@property
def jdbc_params(self) -> dict:
"""JDBC Connection params"""
return {
"user": self.user,
"password": self.password.get_secret_value() if self.password is not None else "",
"driver": self.DRIVER,
"url": self.jdbc_url,
}

@slot
def close(self):
"""
Expand Down Expand Up @@ -312,20 +322,12 @@ def _get_jdbc_properties(
self,
options: JDBCMixinOptions,
**kwargs,
) -> dict:
) -> dict[str, str]:
"""
Fills up human-readable Options class to a format required by Spark internal methods
"""

result = options.copy(
update={
"user": self.user,
"password": self.password.get_secret_value() if self.password is not None else "",
"driver": self.DRIVER,
"url": self.jdbc_url,
},
).dict(by_alias=True, **kwargs)

result = self.jdbc_params
result.update(options.dict(by_alias=True, **kwargs))
return stringify(result)

def _options_to_connection_properties(self, options: JDBCMixinOptions):
Expand All @@ -339,8 +341,7 @@ def _options_to_connection_properties(self, options: JDBCMixinOptions):
* https://github.com/apache/spark/blob/v2.3.0/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala#L248-L255
"""

jdbc_properties = self._get_jdbc_properties(options, exclude_unset=True)

jdbc_properties = self._get_jdbc_properties(options, exclude_none=True)
jdbc_utils_package = self.spark._jvm.org.apache.spark.sql.execution.datasources.jdbc # type: ignore
jdbc_options = jdbc_utils_package.JDBCOptions(
self.jdbc_url,
Expand Down
11 changes: 7 additions & 4 deletions onetl/connection/db_connection/mssql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,14 @@ def package(cls) -> str:

@property
def jdbc_url(self) -> str:
prop = self.extra.dict(by_alias=True)
prop["databaseName"] = self.database
parameters = ";".join(f"{k}={v}" for k, v in sorted(prop.items()))
return f"jdbc:sqlserver://{self.host}:{self.port}"

return f"jdbc:sqlserver://{self.host}:{self.port};{parameters}"
@property
def jdbc_params(self) -> dict:
result = super().jdbc_params
result.update(self.extra.dict(by_alias=True))
result["databaseName"] = self.database
return result

@property
def instance_url(self) -> str:
Expand Down
15 changes: 9 additions & 6 deletions onetl/connection/db_connection/mysql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,14 @@ def package(cls) -> str:
return "com.mysql:mysql-connector-j:8.3.0"

@property
def jdbc_url(self):
prop = self.extra.dict(by_alias=True)
parameters = "&".join(f"{k}={v}" for k, v in sorted(prop.items()))

def jdbc_url(self) -> str:
if self.database:
return f"jdbc:mysql://{self.host}:{self.port}/{self.database}?{parameters}"
return f"jdbc:mysql://{self.host}:{self.port}/{self.database}"

return f"jdbc:mysql://{self.host}:{self.port}"

return f"jdbc:mysql://{self.host}:{self.port}?{parameters}"
@property
def jdbc_params(self) -> dict:
result = super().jdbc_params
result.update(self.extra.dict(by_alias=True))
return result
13 changes: 8 additions & 5 deletions onetl/connection/db_connection/oracle/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,16 @@ def package(cls) -> str:

@property
def jdbc_url(self) -> str:
extra = self.extra.dict(by_alias=True)
parameters = "&".join(f"{k}={v}" for k, v in sorted(extra.items()))

if self.sid:
return f"jdbc:oracle:thin:@{self.host}:{self.port}:{self.sid}?{parameters}".rstrip("?")
return f"jdbc:oracle:thin:@{self.host}:{self.port}:{self.sid}"

return f"jdbc:oracle:thin:@//{self.host}:{self.port}/{self.service_name}"

return f"jdbc:oracle:thin:@//{self.host}:{self.port}/{self.service_name}?{parameters}".rstrip("?")
@property
def jdbc_params(self) -> dict:
result = super().jdbc_params
result.update(self.extra.dict(by_alias=True))
return result

@property
def instance_url(self) -> str:
Expand Down
15 changes: 11 additions & 4 deletions onetl/connection/db_connection/postgres/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ class PostgresExtra(GenericOptions):
# allows automatic conversion from text to target column type during write
stringtype: str = "unspecified"

# avoid closing connections from server side
# while connector is moving data to executors before insert
tcpKeepAlive: str = "true" # noqa: N815

class Config:
extra = "allow"

Expand Down Expand Up @@ -142,11 +146,14 @@ def package(cls) -> str:

@property
def jdbc_url(self) -> str:
extra = self.extra.dict(by_alias=True)
extra["ApplicationName"] = extra.get("ApplicationName", self.spark.sparkContext.appName)
return f"jdbc:postgresql://{self.host}:{self.port}/{self.database}"

parameters = "&".join(f"{k}={v}" for k, v in sorted(extra.items()))
return f"jdbc:postgresql://{self.host}:{self.port}/{self.database}?{parameters}".rstrip("?")
@property
def jdbc_params(self) -> dict[str, str]:
result = super().jdbc_params
result.update(self.extra.dict(by_alias=True))
result["ApplicationName"] = result.get("ApplicationName", self.spark.sparkContext.appName)
return result

@property
def instance_url(self) -> str:
Expand Down
21 changes: 16 additions & 5 deletions onetl/connection/db_connection/teradata/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings
from typing import ClassVar, Optional

from onetl._internal import stringify
from onetl._util.classproperty import classproperty
from onetl._util.version import Version
from onetl.connection.db_connection.jdbc_connection import JDBCConnection
Expand Down Expand Up @@ -162,12 +163,22 @@ def package(cls) -> str:

@property
def jdbc_url(self) -> str:
prop = self.extra.dict(by_alias=True)
# Teradata JDBC driver documentation specifically mentions that params from
# java.sql.DriverManager.getConnection(url, params) are used to only retrieve 'user' and 'password' values.
# Other params should be passed via url
properties = self.extra.dict(by_alias=True)

if self.database:
prop["DATABASE"] = self.database
properties["DATABASE"] = self.database

prop["DBS_PORT"] = self.port
properties["DBS_PORT"] = self.port

conn = ",".join(f"{k}={v}" for k, v in sorted(prop.items()))
return f"jdbc:teradata://{self.host}/{conn}"
connection_params = []
for key, value in sorted(properties.items()):
string_value = stringify(value)
if "," in string_value:
connection_params.append(f"{key}='{string_value}'")
else:
connection_params.append(f"{key}={string_value}")

return f"jdbc:teradata://{self.host}/{','.join(connection_params)}"
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ def test_clickhouse_connection_check_fail(spark):
clickhouse.check()


def test_clickhouse_connection_check_extra_is_passed(spark, processing):
clickhouse = Clickhouse(
host=processing.host,
port=processing.port,
user=processing.user,
password=processing.password,
database=processing.database,
spark=spark,
extra={"socket_timeout": "fail"},
)

with pytest.raises(RuntimeError, match="Connection is unavailable"):
clickhouse.check()


@pytest.mark.parametrize("suffix", ["", ";"])
def test_clickhouse_connection_sql(spark, processing, load_table_data, suffix):
clickhouse = Clickhouse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,21 @@ def test_greenplum_connection_check_fail(spark):
greenplum.check()


def test_greenplum_connection_check_extra_is_passed(spark, processing):
greenplum = Greenplum(
host=processing.host,
port=processing.port,
user=processing.user,
password=processing.password,
database=processing.database,
spark=spark,
extra={**processing.extra, "connectTimeout": "fail"},
)

with pytest.raises(RuntimeError, match="Connection is unavailable"):
greenplum.check()


@pytest.mark.parametrize("suffix", ["", ";"])
def test_greenplum_connection_fetch(spark, processing, load_table_data, suffix):
greenplum = Greenplum(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,21 @@ def test_mssql_connection_check_fail(spark):
mssql.check()


def test_mssql_connection_check_check_extra_is_passed(spark, processing):
mssql = MSSQL(
host=processing.host,
port=processing.port,
user=processing.user,
password=processing.password,
database=processing.database,
spark=spark,
extra={"trustServerCertificate": "false"},
)

with pytest.raises(RuntimeError, match="Connection is unavailable"):
mssql.check()


@pytest.mark.parametrize("suffix", ["", ";"])
def test_mssql_connection_sql(spark, processing, load_table_data, suffix):
mssql = MSSQL(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ def test_mysql_connection_check_fail(spark):
mysql.check()


def test_mysql_connection_check_check_extra_is_passed(spark, processing):
mysql = MySQL(
host=processing.host,
port=processing.port,
user=processing.user,
password=processing.password,
database=processing.database,
spark=spark,
extra={"tcpKeepAlive": "fail"},
)

with pytest.raises(RuntimeError, match="Connection is unavailable"):
mysql.check()


@pytest.mark.parametrize("suffix", ["", ";"])
def test_mysql_connection_sql(spark, processing, load_table_data, suffix):
mysql = MySQL(
Expand Down
Loading

0 comments on commit 95e6069

Please sign in to comment.