diff --git a/dataframely/columns/enum.py b/dataframely/columns/enum.py index d366610..c4b4f9e 100644 --- a/dataframely/columns/enum.py +++ b/dataframely/columns/enum.py @@ -3,7 +3,9 @@ from __future__ import annotations -from collections.abc import Sequence +import enum +from collections.abc import Iterable +from inspect import isclass from typing import Any import polars as pl @@ -22,7 +24,7 @@ class Enum(Column): def __init__( self, - categories: Sequence[str], + categories: pl.Series | Iterable[str] | type[enum.Enum], *, nullable: bool | None = None, primary_key: bool = False, @@ -32,7 +34,8 @@ def __init__( ): """ Args: - categories: The list of valid categories for the enum. + categories: The set of valid categories for the enum, or an existing Python + string-valued enum. nullable: Whether this column may contain null values. Explicitly set `nullable=True` if you want your column to be nullable. In a future release, `nullable=False` will be the default if `nullable` @@ -63,7 +66,13 @@ def __init__( alias=alias, metadata=metadata, ) - self.categories = list(categories) + if isclass(categories) and issubclass(categories, enum.Enum): + categories = pl.Series( + values=[getattr(v, "value", v) for v in categories.__members__.values()] + ) + elif not isinstance(categories, pl.Series): + categories = pl.Series(values=categories) + self.categories = categories @property def dtype(self) -> pl.DataType: @@ -72,7 +81,7 @@ def dtype(self) -> pl.DataType: def validate_dtype(self, dtype: PolarsDataType) -> bool: if not isinstance(dtype, pl.Enum): return False - return self.categories == dtype.categories.to_list() + return self.categories.equals(dtype.categories) def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: category_lengths = [len(c) for c in self.categories] @@ -92,5 +101,7 @@ def pyarrow_dtype(self) -> pa.DataType: def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: return generator.sample_choice( - n, choices=self.categories, null_probability=self._null_probability + n, + choices=self.categories.to_list(), + null_probability=self._null_probability, ).cast(self.dtype) diff --git a/tests/column_types/test_enum.py b/tests/column_types/test_enum.py index f2389b3..cfddec0 100644 --- a/tests/column_types/test_enum.py +++ b/tests/column_types/test_enum.py @@ -1,6 +1,8 @@ # Copyright (c) QuantCo 2025-2025 # SPDX-License-Identifier: BSD-3-Clause - +import enum +from collections.abc import Iterable +from enum import Enum from typing import Any import polars as pl @@ -61,3 +63,48 @@ def test_different_sequences(type1: type, type2: type) -> None: S = create_schema("test", {"x": dy.Enum(type1(allowed))}) df = pl.DataFrame({"x": pl.Series(["a", "b"], dtype=pl.Enum(type2(allowed)))}) S.validate(df) + + +def test_enum_of_enum_136() -> None: + class Categories(str, Enum): + a = "a" + b = "b" + + assert pl.Enum(Categories) == dy.Enum(Categories).dtype + + +def test_enum_of_series() -> None: + categories = pl.Series(["a", "b"]) + assert pl.Enum(categories) == dy.Enum(categories).dtype + + +def test_enum_of_iterable() -> None: + categories = (x for x in ["a", "b"]) + assert pl.Enum(["a", "b"]) == dy.Enum(categories).dtype + + +@pytest.mark.parametrize( + "categories1", + [ + ["a", "b"], + ("a", "b"), + pl.Series(["a", "b"]), + Enum("Categories", {"a": "a", "b": "b"}), + ], +) +@pytest.mark.parametrize( + "categories2", + [ + ["a", "b"], + ("a", "b"), + pl.Series(["a", "b"]), + Enum("Categories", {"a": "a", "b": "b"}), + ], +) +def test_sequences_and_enums( + categories1: pl.Series | Iterable[str] | type[enum.Enum], + categories2: pl.Series | Iterable[str] | type[enum.Enum], +) -> None: + S = create_schema("test", {"x": dy.Enum(categories1)}) + df = pl.DataFrame({"x": pl.Series(["a", "b"], dtype=pl.Enum(categories2))}) + S.validate(df)