Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ class ValidationColumnConfig(SingleColumnConfig):

target_columns: list[str]
validator_type: ValidatorType
validator_params: ValidatorParamsT
validator_params: Annotated[ValidatorParamsT, Discriminator("validator_type")]
batch_size: int = Field(default=10, ge=1, description="Number of records to process in each batch")
column_type: Literal["validation"] = "validation"

Expand All @@ -429,6 +429,17 @@ def required_columns(self) -> list[str]:
def side_effect_columns(self) -> list[str]:
return []

@model_validator(mode="before")
@classmethod
def inject_validator_type_into_params(cls, data: dict) -> dict:
"""Inject validator_type into validator_params for discriminated union resolution."""
if isinstance(data, dict):
validator_type = data.get("validator_type")
validator_params = data.get("validator_params")
if validator_type and isinstance(validator_params, dict) and "validator_type" not in validator_params:
data["validator_params"] = {"validator_type": validator_type, **validator_params}
return data


class SeedDatasetColumnConfig(SingleColumnConfig):
"""Configuration for columns sourced from seed datasets.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from data_designer.config.mcp import ToolConfig
from data_designer.config.models import ModelConfig
from data_designer.config.processor_types import ProcessorConfigT
from data_designer.config.sampler_constraints import ColumnConstraintT
from data_designer.config.sampler_constraints import ColumnConstraintInputT
from data_designer.config.seed import SeedConfig


Expand All @@ -39,6 +39,6 @@ class DataDesignerConfig(ExportableConfigBase):
model_configs: list[ModelConfig] | None = None
tool_configs: list[ToolConfig] | None = None
seed_config: SeedConfig | None = None
constraints: list[ColumnConstraintT] | None = None
constraints: list[ColumnConstraintInputT] | None = None
profilers: list[ColumnProfilerConfigT] | None = None
processors: list[Annotated[ProcessorConfigT, Field(discriminator="processor_type")]] | None = None
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@

from __future__ import annotations

from abc import ABC, abstractmethod
import math
from abc import ABC
from enum import Enum
from typing import Annotated, Any, Literal

from pydantic import Discriminator, Field, Tag
from typing_extensions import TypeAlias

from data_designer.config.base import ConfigBase
Expand All @@ -24,29 +27,62 @@ class InequalityOperator(str, Enum):


class Constraint(ConfigBase, ABC):
target_column: str
"""Base class for sampler constraints. Use a concrete subclass, not this class directly."""

@property
@abstractmethod
def constraint_type(self) -> ConstraintType: ...
target_column: str = Field(description="Name of the sampler column this constraint applies to")
constraint_type: ConstraintType = Field(description="Constraint type discriminator")


class ScalarInequalityConstraint(Constraint):
rhs: float
operator: InequalityOperator
rhs: float = Field(description="Scalar value to compare against")
operator: InequalityOperator = Field(description="Comparison operator")
constraint_type: Literal[ConstraintType.SCALAR_INEQUALITY] = Field(
default=ConstraintType.SCALAR_INEQUALITY,
description="Constraint type discriminator, always 'scalar_inequality' for this constraint",
)

@property
def constraint_type(self) -> ConstraintType:

class ColumnInequalityConstraint(Constraint):
rhs: str = Field(description="Name of the other column to compare against")
operator: InequalityOperator = Field(description="Comparison operator")
constraint_type: Literal[ConstraintType.COLUMN_INEQUALITY] = Field(
default=ConstraintType.COLUMN_INEQUALITY,
description="Constraint type discriminator, always 'column_inequality' for this constraint",
)


# Plain union for engine-layer type hints on already-validated constraint instances.
ColumnConstraintT: TypeAlias = ScalarInequalityConstraint | ColumnInequalityConstraint


def resolve_constraint_input_type(value: Any) -> ConstraintType | str | None:
"""Resolve the constraint type for both tagged and legacy config shapes."""
if isinstance(value, dict):
if (constraint_type := value.get("constraint_type")) is not None:
return constraint_type

# rhs is required on both concrete types, so when it's missing we default to
# SCALAR_INEQUALITY β€” Pydantic will surface a clear "rhs field required" error.
if (rhs := value.get("rhs")) is None:
return ConstraintType.SCALAR_INEQUALITY
if isinstance(rhs, str):
return ConstraintType.SCALAR_INEQUALITY if _can_coerce_to_float(rhs) else ConstraintType.COLUMN_INEQUALITY
return ConstraintType.SCALAR_INEQUALITY

return getattr(value, "constraint_type", None)

class ColumnInequalityConstraint(Constraint):
rhs: str
operator: InequalityOperator

@property
def constraint_type(self) -> ConstraintType:
return ConstraintType.COLUMN_INEQUALITY
def _can_coerce_to_float(value: str) -> bool:
try:
result = float(value)
except ValueError:
return False
return math.isfinite(result)


ColumnConstraintT: TypeAlias = ScalarInequalityConstraint | ColumnInequalityConstraint
# Discriminated union for deserialization β€” supports both tagged and legacy config shapes.
ColumnConstraintInputT: TypeAlias = Annotated[
Annotated[ScalarInequalityConstraint, Tag(ConstraintType.SCALAR_INEQUALITY)]
| Annotated[ColumnInequalityConstraint, Tag(ConstraintType.COLUMN_INEQUALITY)],
Discriminator(resolve_constraint_input_type),
]
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

from enum import Enum
from typing import Any
from typing import Any, Literal

from pydantic import Field, field_serializer, model_validator
from typing_extensions import Self, TypeAlias
Expand All @@ -29,6 +29,10 @@ class CodeValidatorParams(ConfigBase):
`sql:sqlite`, `sql:postgres`, `sql:mysql`, `sql:tsql`, `sql:bigquery`, `sql:ansi`.
"""

