Skip to content

Commit

Permalink
[DOP-13845] - add _get_schema_json
Browse files Browse the repository at this point in the history
  • Loading branch information
maxim-lixakov committed Apr 26, 2024
1 parent c14ecf9 commit e1cf26f
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 38 deletions.
57 changes: 38 additions & 19 deletions onetl/file/format/avro.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import logging
from typing import TYPE_CHECKING, ClassVar, Dict, Optional

import requests

try:
from pydantic.v1 import Field, validator
except (ImportError, AttributeError):
Expand Down Expand Up @@ -236,22 +234,20 @@ def parse_column(self, column: str | Column) -> Column:
from pyspark.sql.avro.functions import from_avro
from pyspark.sql.functions import col

self.check_if_supported(SparkSession.getActiveSession())
spark = SparkSession.getActiveSession()
self.check_if_supported(spark)
self._check_spark_version_for_serialization(spark)

if isinstance(column, Column):
column_name = column._jc.toString() # noqa: WPS437
else:
column_name, column = column, col(column).cast("binary")

if self.schema_dict:
schema_json = json.dumps(self.schema_dict)
elif self.schema_url:
response = requests.get(self.schema_url) # noqa: S113
schema_json = response.text
else:
raise ValueError("No schema defined in Avro class instance.")
schema = self._get_schema_json()
if not schema:
raise ValueError("Avro.parse_column can be used only with defined `schema_dict` or `schema_url`")

return from_avro(column, schema_json).alias(column_name)
return from_avro(column, schema).alias(column_name)

def serialize_column(self, column: str | Column) -> Column:
"""
Expand Down Expand Up @@ -298,24 +294,47 @@ def serialize_column(self, column: str | Column) -> Column:
from pyspark.sql.avro.functions import to_avro
from pyspark.sql.functions import col

self.check_if_supported(SparkSession._instantiatedSession) # noqa: WPS437
spark = SparkSession.getActiveSession()
self.check_if_supported(spark)
self._check_spark_version_for_serialization(spark)

if isinstance(column, Column):
column_name = column._jc.toString() # noqa: WPS437
else:
column_name, column = column, col(column)

if self.schema_dict:
schema = json.dumps(self.schema_dict)
elif self.self.schema_url:
schema = requests.get(self.self.schema_url) # noqa: S113
else:
schema = ""

schema = self._get_schema_json()
return to_avro(column, schema).alias(column_name)

@validator("schema_dict", pre=True)
def _parse_schema_from_json(cls, value):
if isinstance(value, (str, bytes)):
return json.loads(value)
return value

def _check_spark_version_for_serialization(self, spark: SparkSession):
spark_version = get_spark_version(spark)
if spark_version.major < 3:
class_name = self.__class__.__name__
error_msg = (
f"`{class_name}.parse_column` and `{class_name}.serialize_column` are available "
f"only with Spark 3.x and above, but the current Spark version is {spark_version}."
)
raise ValueError(error_msg)

def _get_schema_json(self) -> str:
if self.schema_dict:
return json.dumps(self.schema_dict)
elif self.schema_url:
try:
import requests

response = requests.get(self.schema_url) # noqa: S113
return response.text
except ImportError:
raise ImportError(
"The 'requests' library is required to use 'schema_url' but is not installed. "
"Install it with 'pip install requests' or avoid using 'schema_url'.",
)
else:
return ""
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
Do not test all the possible options and combinations, we are not testing Spark here.
"""

import contextlib

import pytest

from onetl._util.spark import get_spark_version
Expand Down Expand Up @@ -124,39 +126,34 @@ def test_avro_writer(
assert_equal_df(read_df, df, order_by="id")


@pytest.mark.parametrize(
"path, options",
[
("without_compression", {}),
("with_compression", {"compression": "snappy"}),
],
ids=["without_compression", "with_compression"],
)
@pytest.mark.parametrize("column_type", [str, col])
def test_avro_serialize_and_parse_column(
spark,
local_fs_file_df_connection_with_path,
file_df_dataframe,
path,
avro_schema,
options,
column_type,
):
from pyspark.sql.functions import struct
from pyspark.sql.types import BinaryType

spark_version = get_spark_version(spark)
if spark_version < Version("2.4"):
pytest.skip("Avro from_avro, to_avro are supported on Spark 3.x+ only")
if spark_version.major < 3:
msg = (
f"`Avro.parse_column` or `Avro.serialize_column` are available "
f"only since Spark 3.x, but got {spark_version}"
)
context_manager = pytest.raises(ValueError, match=msg)
else:
context_manager = contextlib.nullcontext()

df = file_df_dataframe
avro = Avro(schema_dict=avro_schema, **options)
avro = Avro(schema_dict=avro_schema)

combined_df = df.withColumn("combined", struct([col(c) for c in df.columns]))
serialized_df = combined_df.select(avro.serialize_column(column_type("combined")))

assert isinstance(serialized_df.schema["combined"].dataType, BinaryType)

parsed_df = serialized_df.select(avro.parse_column(column_type("combined")))

assert combined_df.select(column_type("combined")).collect() == parsed_df.collect()
with context_manager:
serialized_df = combined_df.select(avro.serialize_column(column_type("combined")))
assert isinstance(serialized_df.schema["combined"].dataType, BinaryType)
parsed_df = serialized_df.select(avro.parse_column(column_type("combined")))
assert combined_df.select("combined").collect() == parsed_df.collect()

0 comments on commit e1cf26f

Please sign in to comment.