From 4b04898fc6607dff2d3aca3af04a51657a7af30e Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Tue, 10 Mar 2026 18:05:22 -0400 Subject: [PATCH 1/6] feat: normalize validator and constraint discriminators --- .../data_designer/config/column_configs.py | 13 ++++- .../config/data_designer_config.py | 4 +- .../config/sampler_constraints.py | 55 ++++++++++++------- .../data_designer/config/validator_params.py | 14 ++++- .../tests/config/test_columns.py | 23 ++++++++ .../tests/config/test_data_designer_config.py | 37 +++++++++++++ .../tests/config/test_sampler_constraints.py | 12 ++++ .../tests/config/test_validator_params.py | 8 ++- .../engine/sampling_gen/test_constraints.py | 14 ++++- 9 files changed, 154 insertions(+), 26 deletions(-) diff --git a/packages/data-designer-config/src/data_designer/config/column_configs.py b/packages/data-designer-config/src/data_designer/config/column_configs.py index 49dbb831..661bbfa2 100644 --- a/packages/data-designer-config/src/data_designer/config/column_configs.py +++ b/packages/data-designer-config/src/data_designer/config/column_configs.py @@ -412,7 +412,7 @@ class ValidationColumnConfig(SingleColumnConfig): target_columns: list[str] validator_type: ValidatorType - validator_params: ValidatorParamsT + validator_params: Annotated[ValidatorParamsT, Discriminator("validator_type")] batch_size: int = Field(default=10, ge=1, description="Number of records to process in each batch") column_type: Literal["validation"] = "validation" @@ -429,6 +429,17 @@ def required_columns(self) -> list[str]: def side_effect_columns(self) -> list[str]: return [] + @model_validator(mode="before") + @classmethod + def inject_validator_type_into_params(cls, data: dict) -> dict: + """Inject validator_type into validator_params for discriminated union resolution.""" + if isinstance(data, dict): + validator_type = data.get("validator_type") + validator_params = data.get("validator_params") + if validator_type and isinstance(validator_params, dict) and "validator_type" not in validator_params: + data["validator_params"] = {"validator_type": validator_type, **validator_params} + return data + class SeedDatasetColumnConfig(SingleColumnConfig): """Configuration for columns sourced from seed datasets. diff --git a/packages/data-designer-config/src/data_designer/config/data_designer_config.py b/packages/data-designer-config/src/data_designer/config/data_designer_config.py index e65a67de..0fc2a96e 100644 --- a/packages/data-designer-config/src/data_designer/config/data_designer_config.py +++ b/packages/data-designer-config/src/data_designer/config/data_designer_config.py @@ -13,7 +13,7 @@ from data_designer.config.mcp import ToolConfig from data_designer.config.models import ModelConfig from data_designer.config.processor_types import ProcessorConfigT -from data_designer.config.sampler_constraints import ColumnConstraintT +from data_designer.config.sampler_constraints import ColumnConstraintInputT from data_designer.config.seed import SeedConfig @@ -39,6 +39,6 @@ class DataDesignerConfig(ExportableConfigBase): model_configs: list[ModelConfig] | None = None tool_configs: list[ToolConfig] | None = None seed_config: SeedConfig | None = None - constraints: list[ColumnConstraintT] | None = None + constraints: list[ColumnConstraintInputT] | None = None profilers: list[ColumnProfilerConfigT] | None = None processors: list[Annotated[ProcessorConfigT, Field(discriminator="processor_type")]] | None = None diff --git a/packages/data-designer-config/src/data_designer/config/sampler_constraints.py b/packages/data-designer-config/src/data_designer/config/sampler_constraints.py index 86dc2c09..6688cf79 100644 --- a/packages/data-designer-config/src/data_designer/config/sampler_constraints.py +++ b/packages/data-designer-config/src/data_designer/config/sampler_constraints.py @@ -3,9 +3,10 @@ from __future__ import annotations -from abc import ABC, abstractmethod from enum import Enum +from typing import Annotated, Any, Literal +from pydantic import Discriminator, Field, Tag from typing_extensions import TypeAlias from data_designer.config.base import ConfigBase @@ -23,30 +24,46 @@ class InequalityOperator(str, Enum): GE = "ge" -class Constraint(ConfigBase, ABC): - target_column: str - - @property - @abstractmethod - def constraint_type(self) -> ConstraintType: ... +class Constraint(ConfigBase): + target_column: str = Field(description="Name of the sampler column this constraint applies to") + constraint_type: ConstraintType = Field(description="Constraint type discriminator") class ScalarInequalityConstraint(Constraint): - rhs: float - operator: InequalityOperator - - @property - def constraint_type(self) -> ConstraintType: - return ConstraintType.SCALAR_INEQUALITY + rhs: float = Field(description="Scalar value to compare against") + operator: InequalityOperator = Field(description="Comparison operator") + constraint_type: Literal[ConstraintType.SCALAR_INEQUALITY] = Field( + default=ConstraintType.SCALAR_INEQUALITY, + description="Constraint type discriminator, always 'scalar_inequality' for this constraint", + ) class ColumnInequalityConstraint(Constraint): - rhs: str - operator: InequalityOperator - - @property - def constraint_type(self) -> ConstraintType: - return ConstraintType.COLUMN_INEQUALITY + rhs: str = Field(description="Name of the other column to compare against") + operator: InequalityOperator = Field(description="Comparison operator") + constraint_type: Literal[ConstraintType.COLUMN_INEQUALITY] = Field( + default=ConstraintType.COLUMN_INEQUALITY, + description="Constraint type discriminator, always 'column_inequality' for this constraint", + ) ColumnConstraintT: TypeAlias = ScalarInequalityConstraint | ColumnInequalityConstraint + + +def resolve_constraint_input_type(value: Any) -> ConstraintType | str | None: + """Resolve the constraint type for both tagged and legacy config shapes.""" + if isinstance(value, dict): + if (constraint_type := value.get("constraint_type")) is not None: + return constraint_type + + rhs = value.get("rhs") + return ConstraintType.COLUMN_INEQUALITY if isinstance(rhs, str) else ConstraintType.SCALAR_INEQUALITY + + return getattr(value, "constraint_type", None) + + +ColumnConstraintInputT: TypeAlias = Annotated[ + Annotated[ScalarInequalityConstraint, Tag(ConstraintType.SCALAR_INEQUALITY)] + | Annotated[ColumnInequalityConstraint, Tag(ConstraintType.COLUMN_INEQUALITY)], + Discriminator(resolve_constraint_input_type), +] diff --git a/packages/data-designer-config/src/data_designer/config/validator_params.py b/packages/data-designer-config/src/data_designer/config/validator_params.py index e08c3186..6ef38c27 100644 --- a/packages/data-designer-config/src/data_designer/config/validator_params.py +++ b/packages/data-designer-config/src/data_designer/config/validator_params.py @@ -4,7 +4,7 @@ from __future__ import annotations from enum import Enum -from typing import Any +from typing import Any, Literal from pydantic import Field, field_serializer, model_validator from typing_extensions import Self, TypeAlias @@ -29,6 +29,10 @@ class CodeValidatorParams(ConfigBase): `sql:sqlite`, `sql:postgres`, `sql:mysql`, `sql:tsql`, `sql:bigquery`, `sql:ansi`. """ + validator_type: Literal[ValidatorType.CODE] = Field( + default=ValidatorType.CODE, + description="Validator type discriminator, always 'code' for this validator", + ) code_lang: CodeLang = Field(description="The language of the code to validate") @model_validator(mode="after") @@ -50,6 +54,10 @@ class LocalCallableValidatorParams(ConfigBase): the output will not be validated. """ + validator_type: Literal[ValidatorType.LOCAL_CALLABLE] = Field( + default=ValidatorType.LOCAL_CALLABLE, + description="Validator type discriminator, always 'local_callable' for this validator", + ) validation_function: Any = Field( description="Function (Callable[[pd.DataFrame], pd.DataFrame]) to validate the data" ) @@ -81,6 +89,10 @@ class RemoteValidatorParams(ConfigBase): max_parallel_requests: The maximum number of parallel requests to make. Defaults to 4. """ + validator_type: Literal[ValidatorType.REMOTE] = Field( + default=ValidatorType.REMOTE, + description="Validator type discriminator, always 'remote' for this validator", + ) endpoint_url: str = Field(description="URL of the remote endpoint") output_schema: dict[str, Any] | None = Field( default=None, description="Expected schema for remote validator's output" diff --git a/packages/data-designer-config/tests/config/test_columns.py b/packages/data-designer-config/tests/config/test_columns.py index 0740a277..5377cf26 100644 --- a/packages/data-designer-config/tests/config/test_columns.py +++ b/packages/data-designer-config/tests/config/test_columns.py @@ -267,6 +267,29 @@ def test_validation_column_config(): assert validation_column_config.batch_size == 5 +def test_validation_column_config_injects_validator_type_into_params_dict(): + validation_column_config = ValidationColumnConfig( + name="test_validation", + target_columns=["test_column"], + validator_type="code", + validator_params={"code_lang": "python"}, + ) + + assert isinstance(validation_column_config.validator_params, CodeValidatorParams) + assert validation_column_config.validator_params.validator_type == "code" + assert validation_column_config.validator_params.code_lang == CodeLang.PYTHON + + +def test_validation_column_config_schema_uses_validator_discriminator(): + schema = ValidationColumnConfig.model_json_schema() + validator_params = schema["properties"]["validator_params"] + + assert validator_params["discriminator"]["propertyName"] == "validator_type" + assert "code" in validator_params["discriminator"]["mapping"] + assert "local_callable" in validator_params["discriminator"]["mapping"] + assert "remote" in validator_params["discriminator"]["mapping"] + + def test_embedding_column_config(): embedding_column_config = EmbeddingColumnConfig( name="test_embedding", diff --git a/packages/data-designer-config/tests/config/test_data_designer_config.py b/packages/data-designer-config/tests/config/test_data_designer_config.py index 8fbb6956..2ba9d026 100644 --- a/packages/data-designer-config/tests/config/test_data_designer_config.py +++ b/packages/data-designer-config/tests/config/test_data_designer_config.py @@ -6,6 +6,8 @@ import yaml +from data_designer.config.data_designer_config import DataDesignerConfig + def test_data_designer_config_to_dict(stub_data_designer_config): assert isinstance(stub_data_designer_config.to_dict(), dict) @@ -27,3 +29,38 @@ def test_data_designer_config_to_json(stub_data_designer_config): assert result is None with open(tmp_file.name, "r") as f: assert json.loads(f.read()) == stub_data_designer_config.to_dict() + + +def test_data_designer_config_parses_constraint_type_from_legacy_shape(): + config = DataDesignerConfig.model_validate( + { + "columns": [ + { + "name": "age", + "column_type": "sampler", + "sampler_type": "uniform", + "params": {"low": 18, "high": 99}, + } + ], + "constraints": [ + {"target_column": "age", "operator": "lt", "rhs": 65}, + {"target_column": "age", "operator": "gt", "rhs": "minimum_age"}, + ], + } + ) + + serialized_constraints = [constraint.model_dump(mode="json") for constraint in config.constraints] + assert serialized_constraints == [ + { + "target_column": "age", + "operator": "lt", + "rhs": 65.0, + "constraint_type": "scalar_inequality", + }, + { + "target_column": "age", + "operator": "gt", + "rhs": "minimum_age", + "constraint_type": "column_inequality", + }, + ] diff --git a/packages/data-designer-config/tests/config/test_sampler_constraints.py b/packages/data-designer-config/tests/config/test_sampler_constraints.py index 3c127b8b..5e2d08a8 100644 --- a/packages/data-designer-config/tests/config/test_sampler_constraints.py +++ b/packages/data-designer-config/tests/config/test_sampler_constraints.py @@ -15,6 +15,12 @@ def test_scalar_inequality_constraint(): assert constraint.rhs == 1 assert constraint.operator == InequalityOperator.LT assert constraint.constraint_type == ConstraintType.SCALAR_INEQUALITY + assert constraint.model_dump() == { + "target_column": "test", + "rhs": 1.0, + "operator": "lt", + "constraint_type": "scalar_inequality", + } def test_column_inequality_constraint(): @@ -23,3 +29,9 @@ def test_column_inequality_constraint(): assert constraint.rhs == "test2" assert constraint.operator == InequalityOperator.LT assert constraint.constraint_type == ConstraintType.COLUMN_INEQUALITY + assert constraint.model_dump() == { + "target_column": "test", + "rhs": "test2", + "operator": "lt", + "constraint_type": "column_inequality", + } diff --git a/packages/data-designer-config/tests/config/test_validator_params.py b/packages/data-designer-config/tests/config/test_validator_params.py index ab74f14a..cfa1dc89 100644 --- a/packages/data-designer-config/tests/config/test_validator_params.py +++ b/packages/data-designer-config/tests/config/test_validator_params.py @@ -14,6 +14,7 @@ CodeValidatorParams, LocalCallableValidatorParams, RemoteValidatorParams, + ValidatorType, ) if TYPE_CHECKING: @@ -21,7 +22,9 @@ def test_code_validator_params(): - assert CodeValidatorParams(code_lang=CodeLang.PYTHON).code_lang == CodeLang.PYTHON + params = CodeValidatorParams(code_lang=CodeLang.PYTHON) + assert params.code_lang == CodeLang.PYTHON + assert params.validator_type == ValidatorType.CODE with pytest.raises(ValidationError): CodeValidatorParams(code_lang=CodeLang.RUBY) @@ -30,6 +33,7 @@ def test_code_validator_params(): def test_remote_validator_params(): stub_url = "https://example.com" params = RemoteValidatorParams(endpoint_url=stub_url) + assert params.validator_type == ValidatorType.REMOTE assert params.endpoint_url == stub_url assert params.output_schema is None assert params.timeout == 30.0 @@ -52,8 +56,10 @@ def stub_callback(df: pd.DataFrame) -> pd.DataFrame: return lazy.pd.DataFrame([{"is_valid": True, "confidence": "0.98"}]) params = LocalCallableValidatorParams(validation_function=stub_callback) + assert params.validator_type == ValidatorType.LOCAL_CALLABLE assert params.validation_function == stub_callback assert params.output_schema is None params_model_dump = params.model_dump(mode="json") + assert params_model_dump["validator_type"] == "local_callable" assert params_model_dump["validation_function"] == "stub_callback" diff --git a/packages/data-designer-engine/tests/engine/sampling_gen/test_constraints.py b/packages/data-designer-engine/tests/engine/sampling_gen/test_constraints.py index dbc26cbf..be29a5e0 100644 --- a/packages/data-designer-engine/tests/engine/sampling_gen/test_constraints.py +++ b/packages/data-designer-engine/tests/engine/sampling_gen/test_constraints.py @@ -81,7 +81,12 @@ def test_constraint_scenarios( "balance", "gt", 0, - {"target_column": "balance", "operator": "gt", "rhs": 0}, + { + "target_column": "balance", + "operator": "gt", + "rhs": 0, + "constraint_type": "scalar_inequality", + }, ), ( "column_inequality_serialization", @@ -89,7 +94,12 @@ def test_constraint_scenarios( "balance", "gt", "credit", - {"target_column": "balance", "operator": "gt", "rhs": "credit"}, + { + "target_column": "balance", + "operator": "gt", + "rhs": "credit", + "constraint_type": "column_inequality", + }, ), ], ) From b16e9705f209236d66ce58f4e045b3e574c8dbf9 Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Fri, 13 Mar 2026 15:15:56 -0400 Subject: [PATCH 2/6] docs: add docstring and comment to Constraint base class Address Greptile review feedback: - Add docstring to Constraint noting it should not be instantiated directly - Add comment explaining the rhs fallback behavior in the resolver --- .../src/data_designer/config/sampler_constraints.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/packages/data-designer-config/src/data_designer/config/sampler_constraints.py b/packages/data-designer-config/src/data_designer/config/sampler_constraints.py index 6688cf79..41121274 100644 --- a/packages/data-designer-config/src/data_designer/config/sampler_constraints.py +++ b/packages/data-designer-config/src/data_designer/config/sampler_constraints.py @@ -25,6 +25,8 @@ class InequalityOperator(str, Enum): class Constraint(ConfigBase): + """Base class for sampler constraints. Use a concrete subclass, not this class directly.""" + target_column: str = Field(description="Name of the sampler column this constraint applies to") constraint_type: ConstraintType = Field(description="Constraint type discriminator") @@ -56,6 +58,8 @@ def resolve_constraint_input_type(value: Any) -> ConstraintType | str | None: if (constraint_type := value.get("constraint_type")) is not None: return constraint_type + # rhs is required on both concrete types, so when it's missing we default to + # SCALAR_INEQUALITY — Pydantic will surface a clear "rhs field required" error. rhs = value.get("rhs") return ConstraintType.COLUMN_INEQUALITY if isinstance(rhs, str) else ConstraintType.SCALAR_INEQUALITY From 42b73b44d7003034394b108e3270d89349655c6c Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Fri, 13 Mar 2026 15:18:35 -0400 Subject: [PATCH 3/6] refactor: restore ABC on Constraint base class --- .../src/data_designer/config/sampler_constraints.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/data-designer-config/src/data_designer/config/sampler_constraints.py b/packages/data-designer-config/src/data_designer/config/sampler_constraints.py index 41121274..17c8eaef 100644 --- a/packages/data-designer-config/src/data_designer/config/sampler_constraints.py +++ b/packages/data-designer-config/src/data_designer/config/sampler_constraints.py @@ -3,6 +3,7 @@ from __future__ import annotations +from abc import ABC from enum import Enum from typing import Annotated, Any, Literal @@ -24,7 +25,7 @@ class InequalityOperator(str, Enum): GE = "ge" -class Constraint(ConfigBase): +class Constraint(ConfigBase, ABC): """Base class for sampler constraints. Use a concrete subclass, not this class directly.""" target_column: str = Field(description="Name of the sampler column this constraint applies to") From 61acfdf67113f309ad985e96705ef9f80f761711 Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Fri, 13 Mar 2026 15:20:34 -0400 Subject: [PATCH 4/6] refactor: add explicit None guard in constraint resolver --- .../src/data_designer/config/sampler_constraints.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/data-designer-config/src/data_designer/config/sampler_constraints.py b/packages/data-designer-config/src/data_designer/config/sampler_constraints.py index 17c8eaef..19058668 100644 --- a/packages/data-designer-config/src/data_designer/config/sampler_constraints.py +++ b/packages/data-designer-config/src/data_designer/config/sampler_constraints.py @@ -61,7 +61,8 @@ def resolve_constraint_input_type(value: Any) -> ConstraintType | str | None: # rhs is required on both concrete types, so when it's missing we default to # SCALAR_INEQUALITY — Pydantic will surface a clear "rhs field required" error. - rhs = value.get("rhs") + if (rhs := value.get("rhs")) is None: + return ConstraintType.SCALAR_INEQUALITY return ConstraintType.COLUMN_INEQUALITY if isinstance(rhs, str) else ConstraintType.SCALAR_INEQUALITY return getattr(value, "constraint_type", None) From 3305e4683eb1479b1f066f82b13fec8882584d82 Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Fri, 13 Mar 2026 15:36:29 -0400 Subject: [PATCH 5/6] Fix legacy numeric sampler constraint detection --- .../src/data_designer/config/sampler_constraints.py | 12 +++++++++++- .../tests/config/test_data_designer_config.py | 9 ++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/packages/data-designer-config/src/data_designer/config/sampler_constraints.py b/packages/data-designer-config/src/data_designer/config/sampler_constraints.py index 19058668..e4d3479b 100644 --- a/packages/data-designer-config/src/data_designer/config/sampler_constraints.py +++ b/packages/data-designer-config/src/data_designer/config/sampler_constraints.py @@ -63,11 +63,21 @@ def resolve_constraint_input_type(value: Any) -> ConstraintType | str | None: # SCALAR_INEQUALITY — Pydantic will surface a clear "rhs field required" error. if (rhs := value.get("rhs")) is None: return ConstraintType.SCALAR_INEQUALITY - return ConstraintType.COLUMN_INEQUALITY if isinstance(rhs, str) else ConstraintType.SCALAR_INEQUALITY + if isinstance(rhs, str): + return ConstraintType.SCALAR_INEQUALITY if _can_coerce_to_float(rhs) else ConstraintType.COLUMN_INEQUALITY + return ConstraintType.SCALAR_INEQUALITY return getattr(value, "constraint_type", None) +def _can_coerce_to_float(value: str) -> bool: + try: + float(value) + except ValueError: + return False + return True + + ColumnConstraintInputT: TypeAlias = Annotated[ Annotated[ScalarInequalityConstraint, Tag(ConstraintType.SCALAR_INEQUALITY)] | Annotated[ColumnInequalityConstraint, Tag(ConstraintType.COLUMN_INEQUALITY)], diff --git a/packages/data-designer-config/tests/config/test_data_designer_config.py b/packages/data-designer-config/tests/config/test_data_designer_config.py index 2ba9d026..0c8cf0bd 100644 --- a/packages/data-designer-config/tests/config/test_data_designer_config.py +++ b/packages/data-designer-config/tests/config/test_data_designer_config.py @@ -31,7 +31,7 @@ def test_data_designer_config_to_json(stub_data_designer_config): assert json.loads(f.read()) == stub_data_designer_config.to_dict() -def test_data_designer_config_parses_constraint_type_from_legacy_shape(): +def test_data_designer_config_parses_constraint_type_from_legacy_shape() -> None: config = DataDesignerConfig.model_validate( { "columns": [ @@ -44,6 +44,7 @@ def test_data_designer_config_parses_constraint_type_from_legacy_shape(): ], "constraints": [ {"target_column": "age", "operator": "lt", "rhs": 65}, + {"target_column": "age", "operator": "le", "rhs": "65"}, {"target_column": "age", "operator": "gt", "rhs": "minimum_age"}, ], } @@ -57,6 +58,12 @@ def test_data_designer_config_parses_constraint_type_from_legacy_shape(): "rhs": 65.0, "constraint_type": "scalar_inequality", }, + { + "target_column": "age", + "operator": "le", + "rhs": 65.0, + "constraint_type": "scalar_inequality", + }, { "target_column": "age", "operator": "gt", From 2170a2f7e16195d811bfe3e9f51d088df408813a Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Fri, 13 Mar 2026 17:23:49 -0400 Subject: [PATCH 6/6] fix: address PR review feedback from nabinchha - Guard _can_coerce_to_float against inf/nan strings - Add -> None return type annotations to test functions - Add clarifying comments to ColumnConstraintT vs ColumnConstraintInputT - Add tests for tagged constraint round-trip and missing rhs validation --- .../config/sampler_constraints.py | 7 ++- .../tests/config/test_columns.py | 4 +- .../tests/config/test_data_designer_config.py | 60 +++++++++++++++++++ 3 files changed, 67 insertions(+), 4 deletions(-) diff --git a/packages/data-designer-config/src/data_designer/config/sampler_constraints.py b/packages/data-designer-config/src/data_designer/config/sampler_constraints.py index e4d3479b..98eb3421 100644 --- a/packages/data-designer-config/src/data_designer/config/sampler_constraints.py +++ b/packages/data-designer-config/src/data_designer/config/sampler_constraints.py @@ -3,6 +3,7 @@ from __future__ import annotations +import math from abc import ABC from enum import Enum from typing import Annotated, Any, Literal @@ -50,6 +51,7 @@ class ColumnInequalityConstraint(Constraint): ) +# Plain union for engine-layer type hints on already-validated constraint instances. ColumnConstraintT: TypeAlias = ScalarInequalityConstraint | ColumnInequalityConstraint @@ -72,12 +74,13 @@ def resolve_constraint_input_type(value: Any) -> ConstraintType | str | None: def _can_coerce_to_float(value: str) -> bool: try: - float(value) + result = float(value) except ValueError: return False - return True + return math.isfinite(result) +# Discriminated union for deserialization — supports both tagged and legacy config shapes. ColumnConstraintInputT: TypeAlias = Annotated[ Annotated[ScalarInequalityConstraint, Tag(ConstraintType.SCALAR_INEQUALITY)] | Annotated[ColumnInequalityConstraint, Tag(ConstraintType.COLUMN_INEQUALITY)], diff --git a/packages/data-designer-config/tests/config/test_columns.py b/packages/data-designer-config/tests/config/test_columns.py index 5377cf26..239ff086 100644 --- a/packages/data-designer-config/tests/config/test_columns.py +++ b/packages/data-designer-config/tests/config/test_columns.py @@ -267,7 +267,7 @@ def test_validation_column_config(): assert validation_column_config.batch_size == 5 -def test_validation_column_config_injects_validator_type_into_params_dict(): +def test_validation_column_config_injects_validator_type_into_params_dict() -> None: validation_column_config = ValidationColumnConfig( name="test_validation", target_columns=["test_column"], @@ -280,7 +280,7 @@ def test_validation_column_config_injects_validator_type_into_params_dict(): assert validation_column_config.validator_params.code_lang == CodeLang.PYTHON -def test_validation_column_config_schema_uses_validator_discriminator(): +def test_validation_column_config_schema_uses_validator_discriminator() -> None: schema = ValidationColumnConfig.model_json_schema() validator_params = schema["properties"]["validator_params"] diff --git a/packages/data-designer-config/tests/config/test_data_designer_config.py b/packages/data-designer-config/tests/config/test_data_designer_config.py index 0c8cf0bd..0e12f45c 100644 --- a/packages/data-designer-config/tests/config/test_data_designer_config.py +++ b/packages/data-designer-config/tests/config/test_data_designer_config.py @@ -4,6 +4,7 @@ import json import tempfile +import pytest import yaml from data_designer.config.data_designer_config import DataDesignerConfig @@ -71,3 +72,62 @@ def test_data_designer_config_parses_constraint_type_from_legacy_shape() -> None "constraint_type": "column_inequality", }, ] + + +def test_data_designer_config_parses_constraint_type_from_tagged_shape() -> None: + config = DataDesignerConfig.model_validate( + { + "columns": [ + { + "name": "age", + "column_type": "sampler", + "sampler_type": "uniform", + "params": {"low": 18, "high": 99}, + } + ], + "constraints": [ + {"target_column": "age", "operator": "lt", "rhs": 65.0, "constraint_type": "scalar_inequality"}, + { + "target_column": "age", + "operator": "gt", + "rhs": "minimum_age", + "constraint_type": "column_inequality", + }, + ], + } + ) + + serialized_constraints = [constraint.model_dump(mode="json") for constraint in config.constraints] + assert serialized_constraints == [ + { + "target_column": "age", + "operator": "lt", + "rhs": 65.0, + "constraint_type": "scalar_inequality", + }, + { + "target_column": "age", + "operator": "gt", + "rhs": "minimum_age", + "constraint_type": "column_inequality", + }, + ] + + +def test_data_designer_config_constraint_missing_rhs_raises_validation_error() -> None: + with pytest.raises(Exception): + DataDesignerConfig.model_validate( + { + "columns": [ + { + "name": "age", + "column_type": "sampler", + "sampler_type": "uniform", + "params": {"low": 18, "high": 99}, + } + ], + "constraints": [ + {"target_column": "age", "operator": "lt"}, + ], + } + )