Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ jobs:
run: rustup show
- name: Cache Rust dependencies
uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1
- name: Install repository
run: pixi run -e default postinstall
- name: pre-commit
run: pixi run pre-commit-run --color=always --show-diff-on-failure

Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ class HouseSchema(dy.Schema):
price = dy.Float64(nullable=False)

@dy.rule()
def reasonable_bathroom_to_bedroom_ratio() -> pl.Expr:
def reasonable_bathroom_to_bedroom_ratio(cls) -> pl.Expr:
ratio = pl.col("num_bathrooms") / pl.col("num_bedrooms")
return (ratio >= 1 / 3) & (ratio <= 3)

@dy.rule(group_by=["zip_code"])
def minimum_zip_code_count() -> pl.Expr:
def minimum_zip_code_count(cls) -> pl.Expr:
return pl.len() >= 2
```

Expand Down
19 changes: 12 additions & 7 deletions dataframely/_base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import polars as pl

from ._rule import DtypeCastRule, GroupRule, Rule
from ._rule import DtypeCastRule, GroupRule, Rule, RuleFactory
from .columns import Column
from .exc import ImplementationError

Expand Down Expand Up @@ -81,7 +81,7 @@ class Metadata:
"""Utility class to gather columns and rules associated with a schema."""

columns: dict[str, Column] = field(default_factory=dict)
rules: dict[str, Rule] = field(default_factory=dict)
rules: dict[str, RuleFactory] = field(default_factory=dict)

def update(self, other: Self) -> None:
self.columns.update(other.columns)
Expand All @@ -102,7 +102,11 @@ def __new__(
result.update(mcs._get_metadata_recursively(base))
result.update(mcs._get_metadata(namespace))
namespace[_COLUMN_ATTR] = result.columns
namespace[_RULE_ATTR] = result.rules
cls = super().__new__(mcs, name, bases, namespace, *args, **kwargs)

# Assign rules retroactively as we only encounter rule factories in the result
rules = {name: factory.make(cls) for name, factory in result.rules.items()}
setattr(cls, _RULE_ATTR, rules)

# At this point, we already know all columns and custom rules. We want to run
# some checks...
Expand All @@ -111,7 +115,7 @@ def __new__(
# we assume that users cast dtypes, i.e. additional rules for dtype casting
# are also checked.
all_column_names = set(result.columns)
all_rule_names = set(_build_rules(result.rules, result.columns, with_cast=True))
all_rule_names = set(_build_rules(rules, result.columns, with_cast=True))
common_names = all_column_names & all_rule_names
if len(common_names) > 0:
common_list = ", ".join(sorted(f"'{col}'" for col in common_names))
Expand All @@ -121,7 +125,7 @@ def __new__(
)

# 2) Check that the columns referenced in the group rules exist.
for rule_name, rule in result.rules.items():
for rule_name, rule in rules.items():
if isinstance(rule, GroupRule):
missing_columns = set(rule.group_columns) - set(result.columns)
if len(missing_columns) > 0:
Expand All @@ -138,6 +142,7 @@ def __new__(
for attr, value in namespace.items():
if attr.startswith("__"):
continue

# Check for tuple of column (commonly caused by trailing comma)
if (
isinstance(value, tuple)
Expand All @@ -157,7 +162,7 @@ def __new__(
f"Did you forget to add parentheses?"
)

return super().__new__(mcs, name, bases, namespace, *args, **kwargs)
return cls

def __getattribute__(cls, name: str) -> Any:
val = super().__getattribute__(name)
Expand All @@ -182,7 +187,7 @@ def _get_metadata(source: dict[str, Any]) -> Metadata:
}.items():
if isinstance(value, Column):
result.columns[value.alias or attr] = value
if isinstance(value, Rule):
if isinstance(value, RuleFactory):
# We must ensure that custom rules do not clash with internal rules.
if attr == "primary_key":
raise ImplementationError(
Expand Down
48 changes: 40 additions & 8 deletions dataframely/_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
else:
from typing_extensions import Self

ValidationFunction = Callable[[], pl.Expr]
ValidationFunction = Callable[[Any], pl.Expr]


class Rule:
"""Internal class representing validation rules."""

def __init__(self, expr: pl.Expr | ValidationFunction) -> None:
def __init__(self, expr: pl.Expr | Callable[[], pl.Expr]) -> None:
self._expr = expr

@property
Expand Down Expand Up @@ -71,7 +71,7 @@ class GroupRule(Rule):
"""Rule that is evaluated on a group of columns."""

def __init__(
self, expr: pl.Expr | ValidationFunction, group_columns: list[str]
self, expr: pl.Expr | Callable[[], pl.Expr], group_columns: list[str]
) -> None:
super().__init__(expr)
self.group_columns = group_columns
Expand All @@ -92,7 +92,41 @@ def __repr__(self) -> str:
return f"{super().__repr__()} grouped by {self.group_columns}"


def rule(*, group_by: list[str] | None = None) -> Callable[[ValidationFunction], Rule]:
# -------------------------------------- FACTORY ------------------------------------- #


class RuleFactory:
"""Factory class for rules created within schemas."""

def __init__(
self, validation_fn: Callable[[Any], pl.Expr], group_columns: list[str] | None
) -> None:
self.validation_fn = validation_fn
self.group_columns = group_columns

@classmethod
def from_rule(cls, rule: Rule) -> Self:
"""Create a rule factory from an existing rule."""
if isinstance(rule, GroupRule):
return cls(
validation_fn=lambda _: rule.expr,
group_columns=rule.group_columns,
)
return cls(validation_fn=lambda _: rule.expr, group_columns=None)

def make(self, schema: Any) -> Rule:
"""Create a new rule from this factory."""
if self.group_columns is not None:
return GroupRule(
expr=lambda: self.validation_fn(schema),
group_columns=self.group_columns,
)
return Rule(expr=lambda: self.validation_fn(schema))


def rule(
*, group_by: list[str] | None = None
) -> Callable[[ValidationFunction], RuleFactory]:
"""Mark a function as a rule to evaluate during validation.

The name of the function will be used as the name of the rule. The function should
Expand Down Expand Up @@ -128,10 +162,8 @@ def rule(*, group_by: list[str] | None = None) -> Callable[[ValidationFunction],
and (de-)serialization.
"""

def decorator(validation_fn: ValidationFunction) -> Rule:
if group_by is not None:
return GroupRule(expr=validation_fn, group_columns=group_by)
return Rule(expr=validation_fn)
def decorator(validation_fn: ValidationFunction) -> RuleFactory:
return RuleFactory(validation_fn=validation_fn, group_columns=group_by)

return decorator

Expand Down
66 changes: 0 additions & 66 deletions dataframely/mypy.py

This file was deleted.

7 changes: 5 additions & 2 deletions dataframely/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ._native import format_rule_failures
from ._plugin import all_rules, all_rules_horizontal, all_rules_required
from ._polars import collect_if
from ._rule import Rule, rule_from_dict, with_evaluation_rules
from ._rule import Rule, RuleFactory, rule_from_dict, with_evaluation_rules
from ._serialization import (
SERIALIZATION_FORMAT_VERSION,
SchemaJSONDecoder,
Expand Down Expand Up @@ -1377,7 +1377,10 @@ def _schema_from_dict(data: dict[str, Any]) -> type[Schema]:
(Schema,),
{
**{name: column_from_dict(col) for name, col in data["columns"].items()},
**{name: rule_from_dict(rule) for name, rule in data["rules"].items()},
**{
name: RuleFactory.from_rule(rule_from_dict(rule))
for name, rule in data["rules"].items()
},
},
)

Expand Down
14 changes: 10 additions & 4 deletions dataframely/testing/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any

from dataframely._filter import Filter
from dataframely._rule import Rule
from dataframely._rule import Rule, RuleFactory
from dataframely._typing import LazyFrame
from dataframely.collection import Collection
from dataframely.columns import Column
Expand All @@ -14,7 +14,7 @@
def create_schema(
name: str,
columns: dict[str, Column],
rules: dict[str, Rule] | None = None,
rules: dict[str, Rule | RuleFactory] | None = None,
) -> type[Schema]:
"""Dynamically create a new schema with the provided name.

Expand All @@ -23,12 +23,18 @@ def create_schema(
columns: The columns to set on the schema. When properly defining the schema,
this would be the annotations that define the column types.
rules: The custom non-column-specific validation rules. When properly defining
the schema, this would be the functions annotated with `@dy.rule`.
the schema, this would be the functions annotated with ``@dy.rule``.

Returns:
The dynamically created schema.
"""
return type(name, (Schema,), {**columns, **(rules or {})})
rule_factories = {
rule_name: (
rule if isinstance(rule, RuleFactory) else RuleFactory.from_rule(rule)
)
for rule_name, rule in (rules or {}).items()
}
return type(name, (Schema,), {**columns, **rule_factories})


def create_collection(
Expand Down
18 changes: 9 additions & 9 deletions docs/guides/examples/real-world.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -91,11 +91,11 @@
" amount = dy.Decimal(nullable=False, min_exclusive=Decimal(0))\n",
"\n",
" @dy.rule()\n",
" def discharge_after_admission() -> pl.Expr:\n",
" def discharge_after_admission(cls) -> pl.Expr:\n",
" return pl.col(\"discharge_date\") >= pl.col(\"admission_date\")\n",
"\n",
" @dy.rule()\n",
" def received_at_after_discharge() -> pl.Expr:\n",
" def received_at_after_discharge(cls) -> pl.Expr:\n",
" return pl.col(\"received_at\").dt.date() >= pl.col(\"discharge_date\")"
]
},
Expand Down Expand Up @@ -318,7 +318,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -328,7 +328,7 @@
" is_main = dy.Bool(nullable=False)\n",
"\n",
" @dy.rule(group_by=[\"invoice_id\"])\n",
" def exactly_one_main_diagnosis() -> pl.Expr:\n",
" def exactly_one_main_diagnosis(cls) -> pl.Expr:\n",
" return pl.col(\"is_main\").sum() == 1"
]
},
Expand All @@ -351,7 +351,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -368,11 +368,11 @@
" amount = dy.Decimal(nullable=False, min_exclusive=Decimal(0))\n",
"\n",
" @dy.rule()\n",
" def discharge_after_admission() -> pl.Expr:\n",
" def discharge_after_admission(cls) -> pl.Expr:\n",
" return pl.col(\"discharge_date\") >= pl.col(\"admission_date\")\n",
"\n",
" @dy.rule()\n",
" def received_at_after_discharge() -> pl.Expr:\n",
" def received_at_after_discharge(cls) -> pl.Expr:\n",
" return pl.col(\"received_at\").dt.date() >= pl.col(\"discharge_date\")\n",
"\n",
"\n",
Expand All @@ -381,7 +381,7 @@
" is_main = dy.Bool(nullable=False)\n",
"\n",
" @dy.rule(group_by=[\"invoice_id\"])\n",
" def exactly_one_main_diagnosis() -> pl.Expr:\n",
" def exactly_one_main_diagnosis(cls) -> pl.Expr:\n",
" return pl.col(\"is_main\").sum() == 1"
]
},
Expand Down
4 changes: 2 additions & 2 deletions docs/guides/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ class UserSchema(dy.Schema):
email = dy.String(nullable=True) # Must be unique, or null.

@dy.rule(group_by=["username"])
def unique_username() -> pl.Expr:
def unique_username(cls) -> pl.Expr:
"""Username, a non-nullable field, must be total unique."""
return pl.len() == 1

@dy.rule()
def unique_email_or_null() -> pl.Expr:
def unique_email_or_null(cls) -> pl.Expr:
"""Email must be unique, if provided."""
return pl.col("email").is_null() | pl.col("email").is_unique()
```
Loading
Loading