Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog/next_release/160.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `spark-dialect-extension <https://github.com/MobileTeleSystems/spark-dialect-extension/>`_
4 changes: 4 additions & 0 deletions syncmaster/worker/handlers/db/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class ClickhouseHandler(DBHandler):
transfer_dto: ClickhouseTransferDTO

def connect(self, spark: SparkSession):
ClickhouseDialectRegistry = (
spark._jvm.io.github.mtsongithub.doetl.sparkdialectextensions.clickhouse.ClickhouseDialectRegistry
)
ClickhouseDialectRegistry.register()
self.connection = Clickhouse(
host=self.connection_dto.host,
port=self.connection_dto.port,
Expand Down
6 changes: 4 additions & 2 deletions syncmaster/worker/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ def get_packages(db_type: str) -> list[str]:
if db_type == "oracle":
return Oracle.get_packages()
if db_type == "clickhouse":
# TODO: add https://github.com/MobileTeleSystems/spark-dialect-extension/ to spark jars
return Clickhouse.get_packages()
return [
"io.github.mtsongithub.doetl:spark-dialect-extension_2.12:0.0.2",
*Clickhouse.get_packages(),
]
if db_type == "mssql":
return MSSQL.get_packages()
if db_type == "mysql":
Expand Down
6 changes: 6 additions & 0 deletions tests/test_integration/test_run_transfer/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def spark(settings: Settings, request: FixtureRequest) -> SparkSession:
maven_packages.extend(Oracle.get_packages())

if "clickhouse" in markers:
maven_packages.append("io.github.mtsongithub.doetl:spark-dialect-extension_2.12:0.0.2")
maven_packages.extend(Clickhouse.get_packages())

if "mssql" in markers:
Expand Down Expand Up @@ -635,6 +636,11 @@ def prepare_clickhouse(
clickhouse_for_conftest: ClickhouseConnectionDTO,
spark: SparkSession,
):
ClickhouseDialectRegistry = (
spark._jvm.io.github.mtsongithub.doetl.sparkdialectextensions.clickhouse.ClickhouseDialectRegistry
)
ClickhouseDialectRegistry.register()

clickhouse = clickhouse_for_conftest
onetl_conn = Clickhouse(
host=clickhouse.host,
Expand Down
15 changes: 0 additions & 15 deletions tests/test_integration/test_run_transfer/test_clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from onetl.connection import Clickhouse
from onetl.db import DBReader
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, date_trunc
from sqlalchemy.ext.asyncio import AsyncSession

from syncmaster.db.models import Connection, Group, Queue, Status, Transfer
Expand Down Expand Up @@ -117,8 +116,6 @@ async def test_run_transfer_postgres_to_clickhouse(
table=f"{clickhouse.user}.target_table",
)
df = reader.run()
# as spark truncates milliseconds while writing to clickhouse: https://onetl.readthedocs.io/en/latest/connection/db_connection/clickhouse/types.html#id10
init_df = init_df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT")))
for field in init_df.schema:
df = df.withColumn(field.name, df[field.name].cast(field.dataType))

Expand Down Expand Up @@ -169,11 +166,6 @@ async def test_run_transfer_postgres_to_clickhouse_mixed_naming(
assert df.columns != init_df_with_mixed_column_naming.columns
assert df.columns == [column.lower() for column in init_df_with_mixed_column_naming.columns]

# as spark truncates milliseconds while writing to clickhouse: https://onetl.readthedocs.io/en/latest/connection/db_connection/clickhouse/types.html#id10
init_df_with_mixed_column_naming = init_df_with_mixed_column_naming.withColumn(
"Registered At",
date_trunc("second", col("Registered At")),
)
for field in init_df_with_mixed_column_naming.schema:
df = df.withColumn(field.name, df[field.name].cast(field.dataType))

Expand Down Expand Up @@ -222,8 +214,6 @@ async def test_run_transfer_clickhouse_to_postgres(
)
df = reader.run()

# as spark truncates milliseconds while writing to clickhouse: https://onetl.readthedocs.io/en/latest/connection/db_connection/clickhouse/types.html#id10
init_df = init_df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT")))
for field in init_df.schema:
df = df.withColumn(field.name, df[field.name].cast(field.dataType))

Expand Down Expand Up @@ -275,11 +265,6 @@ async def test_run_transfer_clickhouse_to_postgres_mixed_naming(
assert df.columns != init_df_with_mixed_column_naming.columns
assert df.columns == [column.lower() for column in init_df_with_mixed_column_naming.columns]

# as spark truncates milliseconds while writing to clickhouse: https://onetl.readthedocs.io/en/latest/connection/db_connection/clickhouse/types.html#id10
init_df_with_mixed_column_naming = init_df_with_mixed_column_naming.withColumn(
"Registered At",
date_trunc("second", col("Registered At")),
)
for field in init_df_with_mixed_column_naming.schema:
df = df.withColumn(field.name, df[field.name].cast(field.dataType))

Expand Down