validator_type: Literal[ValidatorType.CODE] = Field(
default=ValidatorType.CODE,
description="Validator type discriminator, always 'code' for this validator",
)
code_lang: CodeLang = Field(description="The language of the code to validate")

@model_validator(mode="after")
Expand All @@ -50,6 +54,10 @@ class LocalCallableValidatorParams(ConfigBase):
the output will not be validated.
"""

validator_type: Literal[ValidatorType.LOCAL_CALLABLE] = Field(
default=ValidatorType.LOCAL_CALLABLE,
description="Validator type discriminator, always 'local_callable' for this validator",
)
validation_function: Any = Field(
description="Function (Callable[[pd.DataFrame], pd.DataFrame]) to validate the data"
)
Expand Down Expand Up @@ -81,6 +89,10 @@ class RemoteValidatorParams(ConfigBase):
max_parallel_requests: The maximum number of parallel requests to make. Defaults to 4.
"""

validator_type: Literal[ValidatorType.REMOTE] = Field(
default=ValidatorType.REMOTE,
description="Validator type discriminator, always 'remote' for this validator",
)
endpoint_url: str = Field(description="URL of the remote endpoint")
output_schema: dict[str, Any] | None = Field(
default=None, description="Expected schema for remote validator's output"
Expand Down
23 changes: 23 additions & 0 deletions packages/data-designer-config/tests/config/test_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,29 @@ def test_validation_column_config():
assert validation_column_config.batch_size == 5


def test_validation_column_config_injects_validator_type_into_params_dict() -> None:
validation_column_config = ValidationColumnConfig(
name="test_validation",
target_columns=["test_column"],
validator_type="code",
validator_params={"code_lang": "python"},
)

assert isinstance(validation_column_config.validator_params, CodeValidatorParams)
assert validation_column_config.validator_params.validator_type == "code"
assert validation_column_config.validator_params.code_lang == CodeLang.PYTHON


def test_validation_column_config_schema_uses_validator_discriminator() -> None:
schema = ValidationColumnConfig.model_json_schema()
validator_params = schema["properties"]["validator_params"]

assert validator_params["discriminator"]["propertyName"] == "validator_type"
assert "code" in validator_params["discriminator"]["mapping"]
assert "local_callable" in validator_params["discriminator"]["mapping"]
assert "remote" in validator_params["discriminator"]["mapping"]


def test_embedding_column_config():
embedding_column_config = EmbeddingColumnConfig(
name="test_embedding",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
import json
import tempfile

import pytest
import yaml

from data_designer.config.data_designer_config import DataDesignerConfig


def test_data_designer_config_to_dict(stub_data_designer_config):
assert isinstance(stub_data_designer_config.to_dict(), dict)
Expand All @@ -27,3 +30,104 @@ def test_data_designer_config_to_json(stub_data_designer_config):
assert result is None
with open(tmp_file.name, "r") as f:
assert json.loads(f.read()) == stub_data_designer_config.to_dict()


def test_data_designer_config_parses_constraint_type_from_legacy_shape() -> None:
config = DataDesignerConfig.model_validate(
{
"columns": [
{
"name": "age",
"column_type": "sampler",
"sampler_type": "uniform",
"params": {"low": 18, "high": 99},
}
],
"constraints": [
{"target_column": "age", "operator": "lt", "rhs": 65},
{"target_column": "age", "operator": "le", "rhs": "65"},
{"target_column": "age", "operator": "gt", "rhs": "minimum_age"},
],
}
)

serialized_constraints = [constraint.model_dump(mode="json") for constraint in config.constraints]
assert serialized_constraints == [
{
"target_column": "age",
"operator": "lt",
"rhs": 65.0,
"constraint_type": "scalar_inequality",
},
{
"target_column": "age",
"operator": "le",
"rhs": 65.0,
"constraint_type": "scalar_inequality",
},
{
"target_column": "age",
"operator": "gt",
"rhs": "minimum_age",
"constraint_type": "column_inequality",
},
]


def test_data_designer_config_parses_constraint_type_from_tagged_shape() -> None:
config = DataDesignerConfig.model_validate(
{
"columns": [
{
"name": "age",
"column_type": "sampler",
"sampler_type": "uniform",
"params": {"low": 18, "high": 99},
}
],
"constraints": [
{"target_column": "age", "operator": "lt", "rhs": 65.0, "constraint_type": "scalar_inequality"},
{
"target_column": "age",
"operator": "gt",
"rhs": "minimum_age",
"constraint_type": "column_inequality",
},
],
}
)

serialized_constraints = [constraint.model_dump(mode="json") for constraint in config.constraints]
assert serialized_constraints == [
{
"target_column": "age",
"operator": "lt",
"rhs": 65.0,
"constraint_type": "scalar_inequality",
},
{
"target_column": "age",
"operator": "gt",
"rhs": "minimum_age",
"constraint_type": "column_inequality",
},
]


def test_data_designer_config_constraint_missing_rhs_raises_validation_error() -> None:
with pytest.raises(Exception):
DataDesignerConfig.model_validate(
{
"columns": [
{
"name": "age",
"column_type": "sampler",
"sampler_type": "uniform",
"params": {"low": 18, "high": 99},
}
],
"constraints": [
{"target_column": "age", "operator": "lt"},
],
}
)
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ def test_scalar_inequality_constraint():
assert constraint.rhs == 1
assert constraint.operator == InequalityOperator.LT
assert constraint.constraint_type == ConstraintType.SCALAR_INEQUALITY
assert constraint.model_dump() == {
"target_column": "test",
"rhs": 1.0,
"operator": "lt",
"constraint_type": "scalar_inequality",
}


def test_column_inequality_constraint():
Expand All @@ -23,3 +29,9 @@ def test_column_inequality_constraint():
assert constraint.rhs == "test2"
assert constraint.operator == InequalityOperator.LT
assert constraint.constraint_type == ConstraintType.COLUMN_INEQUALITY
assert constraint.model_dump() == {
"target_column": "test",
"rhs": "test2",
"operator": "lt",
"constraint_type": "column_inequality",
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@
CodeValidatorParams,
LocalCallableValidatorParams,
RemoteValidatorParams,
ValidatorType,
)

if TYPE_CHECKING:
import pandas as pd


def test_code_validator_params():
assert CodeValidatorParams(code_lang=CodeLang.PYTHON).code_lang == CodeLang.PYTHON
params = CodeValidatorParams(code_lang=CodeLang.PYTHON)
assert params.code_lang == CodeLang.PYTHON
assert params.validator_type == ValidatorType.CODE

with pytest.raises(ValidationError):
CodeValidatorParams(code_lang=CodeLang.RUBY)
Expand All @@ -30,6 +33,7 @@ def test_code_validator_params():
def test_remote_validator_params():
stub_url = "https://example.com"
params = RemoteValidatorParams(endpoint_url=stub_url)
assert params.validator_type == ValidatorType.REMOTE
assert params.endpoint_url == stub_url
assert params.output_schema is None
assert params.timeout == 30.0
Expand All @@ -52,8 +56,10 @@ def stub_callback(df: pd.DataFrame) -> pd.DataFrame:
return lazy.pd.DataFrame([{"is_valid": True, "confidence": "0.98"}])

params = LocalCallableValidatorParams(validation_function=stub_callback)
assert params.validator_type == ValidatorType.LOCAL_CALLABLE
assert params.validation_function == stub_callback
assert params.output_schema is None

params_model_dump = params.model_dump(mode="json")
assert params_model_dump["validator_type"] == "local_callable"
assert params_model_dump["validation_function"] == "stub_callback"
Loading
Loading