Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
405 changes: 324 additions & 81 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ openpyxl = "3.1.5"
pandas = "2.3.3"
polars = "0.20.31"
pyarrow = "17.0.0"
pydantic = "1.10.19"
pydantic = "2.13.4"
pyspark = "3.5.2"
typing_extensions = "4.15.0"

Expand Down Expand Up @@ -80,7 +80,8 @@ black = "24.3.0"
astroid = "3.3.9"
isort = "5.13.2"
pylint = "3.3.9"
mypy = "1.11.2"
mypy = "1.20.2"
librt = "0.11.0" # mypy dependency
boto3-stubs = {extras = ["essential"], version = "1.26.72"}
botocore-stubs = "1.29.72"
pandas-stubs = "1.2.0.62"
Expand Down
16 changes: 8 additions & 8 deletions src/dve/core_engine/backends/base/auditing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from types import TracebackType
from typing import Any, ClassVar, Generic, Optional, TypeVar, Union

from pydantic import ValidationError, validate_arguments
from pydantic import ValidationError, validate_call
from typing_extensions import Literal, get_origin

from dve.core_engine.models import (
Expand Down Expand Up @@ -98,8 +98,8 @@ def __init__(self, name: str, record_type: type[AuditRecord]):
def schema(self) -> dict[str, type]:
"""Determine python schema of auditor"""
return {
fld: str if get_origin(mdl.type_) == Literal else mdl.type_
for fld, mdl in self._record_type.__fields__.items()
fld: str if get_origin(mdl.annotation) == Literal else mdl.annotation # type: ignore
for fld, mdl in self._record_type.model_fields.items()
}

@staticmethod
Expand Down Expand Up @@ -195,7 +195,7 @@ def conv_to_iterable(recs: Union[AuditorType, AuditReturnType]) -> Iterable[dict
"""Convert AuditReturnType to iterable of dictionaries"""
raise NotImplementedError()

@validate_arguments
@validate_call
def add_processing_records(self, processing_records: list[ProcessingStatusRecord]):
"""Add an entry to the processing_status auditor."""
if self.pool:
Expand All @@ -207,7 +207,7 @@ def add_processing_records(self, processing_records: list[ProcessingStatusRecord
records=[dict(rec) for rec in processing_records]
)

@validate_arguments
@validate_call
def add_submission_statistics_records(self, sub_stats: list[SubmissionStatisticsRecord]):
"""Add an entry to the submission statistics auditor."""
if self.pool:
Expand All @@ -217,7 +217,7 @@ def add_submission_statistics_records(self, sub_stats: list[SubmissionStatistics
)
return self._submission_statistics.add_records(records=[dict(rec) for rec in sub_stats])

@validate_arguments
@validate_call
def add_transfer_records(self, transfer_records: list[TransferRecord]):
"""Add an entry to the transfers auditor"""
if self.pool:
Expand All @@ -226,7 +226,7 @@ def add_transfer_records(self, transfer_records: list[TransferRecord]):
)
return self._transfers.add_records(records=[dict(rec) for rec in transfer_records])

@validate_arguments
@validate_call
def add_new_submissions(
self,
submissions: list[SubmissionMetadata],
Expand All @@ -249,7 +249,7 @@ def add_new_submissions(
processing_status="received",
job_run_id=job_run_id,
**ts_info,
).dict(),
).model_dump(),
}
processing_status_recs.append(processing_rec)
if sub_info:
Expand Down
16 changes: 8 additions & 8 deletions src/dve/core_engine/backends/implementations/duckdb/auditing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
OrderCriteria,
)
from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import (
PYTHON_TYPE_TO_DUCKDB_TYPE,
get_duckdb_type_from_annotation,
table_exists,
)
from dve.core_engine.backends.utilities import PYTHON_TYPE_TO_POLARS_TYPE
from dve.core_engine.backends.utilities import get_polars_type_from_annotation
from dve.core_engine.models import (
AuditRecord,
ProcessingStatusRecord,
Expand Down Expand Up @@ -62,18 +62,18 @@ def ddb_create_table_sql(self) -> str:
"""Generate create table sql script for auditor"""
_sql_expression = f"CREATE TABLE {self._name} ("
_sql_expression += ", ".join(
[f"{fld} {PYTHON_TYPE_TO_DUCKDB_TYPE.get(dtype)}" for fld, dtype in self.schema.items()]
[
f"{fld} {get_duckdb_type_from_annotation(dtype)}"
for fld, dtype in self.schema.items()
]
)
_sql_expression += ")"
return _sql_expression

@property
def polars_schema(self) -> dict[str, PolarsType]:
"""Get polars dataframe schema for auditor"""
return {
fld: PYTHON_TYPE_TO_POLARS_TYPE.get(dtype, pl.Utf8) # type: ignore
for fld, dtype in self.schema.items()
}
return {fld: get_polars_type_from_annotation(dtype) for fld, dtype in self.schema.items()}

def get_relation(self) -> DuckDBPyRelation:
"""Get a relation to interact with the auditor duckdb table"""
Expand Down Expand Up @@ -106,7 +106,7 @@ def conv_to_entity(self, recs: list[AuditRecord]) -> DuckDBPyRelation:
"""Convert a list of audit records to a relation"""
# pylint: disable=W0612
rec_df = pl.DataFrame( # type: ignore
[rec.dict() for rec in recs],
[rec.model_dump() for rec in recs],
schema=self.polars_schema,
)
return self._connection.sql("select * from rec_df")
Expand Down
17 changes: 7 additions & 10 deletions src/dve/core_engine/backends/implementations/duckdb/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from duckdb.typing import DuckDBPyType
from polars.datatypes.classes import DataTypeClass as PolarsType
from pydantic import BaseModel
from pydantic.fields import ModelField

import dve.parser.file_handling as fh
from dve.common.error_utils import (
Expand Down Expand Up @@ -96,8 +95,8 @@ def create_entity_from_py_iterator( # pylint: disable=unused-argument
) -> DuckDBPyRelation:
"""Create DuckDB Relation from iterator of records"""
polars_schema: dict[str, PolarsType] = {
fld.name: get_polars_type_from_annotation(fld.type_)
for fld in stringify_model(schema).__fields__.values()
name: get_polars_type_from_annotation(fld.annotation)
for name, fld in stringify_model(schema).model_fields.items()
}
_lazy_df = pl.LazyFrame(records, polars_schema) # type: ignore # pylint: disable=unused-variable
return self._connection.sql("select * from _lazy_df")
Expand Down Expand Up @@ -130,17 +129,15 @@ def apply_data_contract(
) as msg_writer:
for entity_name, relation in entities.items():
# get dtypes for all fields -> python data types or use with relation
entity_fields: dict[str, ModelField] = contract_metadata.schemas[
entity_name
].__fields__
entity_fields = contract_metadata.schemas[entity_name].model_fields
ddb_schema: dict[str, DuckDBPyType] = {
fld.name: get_duckdb_type_from_annotation(fld.annotation)
for fld in entity_fields.values()
name: get_duckdb_type_from_annotation(fld.annotation)
for name, fld in entity_fields.items()
}
ddb_schema[RECORD_INDEX_COLUMN_NAME] = get_duckdb_type_from_annotation(int)
polars_schema: dict[str, PolarsType] = {
fld.name: get_polars_type_from_annotation(fld.annotation)
for fld in entity_fields.values()
name: get_polars_type_from_annotation(fld.annotation)
for name, fld in entity_fields.items()
}
polars_schema[RECORD_INDEX_COLUMN_NAME] = get_polars_type_from_annotation(int)
if relation_is_empty(relation):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from datetime import date, datetime, time
from decimal import Decimal
from pathlib import Path
from typing import Any, ClassVar, Union
from typing import Any, ClassVar, Literal, Union
from urllib.parse import urlparse

import duckdb.typing as ddbtyp
Expand Down Expand Up @@ -125,8 +125,8 @@ def get_duckdb_type_from_annotation(type_annotation: Any) -> DuckDBPyType:
'optional' wrapper and return the inner type
- A subclass of `typing.TypedDict` with values typed using supported types. This
will parse the value types as Polars types and return a duckdb STRUCT.
- A dataclass or `pydantic.main.ModelMetaClass` with values typed using supported types.
This will parse the field types as Polars types and return a duckdb STRUCT.
- A dataclass or `pydantic.BaseModel` with values typed using supported types.
This will parse the field types as duckdb types and return a duckdb STRUCT.
- Any supported type, with a `typing_extensions.Annotated` wrapper.

Any `ClassVar` types within `TypedDict`s, dataclasses, or `pydantic` models will be
Expand All @@ -135,6 +135,14 @@ def get_duckdb_type_from_annotation(type_annotation: Any) -> DuckDBPyType:
"""
type_origin = get_origin(type_annotation)

if type_origin is Literal:
ddb_types = [get_duckdb_type_from_annotation(type(t)) for t in get_args(type_annotation)]
if not ddb_types or not all(t == ddb_types[0] for t in ddb_types):
raise ValueError(
f"Unable to determine a single concrete type for Literal. Got {type_annotation!r}"
)
return ddb_types[0]

# An `Optional` or `Union` type, check to ensure non-heterogenity.
if type_origin is Union:
python_type = _get_non_heterogenous_type(get_args(type_annotation))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ def read_to_relation( # pylint: disable=unused-argument
}

ddb_schema: dict[str, SQLType] = {
fld.name: str(get_duckdb_type_from_annotation(fld.annotation)) # type: ignore
for fld in schema.__fields__.values()
name: str(get_duckdb_type_from_annotation(fld.annotation)) # type: ignore
for name, fld in schema.model_fields.items()
}

reader_options["columns"] = ddb_schema
Expand Down Expand Up @@ -154,8 +154,8 @@ def read_to_relation( # pylint: disable=unused-argument
}

polars_types = {
fld.name: get_polars_type_from_annotation(fld.annotation) # type: ignore
for fld in schema.__fields__.values()
name: get_polars_type_from_annotation(fld.annotation) # type: ignore
for name, fld in schema.model_fields.items()
}
reader_options["dtypes"] = polars_types

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def read_to_relation( # pylint: disable=unused-argument
"""Returns a relation object from the source json"""

ddb_schema: dict[str, SQLType] = {
fld.name: str(get_duckdb_type_from_annotation(fld.annotation)) # type: ignore
for fld in schema.__fields__.values()
name: str(get_duckdb_type_from_annotation(fld.annotation)) # type: ignore
for name, fld in schema.model_fields.items()
}

return self.add_record_index(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def read_to_relation(self, resource: URI, entity_name: str, schema: type[BaseMod
)

polars_schema: dict[str, pl.DataType] = { # type: ignore
fld.name: get_polars_type_from_annotation(fld.annotation)
for fld in stringify_model(schema).__fields__.values()
name: get_polars_type_from_annotation(fld.annotation)
for name, fld in stringify_model(schema).model_fields.items()
}

_lazy_frame = self.add_record_index(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def conv_to_records(self, recs: DataFrame) -> Iterable[AuditRecord]:
def conv_to_entity(self, recs: list[AuditRecord]) -> DataFrame:
"""Convert the dataframe to an iterable of the related audit record"""
return self._spark.createDataFrame( # type: ignore
[rec.dict() for rec in recs], schema=self.spark_schema
[rec.model_dump() for rec in recs], schema=self.spark_schema
)

def add_records(self, records: Iterable[dict[str, Any]]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,5 @@ def write_entities_to_parquet(

def convert_submission_info(self, submission_info: SubmissionInfo) -> DataFrame:
return self.spark_session.createDataFrame( # type: ignore
[submission_info.dict()], schema=get_type_from_annotation(type(submission_info))
[submission_info.model_dump()], schema=get_type_from_annotation(type(submission_info))
)
49 changes: 28 additions & 21 deletions src/dve/core_engine/backends/implementations/spark/spark_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
from dataclasses import dataclass, is_dataclass
from decimal import Decimal
from functools import wraps
from typing import Any, ClassVar, Optional, TypeVar, Union, overload
from typing import Any, ClassVar, Literal, Optional, TypeVar, Union, overload

from delta.exceptions import ConcurrentAppendException, DeltaConcurrentModificationException
from pydantic import BaseModel
from pydantic.types import ConstrainedDecimal
from pyspark.sql import DataFrame, Row, SparkSession
from pyspark.sql import functions as sf
from pyspark.sql import types as st
Expand Down Expand Up @@ -49,6 +48,7 @@
"""A wrapped function (Spark UDF) taking four args."""


# TODO - lets see if we can bin this off as it's a bit overkill

Check warning on line 51 in src/dve/core_engine/backends/implementations/spark/spark_helpers.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Complete the task associated to this "TODO" comment.

See more on https://sonarcloud.io/project/issues?id=NHSDigital_data-validation-engine&issues=AZ5k0gEnYz4-StK4c_jY&open=AZ5k0gEnYz4-StK4c_jY&pullRequest=107
@dataclass(frozen=True)
class DecimalConfig:
"""Configuration for a Python decimal to enable it to be mapped to a
Expand All @@ -61,24 +61,26 @@

"""

precision: int = 38
max_digits: int = 38
"""
The precision of the decimal. This is the total number of digits in the
decimal.

"""
scale: int = 18
decimal_places: int = 18
"""
The scale of the decimal. This is the number of digits to the right of the
decimal point.

"""

def __post_init__(self):
if not 0 < self.precision <= 38:
raise ValueError("Precision must be between 1 and 38 (inclusive)")
if not 0 <= self.scale <= self.precision:
raise ValueError("Scale must be between 0 and the precision (inclusive)")
if not 0 < self.max_digits <= 38:
raise ValueError("Max digits must be between 1 and 38 (inclusive)")
if not 0 <= self.decimal_places <= self.max_digits:
raise ValueError(
"Decimal Places must be between 0 and the specified number of digits (inclusive)"
)


DEFAULT_DECIMAL_CONFIG = DecimalConfig()
Expand All @@ -93,7 +95,9 @@
bytes: st.BinaryType(),
dt.date: st.DateType(),
dt.datetime: st.TimestampType(),
Decimal: st.DecimalType(DEFAULT_DECIMAL_CONFIG.precision, DEFAULT_DECIMAL_CONFIG.scale),
Decimal: st.DecimalType(
DEFAULT_DECIMAL_CONFIG.max_digits, DEFAULT_DECIMAL_CONFIG.decimal_places
),
}
"""A mapping of Python types to the equivalent Spark types."""

Expand Down Expand Up @@ -146,7 +150,7 @@
'optional' wrapper and return the inner type (Spark types are all nullable).
- A subclass of `typing.TypedDict` with values typed using supported types. This
will parse the value types as Spark types and return a Spark `StructType`.
- A dataclass or `pydantic.main.ModelMetaClass` with values typed using supported types.
- A dataclass or `pydantic.BaseModel` with values typed using supported types.
This will parse the field types as Spark types and return a Spark `StructType`.
- Any supported type, with a `typing_extensions.Annotated` wrapper.
- A `decimal.Decimal` wrapped with `typing_extensions.Annotated` with a `DecimalConfig`
Expand All @@ -160,6 +164,14 @@
"""
type_origin = get_origin(type_annotation)

if type_origin is Literal:
types = [get_type_from_annotation(type(t)) for t in get_args(type_annotation)]
if not types or not all(t == types[0] for t in types):
raise ValueError(
f"Unable to determine a single concrete type for Literal. Got {type_annotation!r}"
)
return types[0]

# An `Optional` or `Union` type, check to ensure non-heterogenity.
if type_origin is Union:
python_type = _get_non_heterogenous_type(get_args(type_annotation))
Expand All @@ -176,13 +188,13 @@
if python_type is not Decimal:
return get_type_from_annotation(python_type)

try: # Grab the decimal configuration from the list of other args.
configuration: DecimalConfig = next(
filter(lambda config: isinstance(config, DecimalConfig), other_args)
)
except StopIteration:
config_options = [arg for arg in other_args if hasattr(arg, "max_digits")]
if config_options:
configuration = config_options[0]
else:
configuration = DEFAULT_DECIMAL_CONFIG
return st.DecimalType(configuration.precision, configuration.scale)

return st.DecimalType(configuration.max_digits, configuration.decimal_places)

# Ensure that we have a concrete type at this point.
if not isinstance(type_annotation, type):
Expand Down Expand Up @@ -216,11 +228,6 @@

return st.StructType(fields)

if issubclass(type_annotation, ConstrainedDecimal):
precision = int(type_annotation.max_digits or 38)
scale = int(type_annotation.decimal_places or precision)
return st.DecimalType(precision, scale)

if type_annotation is list:
raise ValueError(
f"list must have type annotation (e.g. `list[str]`), got {type_annotation!r}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class PydanticCompatibleJSONEncoder(JSONEncoder):
def default(self, o: Any) -> Any:
"""Sets the format for given types for json encoding"""
if isinstance(o, BaseModel):
return o.dict()
return o.model_dump()
if isinstance(o, dt.date):
return o.isoformat()
return super().default(o)
Expand Down
Loading
Loading