diff --git a/dataframely/_deprecation.py b/dataframely/_deprecation.py index 3be8383..290564a 100644 --- a/dataframely/_deprecation.py +++ b/dataframely/_deprecation.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: BSD-3-Clause import os -import warnings from collections.abc import Callable from functools import wraps @@ -26,14 +25,3 @@ def wrapper() -> None: return wrapper return decorator - - -@skip_if(env="DATAFRAMELY_NO_FUTURE_WARNINGS") -def warn_nullable_default_change() -> None: - warnings.warn( - "The 'nullable' argument was not explicitly set. In a future release, " - "'nullable=False' will be the default if 'nullable' is not specified. " - "Explicitly set 'nullable=True' if you want your column to be nullable.", - FutureWarning, - stacklevel=4, - ) diff --git a/dataframely/columns/_base.py b/dataframely/columns/_base.py index 0ab8e52..686f6c5 100644 --- a/dataframely/columns/_base.py +++ b/dataframely/columns/_base.py @@ -13,9 +13,6 @@ import polars as pl from dataframely._compat import pa, sa, sa_TypeEngine -from dataframely._deprecation import ( - warn_nullable_default_change, -) from dataframely._polars import PolarsDataType from dataframely.random import Generator @@ -45,7 +42,7 @@ class Column(ABC): def __init__( self, *, - nullable: bool | None = None, + nullable: bool = False, primary_key: bool = False, check: Check | None = None, alias: str | None = None, @@ -55,8 +52,6 @@ def __init__( Args: nullable: Whether this column may contain null values. Explicitly set `nullable=True` if you want your column to be nullable. - In a future release, `nullable=False` will be the default if `nullable` - is not specified. primary_key: Whether this column is part of the primary key of the schema. If ``True``, ``nullable`` is automatically set to ``False``. check: A custom rule or multiple rules to run for this column. This can be: @@ -80,13 +75,6 @@ def __init__( if nullable and primary_key: raise ValueError("Nullable primary key columns are not supported.") - if nullable is None: - if primary_key: - nullable = False - else: - warn_nullable_default_change() - nullable = True - self.nullable = nullable self.primary_key = primary_key self.check = check diff --git a/dataframely/columns/categorical.py b/dataframely/columns/categorical.py index efcb41a..cdb25d9 100644 --- a/dataframely/columns/categorical.py +++ b/dataframely/columns/categorical.py @@ -21,7 +21,7 @@ class Categorical(Column): def __init__( self, *, - nullable: bool | None = None, + nullable: bool = False, primary_key: bool = False, check: Check | None = None, alias: str | None = None, diff --git a/dataframely/columns/datetime.py b/dataframely/columns/datetime.py index 82fd7d1..e87c86b 100644 --- a/dataframely/columns/datetime.py +++ b/dataframely/columns/datetime.py @@ -34,7 +34,7 @@ class Date(OrdinalMixin[dt.date], Column): def __init__( self, *, - nullable: bool | None = None, + nullable: bool = False, primary_key: bool = False, min: dt.date | None = None, min_exclusive: dt.date | None = None, @@ -157,7 +157,7 @@ class Time(OrdinalMixin[dt.time], Column): def __init__( self, *, - nullable: bool | None = None, + nullable: bool = False, primary_key: bool = False, min: dt.time | None = None, min_exclusive: dt.time | None = None, @@ -286,7 +286,7 @@ class Datetime(OrdinalMixin[dt.datetime], Column): def __init__( self, *, - nullable: bool | None = None, + nullable: bool = False, primary_key: bool = False, min: dt.datetime | None = None, min_exclusive: dt.datetime | None = None, @@ -433,7 +433,7 @@ class Duration(OrdinalMixin[dt.timedelta], Column): def __init__( self, *, - nullable: bool | None = None, + nullable: bool = False, primary_key: bool = False, min: dt.timedelta | None = None, min_exclusive: dt.timedelta | None = None, diff --git a/dataframely/columns/decimal.py b/dataframely/columns/decimal.py index ae28b1e..b870bbc 100644 --- a/dataframely/columns/decimal.py +++ b/dataframely/columns/decimal.py @@ -27,7 +27,7 @@ def __init__( precision: int | None = None, scale: int = 0, *, - nullable: bool | None = None, + nullable: bool = False, primary_key: bool = False, min: decimal.Decimal | None = None, min_exclusive: decimal.Decimal | None = None, diff --git a/dataframely/columns/enum.py b/dataframely/columns/enum.py index 665e6ad..f638450 100644 --- a/dataframely/columns/enum.py +++ b/dataframely/columns/enum.py @@ -26,7 +26,7 @@ def __init__( self, categories: pl.Series | Iterable[str] | type[enum.Enum], *, - nullable: bool | None = None, + nullable: bool = False, primary_key: bool = False, check: Check | None = None, alias: str | None = None, diff --git a/dataframely/columns/float.py b/dataframely/columns/float.py index ed357a0..5462562 100644 --- a/dataframely/columns/float.py +++ b/dataframely/columns/float.py @@ -26,7 +26,7 @@ class _BaseFloat(OrdinalMixin[float], Column): def __init__( self, *, - nullable: bool | None = None, + nullable: bool = False, primary_key: bool = False, allow_inf_nan: bool = False, min: float | None = None, diff --git a/dataframely/columns/integer.py b/dataframely/columns/integer.py index 92384d4..2ffc0e4 100644 --- a/dataframely/columns/integer.py +++ b/dataframely/columns/integer.py @@ -24,7 +24,7 @@ class _BaseInteger(IsInMixin[int], OrdinalMixin[int], Column): def __init__( self, *, - nullable: bool | None = None, + nullable: bool = False, primary_key: bool = False, min: int | None = None, min_exclusive: int | None = None, diff --git a/dataframely/columns/list.py b/dataframely/columns/list.py index 273fc1f..8cf1203 100644 --- a/dataframely/columns/list.py +++ b/dataframely/columns/list.py @@ -31,7 +31,7 @@ def __init__( self, inner: Column, *, - nullable: bool | None = None, + nullable: bool = False, primary_key: bool = False, check: Check | None = None, alias: str | None = None, diff --git a/dataframely/columns/string.py b/dataframely/columns/string.py index 28f7d48..d98749c 100644 --- a/dataframely/columns/string.py +++ b/dataframely/columns/string.py @@ -22,7 +22,7 @@ class String(Column): def __init__( self, *, - nullable: bool | None = None, + nullable: bool = False, primary_key: bool = False, min_length: int | None = None, max_length: int | None = None, diff --git a/dataframely/columns/struct.py b/dataframely/columns/struct.py index 712fbba..8626aff 100644 --- a/dataframely/columns/struct.py +++ b/dataframely/columns/struct.py @@ -29,7 +29,7 @@ def __init__( self, inner: dict[str, Column], *, - nullable: bool | None = None, + nullable: bool = False, primary_key: bool = False, check: Check | None = None, alias: str | None = None, diff --git a/tests/column_types/test_array.py b/tests/column_types/test_array.py index e15b481..4b53911 100644 --- a/tests/column_types/test_array.py +++ b/tests/column_types/test_array.py @@ -12,8 +12,8 @@ @pytest.mark.parametrize( "inner", [ - (dy.Int64()), - (dy.Integer()), + (dy.Int64(nullable=True)), + (dy.Integer(nullable=True)), ], ) def test_integer_array(inner: Column) -> None: @@ -29,12 +29,12 @@ def test_integer_array(inner: Column) -> None: def test_invalid_inner_type() -> None: - schema = create_schema("test", {"a": dy.Array(dy.Int64(), 1)}) + schema = create_schema("test", {"a": dy.Array(dy.Int64(nullable=True), 1)}) assert not schema.is_valid(pl.DataFrame({"a": [["1"], ["2"], ["3"]]})) def test_invalid_shape() -> None: - schema = create_schema("test", {"a": dy.Array(dy.Int64(), 2)}) + schema = create_schema("test", {"a": dy.Array(dy.Int64(nullable=True), 2)}) assert not schema.is_valid( pl.DataFrame( {"a": [[1], [2], [3]]}, @@ -49,52 +49,52 @@ def test_invalid_shape() -> None: ("column", "dtype", "is_valid"), [ ( - dy.Array(dy.Int64(), 1), + dy.Array(dy.Int64(nullable=True), 1), pl.Array(pl.Int64(), 1), True, ), ( - dy.Array(dy.String(), 1), + dy.Array(dy.String(nullable=True), 1), pl.Array(pl.Int64(), 1), False, ), ( - dy.Array(dy.String(), 1), + dy.Array(dy.String(nullable=True), 1), pl.Array(pl.Int64(), 2), False, ), ( - dy.Array(dy.Int64(), (1,)), + dy.Array(dy.Int64(nullable=True), (1,)), pl.Array(pl.Int64(), (1,)), True, ), ( - dy.Array(dy.Int64(), (1,)), + dy.Array(dy.Int64(nullable=True), (1,)), pl.Array(pl.Int64(), (2,)), False, ), ( - dy.Array(dy.String(), 1), - dy.Array(dy.String(), 1), + dy.Array(dy.String(nullable=True), 1), + dy.Array(dy.String(nullable=True), 1), False, ), ( - dy.Array(dy.String(), 1), + dy.Array(dy.String(nullable=True), 1), dy.String(), False, ), ( - dy.Array(dy.String(), 1), + dy.Array(dy.String(nullable=True), 1), pl.String(), False, ), ( - dy.Array(dy.Array(dy.String(), 1), 1), + dy.Array(dy.Array(dy.String(nullable=True), 1), 1), pl.Array(pl.String(), (1, 1)), True, ), ( - dy.Array(dy.String(), (1, 1)), + dy.Array(dy.String(nullable=True), (1, 1)), pl.Array(pl.Array(pl.String(), 1), 1), True, ), @@ -105,7 +105,9 @@ def test_validate_dtype(column: Column, dtype: pl.DataType, is_valid: bool) -> N def test_nested_arrays() -> None: - schema = create_schema("test", {"a": dy.Array(dy.Array(dy.Int64(), 1), 1)}) + schema = create_schema( + "test", {"a": dy.Array(dy.Array(dy.Int64(nullable=True), 1), 1)} + ) assert schema.is_valid( pl.DataFrame( {"a": [[[1]], [[2]], [[3]]]}, @@ -117,7 +119,9 @@ def test_nested_arrays() -> None: def test_nested_array() -> None: - schema = create_schema("test", {"a": dy.Array(dy.Array(dy.Int64(), 1), 1)}) + schema = create_schema( + "test", {"a": dy.Array(dy.Array(dy.Int64(nullable=True), 1), 1)} + ) assert schema.is_valid( pl.DataFrame( {"a": [[[1]], [[2]], [[3]]]}, @@ -147,7 +151,7 @@ def test_array_with_rules() -> None: def test_outer_nullability() -> None: schema = create_schema( "test", - {"nullable": dy.Array(inner=dy.Integer(), shape=1, nullable=True)}, + {"nullable": dy.Array(inner=dy.Integer(nullable=True), shape=1, nullable=True)}, ) df = pl.DataFrame({"nullable": [None, None]}) schema.validate(df, cast=True) diff --git a/tests/column_types/test_datetime.py b/tests/column_types/test_datetime.py index f3aa93f..7c5d96f 100644 --- a/tests/column_types/test_datetime.py +++ b/tests/column_types/test_datetime.py @@ -216,47 +216,47 @@ def test_args_resolution_valid( ("column", "values", "valid"), [ ( - dy.Date(min=dt.date(2020, 4, 1)), + dy.Date(min=dt.date(2020, 4, 1), nullable=True), [dt.date(2020, 3, 31), dt.date(2020, 4, 1), dt.date(9999, 12, 31)], {"min": [False, True, True]}, ), ( - dy.Date(min_exclusive=dt.date(2020, 4, 1)), + dy.Date(min_exclusive=dt.date(2020, 4, 1), nullable=True), [dt.date(2020, 3, 31), dt.date(2020, 4, 1), dt.date(9999, 12, 31)], {"min_exclusive": [False, False, True]}, ), ( - dy.Date(max=dt.date(2020, 4, 1)), + dy.Date(max=dt.date(2020, 4, 1), nullable=True), [dt.date(2020, 3, 31), dt.date(2020, 4, 1), dt.date(2020, 4, 2)], {"max": [True, True, False]}, ), ( - dy.Date(max_exclusive=dt.date(2020, 4, 1)), + dy.Date(max_exclusive=dt.date(2020, 4, 1), nullable=True), [dt.date(2020, 3, 31), dt.date(2020, 4, 1), dt.date(2020, 4, 2)], {"max_exclusive": [True, False, False]}, ), ( - dy.Time(min=dt.time(3)), + dy.Time(min=dt.time(3), nullable=True), [dt.time(2, 59), dt.time(3, 0, 0), dt.time(4)], {"min": [False, True, True]}, ), ( - dy.Time(min_exclusive=dt.time(3)), + dy.Time(min_exclusive=dt.time(3), nullable=True), [dt.time(2, 59), dt.time(3, 0, 0), dt.time(4)], {"min_exclusive": [False, False, True]}, ), ( - dy.Time(max=dt.time(11, 59, 59, 999999)), + dy.Time(max=dt.time(11, 59, 59, 999999), nullable=True), [dt.time(11), dt.time(12), dt.time(13)], {"max": [True, False, False]}, ), ( - dy.Time(max_exclusive=dt.time(12)), + dy.Time(max_exclusive=dt.time(12), nullable=True), [dt.time(11), dt.time(12), dt.time(13)], {"max_exclusive": [True, False, False]}, ), ( - dy.Datetime(min=dt.datetime(2020, 3, 1, hour=12)), + dy.Datetime(min=dt.datetime(2020, 3, 1, hour=12), nullable=True), [ dt.datetime(2020, 2, 29, hour=14), dt.datetime(2020, 3, 1, hour=11), @@ -267,7 +267,7 @@ def test_args_resolution_valid( {"min": [False, False, True, True, True]}, ), ( - dy.Datetime(min_exclusive=dt.datetime(2020, 3, 1, hour=12)), + dy.Datetime(min_exclusive=dt.datetime(2020, 3, 1, hour=12), nullable=True), [ dt.datetime(2020, 2, 29, hour=14), dt.datetime(2020, 3, 1, hour=11), @@ -278,7 +278,7 @@ def test_args_resolution_valid( {"min_exclusive": [False, False, False, True, True]}, ), ( - dy.Datetime(max=dt.datetime(2020, 3, 1, hour=12)), + dy.Datetime(max=dt.datetime(2020, 3, 1, hour=12), nullable=True), [ dt.datetime(2020, 2, 29, hour=14), dt.datetime(2020, 3, 1, hour=11), @@ -289,7 +289,7 @@ def test_args_resolution_valid( {"max": [True, True, True, False, False]}, ), ( - dy.Datetime(max_exclusive=dt.datetime(2020, 3, 1, hour=12)), + dy.Datetime(max_exclusive=dt.datetime(2020, 3, 1, hour=12), nullable=True), [ dt.datetime(2020, 2, 29, hour=14), dt.datetime(2020, 3, 1, hour=11), @@ -300,7 +300,7 @@ def test_args_resolution_valid( {"max_exclusive": [True, True, False, False, False]}, ), ( - dy.Duration(min=dt.timedelta(days=1, seconds=14400)), + dy.Duration(min=dt.timedelta(days=1, seconds=14400), nullable=True), [ dt.timedelta(seconds=13000), dt.timedelta(days=1, seconds=14400), @@ -309,7 +309,9 @@ def test_args_resolution_valid( {"min": [False, True, True]}, ), ( - dy.Duration(min_exclusive=dt.timedelta(days=1, seconds=14400)), + dy.Duration( + min_exclusive=dt.timedelta(days=1, seconds=14400), nullable=True + ), [ dt.timedelta(seconds=13000), dt.timedelta(days=1, seconds=14400), @@ -318,7 +320,7 @@ def test_args_resolution_valid( {"min_exclusive": [False, False, True]}, ), ( - dy.Duration(max=dt.timedelta(days=1, seconds=14400)), + dy.Duration(max=dt.timedelta(days=1, seconds=14400), nullable=True), [ dt.timedelta(seconds=13000), dt.timedelta(days=1, seconds=14400), @@ -327,7 +329,9 @@ def test_args_resolution_valid( {"max": [True, True, False]}, ), ( - dy.Duration(max_exclusive=dt.timedelta(days=1, seconds=14400)), + dy.Duration( + max_exclusive=dt.timedelta(days=1, seconds=14400), nullable=True + ), [ dt.timedelta(seconds=13000), dt.timedelta(days=1, seconds=14400), @@ -350,17 +354,17 @@ def test_validate_min_max( ("column", "values", "valid"), [ ( - dy.Date(resolution="1mo"), + dy.Date(resolution="1mo", nullable=True), [dt.date(2020, 1, 1), dt.date(2021, 1, 15), dt.date(2022, 12, 1)], {"resolution": [True, False, True]}, ), ( - dy.Time(resolution="1h"), + dy.Time(resolution="1h", nullable=True), [dt.time(12, 0), dt.time(13, 15), dt.time(14, 0, 5)], {"resolution": [True, False, False]}, ), ( - dy.Datetime(resolution="1d"), + dy.Datetime(resolution="1d", nullable=True), [ dt.datetime(2020, 4, 5), dt.datetime(2021, 1, 1, 12), @@ -369,7 +373,7 @@ def test_validate_min_max( {"resolution": [True, False, False]}, ), ( - dy.Duration(resolution="12h"), + dy.Duration(resolution="12h", nullable=True), [ dt.timedelta(hours=12), dt.timedelta(days=2), diff --git a/tests/column_types/test_decimal.py b/tests/column_types/test_decimal.py index 46ff42a..11719e4 100644 --- a/tests/column_types/test_decimal.py +++ b/tests/column_types/test_decimal.py @@ -89,7 +89,10 @@ def test_non_decimal_dtype_fails(dtype: DataTypeClass) -> None: ], ) def test_validate_min(inclusive: bool, valid: dict[str, list[bool]]) -> None: - kwargs = {("min" if inclusive else "min_exclusive"): decimal.Decimal(3)} + kwargs = { + ("min" if inclusive else "min_exclusive"): decimal.Decimal(3), + "nullable": True, + } column = dy.Decimal(**kwargs) # type: ignore lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a")))) @@ -105,7 +108,10 @@ def test_validate_min(inclusive: bool, valid: dict[str, list[bool]]) -> None: ], ) def test_validate_max(inclusive: bool, valid: dict[str, list[bool]]) -> None: - kwargs = {("max" if inclusive else "max_exclusive"): decimal.Decimal(3)} + kwargs = { + ("max" if inclusive else "max_exclusive"): decimal.Decimal(3), + "nullable": True, + } column = dy.Decimal(**kwargs) # type: ignore lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a")))) @@ -158,6 +164,7 @@ def test_validate_range( kwargs = { ("min" if min_inclusive else "min_exclusive"): decimal.Decimal(0), ("max" if max_inclusive else "max_exclusive"): decimal.Decimal(2), + "nullable": True, } column = dy.Decimal(**kwargs) # type: ignore lf = pl.LazyFrame({"a": [-1, 0, 1, 2, 3]}) diff --git a/tests/column_types/test_float.py b/tests/column_types/test_float.py index d214d39..eda77f8 100644 --- a/tests/column_types/test_float.py +++ b/tests/column_types/test_float.py @@ -173,7 +173,7 @@ def test_validate_inf_nan(inf: Any, nan: Any) -> None: @pytest.mark.parametrize("inf", [np.inf, -np.inf, float("inf"), float("-inf")]) @pytest.mark.parametrize("nan", [np.nan, float("nan"), float("NaN")]) def test_validate_allow_inf_nan(inf: Any, nan: Any) -> None: - column = dy.Float(allow_inf_nan=True) + column = dy.Float(allow_inf_nan=True, nullable=True) lf = pl.LazyFrame({"a": pl.Series([inf, 2.0, nan, 4.0, 5.0])}) actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a")))) assert len(actual.collect_schema().names()) == 0, ( diff --git a/tests/column_types/test_integer.py b/tests/column_types/test_integer.py index 465c678..5194698 100644 --- a/tests/column_types/test_integer.py +++ b/tests/column_types/test_integer.py @@ -76,7 +76,7 @@ def test_non_integer_dtype_fails(dtype: DataTypeClass) -> None: @pytest.mark.parametrize("column_type", INTEGER_COLUMN_TYPES) @pytest.mark.parametrize("inclusive", [True, False]) def test_validate_min(column_type: type[_BaseInteger], inclusive: bool) -> None: - kwargs = {("min" if inclusive else "min_exclusive"): 3} + kwargs = {("min" if inclusive else "min_exclusive"): 3, "nullable": True} column = column_type(**kwargs) # type: ignore lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a")))) @@ -88,7 +88,7 @@ def test_validate_min(column_type: type[_BaseInteger], inclusive: bool) -> None: @pytest.mark.parametrize("column_type", INTEGER_COLUMN_TYPES) @pytest.mark.parametrize("inclusive", [True, False]) def test_validate_max(column_type: type[_BaseInteger], inclusive: bool) -> None: - kwargs = {("max" if inclusive else "max_exclusive"): 3} + kwargs = {("max" if inclusive else "max_exclusive"): 3, "nullable": True} column = column_type(**kwargs) # type: ignore lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a")))) @@ -103,7 +103,7 @@ def test_validate_min_zero(column_type: type[_BaseInteger], inclusive: bool) -> """Specific edge case where the minimum is `0`, which can lead to python bugs if we use `if value` instead of `if value is not None` somewhere.""" key = "min" if inclusive else "min_exclusive" - kwargs = {key: 0} + kwargs = {key: 0, "nullable": True} column = column_type(**kwargs) # type: ignore lf = pl.LazyFrame({"a": [-1]}) actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a")))) @@ -117,7 +117,7 @@ def test_validate_max_zero(column_type: type[_BaseInteger], inclusive: bool) -> """Specific edge case where the maximum is `0`, which can lead to python bugs if we use `if value` instead of `if value is not None` somewhere.""" key = "max" if inclusive else "max_exclusive" - kwargs = {key: 0} + kwargs = {key: 0, "nullable": True} column = column_type(**kwargs) # type: ignore lf = pl.LazyFrame({"a": [1]}) actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a")))) @@ -134,6 +134,7 @@ def test_validate_range( kwargs = { ("min" if min_inclusive else "min_exclusive"): 2, ("max" if max_inclusive else "max_exclusive"): 4, + "nullable": True, } column = column_type(**kwargs) # type: ignore lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) @@ -151,7 +152,7 @@ def test_validate_range( @pytest.mark.parametrize("column_type", INTEGER_COLUMN_TYPES) def test_validate_is_in(column_type: type[_BaseInteger]) -> None: - column = column_type(is_in=[3, 5]) + column = column_type(is_in=[3, 5], nullable=True) lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a")))) expected = pl.LazyFrame({"is_in": [False, False, True, False, True]}) diff --git a/tests/column_types/test_list.py b/tests/column_types/test_list.py index d7401fc..e860b85 100644 --- a/tests/column_types/test_list.py +++ b/tests/column_types/test_list.py @@ -62,7 +62,7 @@ def test_nested_lists() -> None: def test_list_with_pk() -> None: schema = create_schema( "test", - {"a": dy.List(dy.String(), primary_key=True)}, + {"a": dy.List(dy.String(nullable=True), primary_key=True)}, ) df = pl.DataFrame({"a": [["ab"], ["a", "ab"], [None], ["a", "b"], ["a", "b"]]}) _, failures = schema.filter(df) diff --git a/tests/column_types/test_string.py b/tests/column_types/test_string.py index 1003102..1846ab8 100644 --- a/tests/column_types/test_string.py +++ b/tests/column_types/test_string.py @@ -9,7 +9,7 @@ def test_validate_min_length() -> None: - column = dy.String(min_length=2) + column = dy.String(min_length=2, nullable=True) lf = pl.LazyFrame({"a": ["foo", "x"]}) actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a")))) expected = pl.LazyFrame({"min_length": [True, False]}) @@ -17,7 +17,7 @@ def test_validate_min_length() -> None: def test_validate_max_length() -> None: - column = dy.String(max_length=2) + column = dy.String(max_length=2, nullable=True) lf = pl.LazyFrame({"a": ["foo", "x"]}) actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a")))) expected = pl.LazyFrame({"max_length": [False, True]}) @@ -25,7 +25,7 @@ def test_validate_max_length() -> None: def test_validate_regex() -> None: - column = dy.String(regex="[0-9][a-z]$") + column = dy.String(regex="[0-9][a-z]$", nullable=True) lf = pl.LazyFrame({"a": ["33x", "3x", "44"]}) actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a")))) expected = pl.LazyFrame({"regex": [True, True, False]}) diff --git a/tests/columns/test_default_dtypes.py b/tests/columns/test_default_dtypes.py index 278b7e8..2d95920 100644 --- a/tests/columns/test_default_dtypes.py +++ b/tests/columns/test_default_dtypes.py @@ -36,7 +36,7 @@ (dy.UInt64(), pl.UInt64()), (dy.String(), pl.String()), (dy.List(dy.String()), pl.List(pl.String())), - (dy.Array(dy.String(), 1), pl.Array(pl.String(), 1)), + (dy.Array(dy.String(nullable=True), 1), pl.Array(pl.String(), 1)), (dy.Struct({"a": dy.String()}), pl.Struct({"a": pl.String()})), (dy.Enum(["a", "b"]), pl.Enum(["a", "b"])), (dy.Categorical(), pl.Categorical()), diff --git a/tests/columns/test_matches.py b/tests/columns/test_matches.py index babee98..b8731de 100644 --- a/tests/columns/test_matches.py +++ b/tests/columns/test_matches.py @@ -46,7 +46,11 @@ dy.String(check=[lambda x: x == "a"]), False, ), - (dy.Array(dy.Int32(), shape=(2, 2)), dy.Array(dy.Int32(), shape=(2, 2)), True), + ( + dy.Array(dy.Int32(nullable=True), shape=(2, 2)), + dy.Array(dy.Int32(nullable=True), shape=(2, 2)), + True, + ), (dy.List(dy.Int32()), dy.List(dy.Int32()), True), ( dy.Struct({"a": dy.Int32(check=lambda expr: expr > 4)}), diff --git a/tests/columns/test_pyarrow.py b/tests/columns/test_pyarrow.py index e9114b5..d2ef1e5 100644 --- a/tests/columns/test_pyarrow.py +++ b/tests/columns/test_pyarrow.py @@ -1,6 +1,8 @@ # Copyright (c) QuantCo 2025-2025 # SPDX-License-Identifier: BSD-3-Clause +from typing import TypeVar + import pytest from polars._typing import TimeUnit @@ -17,9 +19,19 @@ pytestmark = pytest.mark.with_optionals +T = TypeVar("T", bound=dy.Column) + + +def _nullable(column_type: type[T]) -> T: + # dy.Any doesn't have the `nullable` parameter. + if column_type == dy.Any: + return column_type() + return column_type(nullable=True) + + @pytest.mark.parametrize("column_type", ALL_COLUMN_TYPES) def test_equal_to_polars_schema(column_type: type[Column]) -> None: - schema = create_schema("test", {"a": column_type()}) + schema = create_schema("test", {"a": _nullable(column_type)}) actual = schema.pyarrow_schema() expected = schema.create_empty().to_arrow().schema assert actual == expected @@ -39,7 +51,7 @@ def test_equal_to_polars_schema(column_type: type[Column]) -> None: ], ) def test_equal_polars_schema_enum(categories: list[str]) -> None: - schema = create_schema("test", {"a": dy.Enum(categories)}) + schema = create_schema("test", {"a": dy.Enum(categories, nullable=True)}) actual = schema.pyarrow_schema() expected = schema.create_empty().to_arrow().schema assert actual == expected @@ -49,11 +61,14 @@ def test_equal_polars_schema_enum(categories: list[str]) -> None: "inner", [c() for c in ALL_COLUMN_TYPES] + [dy.List(t()) for t in ALL_COLUMN_TYPES] - + [dy.Array(t(), 1) for t in NO_VALIDATION_COLUMN_TYPES] + + [ + dy.Array(t() if t == dy.Any else t(nullable=True), 1) + for t in NO_VALIDATION_COLUMN_TYPES + ] + [dy.Struct({"a": t()}) for t in ALL_COLUMN_TYPES], ) def test_equal_polars_schema_list(inner: Column) -> None: - schema = create_schema("test", {"a": dy.List(inner)}) + schema = create_schema("test", {"a": dy.List(inner, nullable=True)}) actual = schema.pyarrow_schema() expected = schema.create_empty().to_arrow().schema assert actual == expected @@ -61,10 +76,13 @@ def test_equal_polars_schema_list(inner: Column) -> None: @pytest.mark.parametrize( "inner", - [c() for c in NO_VALIDATION_COLUMN_TYPES] - + [dy.List(t()) for t in NO_VALIDATION_COLUMN_TYPES] - + [dy.Array(t(), 1) for t in NO_VALIDATION_COLUMN_TYPES] - + [dy.Struct({"a": t()}) for t in NO_VALIDATION_COLUMN_TYPES], + [_nullable(c) for c in NO_VALIDATION_COLUMN_TYPES] + + [dy.List(_nullable(t), nullable=True) for t in NO_VALIDATION_COLUMN_TYPES] + + [dy.Array(_nullable(t), 1, nullable=True) for t in NO_VALIDATION_COLUMN_TYPES] + + [ + dy.Struct({"a": _nullable(t)}, nullable=True) + for t in NO_VALIDATION_COLUMN_TYPES + ], ) @pytest.mark.parametrize( "shape", @@ -83,13 +101,16 @@ def test_equal_polars_schema_array(inner: Column, shape: int | tuple[int, ...]) @pytest.mark.parametrize( "inner", - [c() for c in ALL_COLUMN_TYPES] - + [dy.Struct({"a": t()}) for t in ALL_COLUMN_TYPES] - + [dy.Array(t(), 1) for t in NO_VALIDATION_COLUMN_TYPES] - + [dy.List(t()) for t in ALL_COLUMN_TYPES], + [_nullable(c) for c in NO_VALIDATION_COLUMN_TYPES] + + [dy.List(_nullable(t), nullable=True) for t in NO_VALIDATION_COLUMN_TYPES] + + [dy.Array(_nullable(t), 1, nullable=True) for t in NO_VALIDATION_COLUMN_TYPES] + + [ + dy.Struct({"a": _nullable(t)}, nullable=True) + for t in NO_VALIDATION_COLUMN_TYPES + ], ) def test_equal_polars_schema_struct(inner: Column) -> None: - schema = create_schema("test", {"a": dy.Struct({"a": inner})}) + schema = create_schema("test", {"a": dy.Struct({"a": inner}, nullable=True)}) actual = schema.pyarrow_schema() expected = schema.create_empty().to_arrow().schema assert actual == expected @@ -110,10 +131,13 @@ def test_nullability_information_enum(nullable: bool) -> None: @pytest.mark.parametrize( "inner", - [c() for c in ALL_COLUMN_TYPES] - + [dy.List(t()) for t in ALL_COLUMN_TYPES] - + [dy.Array(t(), 1) for t in NO_VALIDATION_COLUMN_TYPES] - + [dy.Struct({"a": t()}) for t in ALL_COLUMN_TYPES], + [_nullable(c) for c in NO_VALIDATION_COLUMN_TYPES] + + [dy.List(_nullable(t), nullable=True) for t in NO_VALIDATION_COLUMN_TYPES] + + [dy.Array(_nullable(t), 1, nullable=True) for t in NO_VALIDATION_COLUMN_TYPES] + + [ + dy.Struct({"a": _nullable(t)}, nullable=True) + for t in NO_VALIDATION_COLUMN_TYPES + ], ) @pytest.mark.parametrize("nullable", [True, False]) def test_nullability_information_list(inner: Column, nullable: bool) -> None: @@ -123,10 +147,13 @@ def test_nullability_information_list(inner: Column, nullable: bool) -> None: @pytest.mark.parametrize( "inner", - [c() for c in ALL_COLUMN_TYPES] - + [dy.Struct({"a": t()}) for t in ALL_COLUMN_TYPES] - + [dy.Array(t(), 1) for t in NO_VALIDATION_COLUMN_TYPES] - + [dy.List(t()) for t in ALL_COLUMN_TYPES], + [_nullable(c) for c in NO_VALIDATION_COLUMN_TYPES] + + [dy.List(_nullable(t), nullable=True) for t in NO_VALIDATION_COLUMN_TYPES] + + [dy.Array(_nullable(t), 1, nullable=True) for t in NO_VALIDATION_COLUMN_TYPES] + + [ + dy.Struct({"a": _nullable(t)}, nullable=True) + for t in NO_VALIDATION_COLUMN_TYPES + ], ) @pytest.mark.parametrize("nullable", [True, False]) def test_nullability_information_struct(inner: Column, nullable: bool) -> None: @@ -135,11 +162,15 @@ def test_nullability_information_struct(inner: Column, nullable: bool) -> None: def test_multiple_columns() -> None: - schema = create_schema("test", {"a": dy.Int32(nullable=False), "b": dy.Integer()}) + schema = create_schema( + "test", {"a": dy.Int32(nullable=False), "b": dy.Integer(nullable=True)} + ) assert str(schema.pyarrow_schema()).split("\n") == ["a: int32 not null", "b: int64"] @pytest.mark.parametrize("time_unit", ["ns", "us", "ms"]) def test_datetime_time_unit(time_unit: TimeUnit) -> None: - schema = create_schema("test", {"a": dy.Datetime(time_unit=time_unit)}) + schema = create_schema( + "test", {"a": dy.Datetime(time_unit=time_unit, nullable=True)} + ) assert str(schema.pyarrow_schema()) == f"a: timestamp[{time_unit}]" diff --git a/tests/columns/test_sample.py b/tests/columns/test_sample.py index 3bcac16..d088847 100644 --- a/tests/columns/test_sample.py +++ b/tests/columns/test_sample.py @@ -176,20 +176,24 @@ def test_sample_enum(generator: Generator) -> None: def test_sample_list(generator: Generator) -> None: - column = dy.List(dy.String(regex="[abc]"), min_length=5, max_length=10) + column = dy.List( + dy.String(regex="[abc]"), nullable=True, min_length=5, max_length=10 + ) samples = sample_and_validate(column, generator, n=10_000) assert set(samples.list.len()) == set(range(5, 11)) | {None} def test_sample_array(generator: Generator) -> None: - column = dy.Array(dy.Bool(), (2, 3)) + column = dy.Array(dy.Bool(nullable=True), (2, 3)) samples = sample_and_validate(column, generator, n=10_000) assert samples.is_null().any() assert set(samples.arr.len()) == {2, None} def test_sample_struct(generator: Generator) -> None: - column = dy.Struct({"a": dy.String(regex="[abc]"), "b": dy.String(regex="[a-z]xx")}) + column = dy.Struct( + {"a": dy.String(regex="[abc]"), "b": dy.String(regex="[a-z]xx")}, nullable=True + ) samples = sample_and_validate(column, generator, n=10_000) assert samples.is_null().any() assert len(samples) == 10_000 diff --git a/tests/columns/test_sql_schema.py b/tests/columns/test_sql_schema.py index 5b41425..abb54a5 100644 --- a/tests/columns/test_sql_schema.py +++ b/tests/columns/test_sql_schema.py @@ -149,7 +149,7 @@ def test_raise_for_array_column(dialect: Dialect) -> None: with pytest.raises( NotImplementedError, match="SQL column cannot have 'Array' type." ): - dy.Array(dy.String(), 1).sqlalchemy_dtype(dialect) + dy.Array(dy.String(nullable=True), 1).sqlalchemy_dtype(dialect) @pytest.mark.parametrize("dialect", [MSDialect_pyodbc(), PGDialect_psycopg2()]) @@ -157,7 +157,7 @@ def test_raise_for_struct_column(dialect: Dialect) -> None: with pytest.raises( NotImplementedError, match="SQL column cannot have 'Struct' type." ): - dy.Struct({"a": dy.String()}).sqlalchemy_dtype(dialect) + dy.Struct({"a": dy.String(nullable=True)}).sqlalchemy_dtype(dialect) @pytest.mark.parametrize("dialect", [MSDialect_pyodbc(), PGDialect_psycopg2()]) diff --git a/tests/columns/test_str.py b/tests/columns/test_str.py index 387af51..9520822 100644 --- a/tests/columns/test_str.py +++ b/tests/columns/test_str.py @@ -25,7 +25,7 @@ def test_string_representation_list() -> None: def test_string_representation_array() -> None: - column = dy.Array(dy.String(), 1) + column = dy.Array(dy.String(nullable=True), 1) assert str(column) == dy.Array.__name__.lower() diff --git a/tests/schema/test_base.py b/tests/schema/test_base.py index 92b3b1e..8311791 100644 --- a/tests/schema/test_base.py +++ b/tests/schema/test_base.py @@ -14,7 +14,7 @@ class MySchema(dy.Schema): a = dy.Integer(primary_key=True) b = dy.String(primary_key=True) - c = dy.Float64() + c = dy.Float64(nullable=True) d = dy.Any(alias="e") diff --git a/tests/schema/test_filter.py b/tests/schema/test_filter.py index 5dbbe84..e6be485 100644 --- a/tests/schema/test_filter.py +++ b/tests/schema/test_filter.py @@ -17,7 +17,7 @@ class MySchema(dy.Schema): a = dy.Int64(primary_key=True) - b = dy.String(max_length=3) + b = dy.String(max_length=3, nullable=True) @pytest.mark.parametrize( @@ -106,7 +106,7 @@ def test_filter_failure( @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame]) def test_filter_no_rules(df_type: type[pl.DataFrame] | type[pl.LazyFrame]) -> None: - schema = create_schema("test", {"a": dy.Int64()}) + schema = create_schema("test", {"a": dy.Int64(nullable=True)}) df = df_type({"a": [1, 2, 3]}) df_valid, failures = schema.filter(df) assert isinstance(df_valid, pl.DataFrame) diff --git a/tests/schema/test_repr.py b/tests/schema/test_repr.py index 63958da..64a0c8e 100644 --- a/tests/schema/test_repr.py +++ b/tests/schema/test_repr.py @@ -9,7 +9,7 @@ def test_repr_no_rules() -> None: class SchemaNoRules(dy.Schema): - a = dy.Integer() + a = dy.Integer(nullable=True) assert repr(SchemaNoRules) == textwrap.dedent("""\ [Schema "SchemaNoRules"] @@ -20,7 +20,7 @@ class SchemaNoRules(dy.Schema): def test_repr_only_column_rules() -> None: class SchemaColumnRules(dy.Schema): - a = dy.Integer(min=10) + a = dy.Integer(min=10, nullable=True) assert repr(SchemaColumnRules) == textwrap.dedent("""\ [Schema "SchemaColumnRules"] @@ -46,8 +46,8 @@ def test_repr_with_rules() -> None: assert repr(SchemaWithRules) == textwrap.dedent("""\ [Schema "SchemaWithRules"] Columns: - - "a": Integer(nullable=True, min=10) - - "b2": String(nullable=False, primary_key=True, regex='^[A-Z]{3}$') + - "a": Integer(min=10) + - "b2": String(primary_key=True, regex='^[A-Z]{3}$') Rules: - "my_rule": [(col("a")) < (dyn int: 100)] - "my_group_rule": [(col("a").sum()) > (dyn int: 50)] grouped by ['a'] @@ -56,7 +56,7 @@ def test_repr_with_rules() -> None: def test_repr_enum() -> None: class SchemaNoRules(dy.Schema): - a = dy.Enum(["a"]) + a = dy.Enum(["a"], nullable=True) assert repr(SchemaNoRules) == textwrap.dedent("""\ [Schema "SchemaNoRules"] diff --git a/tests/schema/test_sample.py b/tests/schema/test_sample.py index 00ca956..d7e38a5 100644 --- a/tests/schema/test_sample.py +++ b/tests/schema/test_sample.py @@ -12,8 +12,8 @@ class MySimpleSchema(dy.Schema): - a = dy.Int64() - b = dy.String() + a = dy.Int64(nullable=True) + b = dy.String(nullable=True) class PrimaryKeySchema(dy.Schema): @@ -93,8 +93,8 @@ def _sampling_overrides(cls) -> dict[str, pl.Expr]: class MyAdvancedSchema(dy.Schema): - a = dy.Float64(min=20.0) - b = dy.String(regex=r"abc*") + a = dy.Float64(min=20.0, nullable=True) + b = dy.String(regex=r"abc*", nullable=True) # --------------------------------------- TESTS -------------------------------------- # diff --git a/tests/schema/test_serialization.py b/tests/schema/test_serialization.py index 46e1204..78e3bd6 100644 --- a/tests/schema/test_serialization.py +++ b/tests/schema/test_serialization.py @@ -53,7 +53,7 @@ def test_simple_serialization() -> None: {"a": dy.Int64()}, rules={"test": GroupRule(pl.len() > 2, group_columns=["a"])}, ), - create_schema("test", {"a": dy.Array(dy.Int64(), shape=(2, 2))}), + create_schema("test", {"a": dy.Array(dy.Int64(nullable=True), shape=(2, 2))}), create_schema("test", {"a": dy.List(dy.Int64(min=5))}), create_schema( "test", diff --git a/tests/schema/test_validate.py b/tests/schema/test_validate.py index c2c7a07..fc33b0f 100644 --- a/tests/schema/test_validate.py +++ b/tests/schema/test_validate.py @@ -15,7 +15,7 @@ class MySchema(dy.Schema): a = dy.Int64(primary_key=True) b = dy.String(nullable=False, max_length=5) - c = dy.String() + c = dy.String(nullable=True) class MyComplexSchema(dy.Schema): diff --git a/tests/test_deprecation.py b/tests/test_deprecation.py index f77c6af..c6d53fa 100644 --- a/tests/test_deprecation.py +++ b/tests/test_deprecation.py @@ -1,46 +1,24 @@ # Copyright (c) QuantCo 2025-2025 # SPDX-License-Identifier: BSD-3-Clause -import warnings -from collections.abc import Callable - import pytest -import dataframely as dy - -# --------------------- Nullability default change ------------------------------# - - -def deprecated_default_nullable() -> None: - """This function causes a FutureWarning because no value is specified for the - `nullable` argument to the Column constructor.""" - dy.Integer() - - -def test_warning_deprecated_default_nullable( - monkeypatch: pytest.MonkeyPatch, -) -> None: - monkeypatch.setenv("DATAFRAMELY_NO_FUTURE_WARNINGS", "") - with pytest.warns( - FutureWarning, match="The 'nullable' argument was not explicitly set" - ): - deprecated_default_nullable() - +from dataframely._deprecation import skip_if # ------------------------- Common ---------------------------------# -@pytest.mark.parametrize( - "deprecated_behavior", - [deprecated_default_nullable], -) @pytest.mark.parametrize("env_var", ["1", "True", "true"]) -def test_future_warning_skip( - monkeypatch: pytest.MonkeyPatch, env_var: str, deprecated_behavior: Callable -) -> None: - """FutureWarnings should be avoidable by setting an environment variable.""" - monkeypatch.setenv("DATAFRAMELY_NO_FUTURE_WARNINGS", env_var) - # Elevates FutureWarning to an exception - with warnings.catch_warnings(): - warnings.simplefilter("error", FutureWarning) - deprecated_behavior() +def test_skip_if(monkeypatch: pytest.MonkeyPatch, env_var: str) -> None: + """The skip_if decorator should allow us to prevent execution of a wrapped + function.""" + variable_name = "DATAFRAMELY_NO_FUTURE_WARNINGS" + + @skip_if(variable_name) + def callable() -> None: + raise ValueError() + + with pytest.raises(ValueError): + callable() + monkeypatch.setenv(variable_name, env_var) + callable()