diff --git a/src/dve/core_engine/backends/implementations/duckdb/contract.py b/src/dve/core_engine/backends/implementations/duckdb/contract.py index 25fb8a7..3595716 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/contract.py +++ b/src/dve/core_engine/backends/implementations/duckdb/contract.py @@ -31,6 +31,7 @@ duckdb_read_parquet, duckdb_record_index, duckdb_write_parquet, + get_duckdb_cast_statement_from_annotation, get_duckdb_type_from_annotation, relation_is_empty, ) @@ -101,18 +102,7 @@ def create_entity_from_py_iterator( # pylint: disable=unused-argument _lazy_df = pl.LazyFrame(records, polars_schema) # type: ignore # pylint: disable=unused-variable return self._connection.sql("select * from _lazy_df") - @staticmethod - def generate_ddb_cast_statement( - column_name: str, dtype: DuckDBPyType, null_flag: bool = False - ) -> str: - """Helper method to generate sql statements for casting datatypes (permissively). - Current duckdb python API doesn't play well with this currently. - """ - if not null_flag: - return f'try_cast("{column_name}" AS {dtype}) AS "{column_name}"' - return f'cast(NULL AS {dtype}) AS "{column_name}"' - - # pylint: disable=R0914 + # pylint: disable=R0914,R0915 def apply_data_contract( self, working_dir: URI, @@ -180,12 +170,16 @@ def apply_data_contract( casting_statements = [ ( - self.generate_ddb_cast_statement(column, dtype) + get_duckdb_cast_statement_from_annotation(column, mdl_fld.annotation) + + f""" AS "{column}" """ if column in relation.columns - else self.generate_ddb_cast_statement(column, dtype, null_flag=True) + else f"CAST(NULL AS {ddb_schema[column]}) AS {column}" ) - for column, dtype in ddb_schema.items() + for column, mdl_fld in entity_fields.items() ] + casting_statements.append( + f"CAST({RECORD_INDEX_COLUMN_NAME} AS {get_duckdb_type_from_annotation(int)}) AS {RECORD_INDEX_COLUMN_NAME}" # pylint: disable=C0301 + ) try: relation = relation.project(", ".join(casting_statements)) except Exception as err: # pylint: disable=broad-except diff --git a/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py b/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py index f5b0fe9..394cd01 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py +++ b/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py @@ -313,3 +313,108 @@ def duckdb_record_index(cls): setattr(cls, "add_record_index", _add_duckdb_record_index) setattr(cls, "drop_record_index", _drop_duckdb_record_index) return cls + + +def _cast_as_ddb_type(field_expr: str, type_annotation: Any) -> str: + """Cast to Duck DB type""" + return f"""try_cast({field_expr} as {get_duckdb_type_from_annotation(type_annotation)})""" + + +def _ddb_safely_quote_name(field_name: str) -> str: + """Quote field names in case reserved""" + try: + sep_idx = field_name.index(".") + return f'"{field_name[: sep_idx]}"' + field_name[sep_idx:] + except ValueError: + return f'"{field_name}"' + + +# pylint: disable=R0801,R0911,R0912 +def get_duckdb_cast_statement_from_annotation( + element_name: str, + type_annotation: Any, + parent_element: bool = True, + date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$", + timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\+|\-)[0-9]{2}:[0-9]{2})?$", # pylint: disable=C0301 + time_regex: str = r"^[0-9]{2}:[0-9]{2}:[0-9]{2}$", +) -> str: + """Generate casting statements for duckdb relations from type annotations""" + type_origin = get_origin(type_annotation) + + quoted_name = _ddb_safely_quote_name(element_name) + + # An `Optional` or `Union` type, check to ensure non-heterogenity. + if type_origin is Union: + python_type = _get_non_heterogenous_type(get_args(type_annotation)) + return get_duckdb_cast_statement_from_annotation( + element_name, python_type, parent_element, date_regex, timestamp_regex + ) + + # Type hint is e.g. `List[str]`, check to ensure non-heterogenity. + if type_origin is list or (isinstance(type_origin, type) and issubclass(type_origin, list)): + element_type = _get_non_heterogenous_type(get_args(type_annotation)) + stmt = f"list_transform({quoted_name}, x -> {get_duckdb_cast_statement_from_annotation('x',element_type, False, date_regex, timestamp_regex)})" # pylint: disable=C0301 + return stmt if not parent_element else _cast_as_ddb_type(stmt, type_annotation) + + if type_origin is Annotated: + python_type, *other_args = get_args(type_annotation) # pylint: disable=unused-variable + return get_duckdb_cast_statement_from_annotation( + element_name, python_type, parent_element, date_regex, timestamp_regex + ) # add other expected params here + # Ensure that we have a concrete type at this point. + if not isinstance(type_annotation, type): + raise ValueError(f"Unsupported type annotation {type_annotation!r}") + + if ( + # Type hint is a dict subclass, but not dict. Possibly a `TypedDict`. + (issubclass(type_annotation, dict) and type_annotation is not dict) + # Type hint is a dataclass. + or is_dataclass(type_annotation) + # Type hint is a `pydantic` model. + or (type_origin is None and issubclass(type_annotation, BaseModel)) + ): + fields: dict[str, str] = {} + for field_name, field_annotation in get_type_hints(type_annotation).items(): + # Technically non-string keys are disallowed, but people are bad. + if not isinstance(field_name, str): + raise ValueError( + f"Dictionary/Dataclass keys must be strings, got {type_annotation!r}" + ) # pragma: no cover + if get_origin(field_annotation) is ClassVar: + continue + + fields[field_name] = get_duckdb_cast_statement_from_annotation( + f"{element_name}.{field_name}", field_annotation, False, date_regex, timestamp_regex + ) + + if not fields: + raise ValueError( + f"No type annotations in dict/dataclass type (got {type_annotation!r})" + ) + cast_exprs = ",".join([f'"{nme}":= {stmt}' for nme, stmt in fields.items()]) + stmt = f"struct_pack({cast_exprs})" + return stmt if not parent_element else _cast_as_ddb_type(stmt, type_annotation) + + if type_annotation is list: + raise ValueError( + f"List must have type annotation (e.g. `List[str]`), got {type_annotation!r}" + ) + if type_annotation is dict or type_origin is dict: + raise ValueError(f"dict must be `typing.TypedDict` subclass, got {type_annotation!r}") + + for type_ in type_annotation.mro(): + # datetime is subclass of date, so needs to be handled first + if issubclass(type_, datetime): + stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{timestamp_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIMESTAMP) ELSE NULL END" # pylint: disable=C0301 + return stmt + if issubclass(type_, date): + stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{date_regex}') THEN TRY_CAST(TRIM({quoted_name}) as DATE) ELSE NULL END" # pylint: disable=C0301 + return stmt + if issubclass(type_, time): + stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{time_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIME) ELSE NULL END" # pylint: disable=C0301 + return stmt + duck_type = get_duckdb_type_from_annotation(type_) + if duck_type: + stmt = f"trim({quoted_name})" + return _cast_as_ddb_type(stmt, type_) if parent_element else stmt + raise ValueError(f"No equivalent DuckDB type for {type_annotation!r}") diff --git a/src/dve/core_engine/backends/implementations/spark/spark_helpers.py b/src/dve/core_engine/backends/implementations/spark/spark_helpers.py index 07a4a04..ced985a 100644 --- a/src/dve/core_engine/backends/implementations/spark/spark_helpers.py +++ b/src/dve/core_engine/backends/implementations/spark/spark_helpers.py @@ -439,3 +439,103 @@ def spark_record_index(cls): setattr(cls, "add_record_index", _add_spark_record_index) setattr(cls, "drop_record_index", _drop_spark_record_index) return cls + + +def _cast_as_spark_type(field_expr: str, field_type: Any) -> Column: + """Cast to spark type""" + return sf.expr(field_expr).cast(get_type_from_annotation(field_type)) + + +def _spark_safely_quote_name(field_name: str) -> str: + """Quote field names in case reserved""" + try: + sep_idx = field_name.index(".") + return f"`{field_name[: sep_idx]}`" + field_name[sep_idx:] + except ValueError: + return f"`{field_name}`" + + +# pylint: disable=R0801 +def get_spark_cast_statement_from_annotation( + element_name: str, + type_annotation: Any, + parent_element: bool = True, + date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$", + timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\\+|\\-)[0-9]{2}:[0-9]{2})?$", # pylint: disable=C0301 +): + """Generate casting statements for spark dataframes based on type annotations""" + type_origin = get_origin(type_annotation) + + quoted_name = _spark_safely_quote_name(element_name) + + # An `Optional` or `Union` type, check to ensure non-heterogenity. + if type_origin is Union: + python_type = _get_non_heterogenous_type(get_args(type_annotation)) + return get_spark_cast_statement_from_annotation( + element_name, python_type, parent_element, date_regex, timestamp_regex + ) + + # Type hint is e.g. `List[str]`, check to ensure non-heterogenity. + if type_origin is list or (isinstance(type_origin, type) and issubclass(type_origin, list)): + element_type = _get_non_heterogenous_type(get_args(type_annotation)) + stmt = f"transform({quoted_name}, x -> {get_spark_cast_statement_from_annotation('x',element_type, False, date_regex, timestamp_regex)})" # pylint: disable=C0301 + return stmt if not parent_element else _cast_as_spark_type(stmt, type_annotation) + + if type_origin is Annotated: + python_type, *_ = get_args(type_annotation) # pylint: disable=unused-variable + return get_spark_cast_statement_from_annotation( + element_name, python_type, parent_element, date_regex, timestamp_regex + ) # add other expected params here + # Ensure that we have a concrete type at this point. + if not isinstance(type_annotation, type): + raise ValueError(f"Unsupported type annotation {type_annotation!r}") + + if ( + # Type hint is a dict subclass, but not dict. Possibly a `TypedDict`. + (issubclass(type_annotation, dict) and type_annotation is not dict) + # Type hint is a dataclass. + or is_dataclass(type_annotation) + # Type hint is a `pydantic` model. + or (type_origin is None and issubclass(type_annotation, BaseModel)) + ): + fields: dict[str, str] = {} + for field_name, field_annotation in get_type_hints(type_annotation).items(): + # Technically non-string keys are disallowed, but people are bad. + if not isinstance(field_name, str): + raise ValueError( + f"Dictionary/Dataclass keys must be strings, got {type_annotation!r}" + ) # pragma: no cover + if get_origin(field_annotation) is ClassVar: + continue + + fields[field_name] = get_spark_cast_statement_from_annotation( + f"{element_name}.{field_name}", field_annotation, False, date_regex, timestamp_regex + ) + + if not fields: + raise ValueError( + f"No type annotations in dict/dataclass type (got {type_annotation!r})" + ) + cast_exprs = ",".join([f"{stmt} AS `{nme}`" for nme, stmt in fields.items()]) + stmt = f"struct({cast_exprs})" + return stmt if not parent_element else _cast_as_spark_type(stmt, type_annotation) + if type_annotation is list: + raise ValueError( + f"List must have type annotation (e.g. `List[str]`), got {type_annotation!r}" + ) + if type_annotation is dict or type_origin is dict: + raise ValueError(f"dict must be `typing.TypedDict` subclass, got {type_annotation!r}") + + for type_ in type_annotation.mro(): + # datetime is subclass of date, so needs to be handled first + if issubclass(type_, dt.datetime): + stmt = rf"CASE WHEN REGEXP(TRIM({quoted_name}), '{timestamp_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301 + return _cast_as_spark_type(stmt, type_) if parent_element else stmt + if issubclass(type_, dt.date): + stmt = rf"CASE WHEN REGEXP(TRIM({quoted_name}), '{date_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301 + return _cast_as_spark_type(stmt, type_) if parent_element else stmt + spark_type = get_type_from_annotation(type_) + if spark_type: + stmt = f"trim({quoted_name})" + return _cast_as_spark_type(stmt, type_) if parent_element else stmt + raise ValueError(f"No equivalent Spark type for {type_annotation!r}") diff --git a/src/dve/pipeline/foundry_ddb_pipeline.py b/src/dve/pipeline/foundry_ddb_pipeline.py index 3b0e55f..21cac56 100644 --- a/src/dve/pipeline/foundry_ddb_pipeline.py +++ b/src/dve/pipeline/foundry_ddb_pipeline.py @@ -42,13 +42,13 @@ def persist_audit_records(self, submission_info: SubmissionInfo) -> URI: write_to.parent.mkdir(parents=True, exist_ok=True) write_to = write_to.as_posix() self.write_parquet( # type: ignore # pylint: disable=E1101 - self._audit_tables._processing_status.get_relation().filter( # pylint: disable=W0212 + self._audit_tables._processing_status.get_relation().filter( # pylint: disable=W0212 f"submission_id = '{submission_info.submission_id}'" ), fh.joinuri(write_to, "processing_status.parquet"), ) self.write_parquet( # type: ignore # pylint: disable=E1101 - self._audit_tables._submission_statistics.get_relation().filter( # pylint: disable=W0212 + self._audit_tables._submission_statistics.get_relation().filter( # pylint: disable=W0212 f"submission_id = '{submission_info.submission_id}'" ), fh.joinuri(write_to, "submission_statistics.parquet"), diff --git a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py index 5c39e36..19e96e2 100644 --- a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py +++ b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py @@ -3,17 +3,73 @@ import datetime import tempfile from pathlib import Path -from typing import Any +from typing import Any, List import pytest import pyspark.sql.types as pst from duckdb import DuckDBPyRelation, DuckDBPyConnection +from pydantic import BaseModel from pyspark.sql import Row, SparkSession from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import ( _ddb_read_parquet, - duckdb_rel_to_dictionaries) + duckdb_rel_to_dictionaries, + get_duckdb_cast_statement_from_annotation, + get_duckdb_type_from_annotation) +@pytest.fixture +def casting_test_table(temp_ddb_conn): + _, conn = temp_ddb_conn + conn.sql("""CREATE TABLE test_casting ( + str_test VARCHAR, + int_test VARCHAR, + date_test VARCHAR, + timestamp_test VARCHAR, + list_int_field VARCHAR[], + basic_model STRUCT(str_field VARCHAR, date_field VARCHAR), + another_model STRUCT(unique_id VARCHAR, basic_models STRUCT(str_field VARCHAR, date_field VARCHAR)[]))""") + + conn.sql("""INSERT INTO test_casting + VALUES( + 'good_one', + '1', + '2024-11-13', + '2024-04-15 12:25:36', + ['1', '2', '3'], + {'str_field': 'test', 'date_field': '2024-12-11'}, + {'unique_id': '1', "basic_models": [{'str_field': 'test_nest', 'date_field': '2020-01-04'}, {'str_field': 'test_nest2', 'date_field': '2020-01-05'}]}), + ( + 'dodgy_dates', + '2', + '24-11-13', + '2024-4-15 12:25:36', + ['4', '5', '6'], + {'str_field': 'test', 'date_field': '202-1-11'}, + {'unique_id': '2', "basic_models": [{'str_field': 'test_dd', 'date_field': '20-01-04'}, {'str_field': 'test_dd2', 'date_field': '2020-1-5'}]})""") + + + yield temp_ddb_conn + + conn.sql("DROP TABLE IF EXISTS test_casting") + + + +class BasicModel(BaseModel): + str_field: str + date_field: datetime.date + +class AnotherModel(BaseModel): + unique_id: int + basic_models: List[BasicModel] + +class CastingRecord(BaseModel): + str_test: str + int_test: int + date_test: datetime.date + timestamp_test: datetime.datetime + list_int_field: list[int] + basic_model: BasicModel + another_model: AnotherModel class TempConnection: """ @@ -25,6 +81,7 @@ def __init__(self, connection: DuckDBPyConnection) -> None: self._connection = connection + @pytest.mark.parametrize( "outpath", [ @@ -94,4 +151,29 @@ def test_duckdb_rel_to_dictionaries(temp_ddb_conn: DuckDBPyConnection, res.append(chunk) assert res == data + +# add decimal check +@pytest.mark.parametrize("field_name,field_type,cast_statement", + [("str_test", str, "try_cast(trim(\"str_test\") as VARCHAR)"), + ("int_test", int, "try_cast(trim(\"int_test\") as BIGINT)"), + ("date_test", datetime.date,"CASE WHEN REGEXP_MATCHES(TRIM(\"date_test\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(\"date_test\") as DATE) ELSE NULL END"), + ("timestamp_test", datetime.datetime,"CASE WHEN REGEXP_MATCHES(TRIM(\"timestamp_test\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\+|\-)[0-9]{2}:[0-9]{2})?$') THEN TRY_CAST(TRIM(\"timestamp_test\") as TIMESTAMP) ELSE NULL END"), + ("list_int_field", list[int], "try_cast(list_transform(\"list_int_field\", x -> trim(\"x\")) as BIGINT[])"), + ("basic_model", BasicModel, "try_cast(struct_pack(\"str_field\":= trim(\"basic_model\".str_field),\"date_field\":= CASE WHEN REGEXP_MATCHES(TRIM(\"basic_model\".date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(\"basic_model\".date_field) as DATE) ELSE NULL END) as STRUCT(str_field VARCHAR, date_field DATE))"), + ("another_model", AnotherModel, "try_cast(struct_pack(\"unique_id\":= trim(\"another_model\".unique_id),\"basic_models\":= list_transform(\"another_model\".basic_models, x -> struct_pack(\"str_field\":= trim(\"x\".str_field),\"date_field\":= CASE WHEN REGEXP_MATCHES(TRIM(\"x\".date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(\"x\".date_field) as DATE) ELSE NULL END))) as STRUCT(unique_id BIGINT, basic_models STRUCT(str_field VARCHAR, date_field DATE)[]))")]) +def test_get_duckdb_cast_statement_from_annotation(field_name, field_type, cast_statement): + assert get_duckdb_cast_statement_from_annotation(field_name, field_type) == cast_statement + + +def test_use_cast_statements(casting_test_table): + _, conn = casting_test_table + test_rel = conn.sql("SELECT * from test_casting") + casting_statements = [ f"{get_duckdb_cast_statement_from_annotation(fld.name, fld.annotation)} as {fld.name}" for fld in CastingRecord.__fields__.values()] + test_rel = test_rel.project(",".join(casting_statements)) + assert dict(zip(test_rel.columns, test_rel.dtypes)) == {fld.name: get_duckdb_type_from_annotation(fld.annotation) for fld in CastingRecord.__fields__.values()} + dodgy_date_rec = test_rel.pl()[1].to_dicts()[0] + assert (not dodgy_date_rec.get("date_test") and + not dodgy_date_rec.get("basic_model",{}).get("date_field") + and all(not val.get("date_field") for val in dodgy_date_rec.get("another_model",{}).get("basic_models",[])) + ) diff --git a/tests/test_core_engine/test_spark_helpers.py b/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py similarity index 54% rename from tests/test_core_engine/test_spark_helpers.py rename to tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py index a3f167d..7502673 100644 --- a/tests/test_core_engine/test_spark_helpers.py +++ b/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py @@ -12,17 +12,56 @@ from pydantic.types import condecimal from pyspark.sql import DataFrame, SparkSession from pyspark.sql import types as st -from pyspark.sql.functions import col +from pyspark.sql.functions import col, expr +from pyspark.sql.types import ArrayType, DateType, LongType, StringType, StructField, StructType, TimestampType from typing_extensions import Annotated, TypedDict from dve.core_engine.backends.implementations.spark.spark_helpers import ( DecimalConfig, create_udf, + get_spark_cast_statement_from_annotation, get_type_from_annotation, object_to_spark_literal, ) -from ..fixtures import spark # pylint: disable=unused-import +from .....fixtures import spark # pylint: disable=unused-import + +@pytest.fixture +def casting_dataframe(spark): + data = [{"str_test": "good_one", "int_test": "1", "date_test": "2024-11-13", "timestamp_test": "2024-04-15 12:25:36", + "list_int_field":['1', '2', '3'], "basic_model": {'str_field': 'test', 'date_field': '2024-12-11'}, + "another_model": {'unique_id': '1', "basic_models": [{'str_field': 'test_nest', 'date_field': '2020-01-04'}, {'str_field': 'test_nest2', 'date_field': '2020-01-05'}]}}, + {"str_test": "dodgy_dates", "int_test": "2", "date_test": "24-11-13", "timestamp_test": "2024-4-15 12:25:36", + "list_int_field":['4', '5', '6'], "basic_model": {'str_field': 'test', 'date_field': '202-12-11'}, + "another_model": {'unique_id': '2', "basic_models": [{'str_field': 'test_dd', 'date_field': '20-01-04'}, {'str_field': 'test_dd2', 'date_field': '2020-1-05'}]}}] + + bm_schema = StructType([StructField("str_field", StringType()), StructField("date_field", StringType())]) + + schema = StructType([StructField("str_test", StringType()), StructField("int_test", StringType()), StructField("date_test", StringType()), + StructField("timestamp_test", StringType()), StructField("list_int_field", ArrayType(StringType())), + StructField("basic_model", bm_schema), + StructField("another_model", StructType([StructField("unique_id", StringType()), StructField("basic_models", ArrayType(bm_schema))]))]) + yield spark.createDataFrame(data, schema=schema) + + + + +class BasicModel(BaseModel): + str_field: str + date_field: dt.date + +class AnotherModel(BaseModel): + unique_id: int + basic_models: List[BasicModel] + +class CastingRecord(BaseModel): + str_test: str + int_test: int + date_test: dt.date + timestamp_test: dt.datetime + list_int_field: list[int] + basic_model: BasicModel + another_model: AnotherModel EXPECTED_STRUCT = st.StructType( [ @@ -203,3 +242,26 @@ def test_object_to_spark_literal_blocks_some_footguns(obj: Any): """ with pytest.raises(ValueError): object_to_spark_literal(obj) + +@pytest.mark.parametrize("field_name,field_type,expression,spark_type", + [("str_test", str, "trim(`str_test`)", StringType()), + ("int_test", int, "trim(`int_test`)", LongType()), + ("date_test", dt.date, "CASE WHEN REGEXP(TRIM(`date_test`), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRIM(`date_test`) ELSE NULL END", DateType()), + ("timestamp_test", dt.datetime, r"CASE WHEN REGEXP(TRIM(`timestamp_test`), '^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\\+|\\-)[0-9]{2}:[0-9]{2})?$') THEN TRIM(`timestamp_test`) ELSE NULL END", TimestampType()), + ("list_int_field", list[int], "transform(`list_int_field`, x -> trim(`x`))", ArrayType(LongType(), True)), + ("basic_model", BasicModel, "struct(trim(`basic_model`.str_field) as str_field, CASE WHEN REGEXP(TRIM(`basic_model`.date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRIM(`basic_model`.date_field) ELSE NULL END as date_field)", StructType([StructField("str_field", StringType(), True), StructField("date_field", DateType(), True)])), + ("another_model", AnotherModel, "struct(trim(`another_model`.unique_id) as unique_id, transform(`another_model`.basic_models, x -> struct(trim(x.str_field) as str_field, CASE WHEN REGEXP(TRIM(x.date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRIM(x.date_field) ELSE NULL END as date_field)) as basic_models)", StructType([StructField("unique_id", LongType(), True), StructField("basic_models", ArrayType(StructType([StructField("str_field", StringType()), StructField("date_field", DateType(), True)])))]))]) +def test_get_spark_cast_statement_from_annotation(field_name, field_type, expression, spark_type): + assert str(get_spark_cast_statement_from_annotation(field_name, field_type)) == str(expr(expression).cast(spark_type)) + + +def test_use_cast_statements(spark, casting_dataframe): + casting_statements = [ get_spark_cast_statement_from_annotation(fld.name, fld.annotation).alias(fld.name) for fld in CastingRecord.__fields__.values()] + cast_df = casting_dataframe.select(*casting_statements) + assert {fld.name: fld.dataType for fld in cast_df.schema} == {fld.name: get_type_from_annotation(fld.annotation) for fld in CastingRecord.__fields__.values()} + dodgy_date_rec = [rw.asDict(True) for rw in cast_df.collect()][1] + assert (not dodgy_date_rec.get("date_test") and + not dodgy_date_rec.get("basic_model",{}).get("date_field") + and all(not val.get("date_field") for val in dodgy_date_rec.get("another_model",{}).get("basic_models",[])) + ) + assert cast_df \ No newline at end of file