Skip to content

Commit

Permalink
[DOP-13846] - add bypass test and note for rootTag
Browse files Browse the repository at this point in the history
  • Loading branch information
maxim-lixakov committed May 2, 2024
1 parent 89f691f commit d49cd7d
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 9 deletions.
8 changes: 6 additions & 2 deletions onetl/file/format/xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,11 @@ def parse_column(self, column: str | Column, schema: StructType) -> Column:
This method assumes that the ``spark-xml`` package is installed and properly configured within your Spark environment.
.. note::
This method does not support XML strings with a root tag that is not specified as the ``rowTag``. If your XML data includes a root tag that encapsulates multiple row tags, ensure to preprocess the XML string to remove or ignore the root tag before parsing.
Parameters
----------
column : str | Column
Expand Down Expand Up @@ -294,9 +299,8 @@ def parse_column(self, column: str | Column, schema: StructType) -> Column:

java_column = _to_java_column(column)
java_schema = spark._jsparkSession.parseDataType(schema.json()) # noqa: WPS437
filtered_options = {k: v for k, v in self.dict().items() if k in self.Config.known_options}
scala_options = spark._jvm.org.apache.spark.api.python.PythonUtils.toScalaMap( # noqa: WPS219, WPS437
filtered_options,
self.dict(),
)
jc = spark._jvm.com.databricks.spark.xml.functions.from_xml( # noqa: WPS219, WPS437
java_column,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
from onetl.file.format import XML

try:
from tests.util.assert_df import assert_equal_df
from pyspark.sql.functions import col

from tests.util.assert_df import assert_equal_df, assert_subset_df
except ImportError:
pytest.skip("Missing pandas", allow_module_level=True)
pytest.skip("Missing pandas or pyspark", allow_module_level=True)

pytestmark = [pytest.mark.local_fs, pytest.mark.file_df_connection, pytest.mark.connection, pytest.mark.xml]

Expand Down Expand Up @@ -168,13 +170,18 @@ def test_xml_reader_with_attributes(
assert_equal_df(read_df, expected_xml_attributes_df, order_by="id")


@pytest.mark.parametrize("column_type", [str, col])
def test_xml_parse_column(
spark,
local_fs_file_df_connection_with_path_and_files,
expected_xml_attributes_df,
file_df_dataframe,
file_df_schema,
column_type,
):
from pyspark.sql.functions import expr
from pyspark.sql.types import StructType

from onetl.file.format import XML

spark_version = get_spark_version(spark)
Expand All @@ -188,8 +195,26 @@ def test_xml_parse_column(
xml_data = file.read()

df = spark.createDataFrame([(xml_data,)], ["xml_string"])
df.show(truncate=False)
xml = XML.parse({"rowTag": "item", "rootTag": "root"})
parsed_df = df.select(xml.parse_column("xml_string", schema=file_df_schema))

parsed_df.show(truncate=False)
# remove the <root> tag from the XML string
df = df.withColumn("xml_string", expr("regexp_replace(xml_string, '^<root>|</root>$', '')"))

xml = XML(row_tag="item")
parsed_df = df.select(xml.parse_column(column_type("xml_string"), schema=file_df_schema))
assert isinstance(parsed_df.schema, StructType)
transformed_df = parsed_df.select(
"xml_string.id",
"xml_string.str_value",
"xml_string.int_value",
"xml_string.date_value",
"xml_string.datetime_value",
"xml_string.float_value",
)
expected_df_selected = expected_xml_attributes_df.select(
"id",
"str_value",
"int_value",
"date_value",
"datetime_value",
"float_value",
)
assert_subset_df(transformed_df, expected_df_selected)

0 comments on commit d49cd7d

Please sign in to comment.