In [1]:
from typing import List, Optional, Union
from pydantic import BaseModel, ValidationInfo, field_validator, model_validator
import polars as pl

class Feature(BaseModel):
    column_name: str
    name: str
    description: Optional[str] = None

    @model_validator(mode='before')
    def validate_column_name(cls, values):
        column_name = values.get("column_name")
        context = values.get("context")
        if context is not None and isinstance(context, pl.DataFrame):
            if column_name not in context.columns:
                raise ValueError(f"Column '{column_name}' not found in the DataFrame.")
        return values

    class Config:
        arbitrary_types_allowed = True
        extra = "allow"

class NumericalFeature(Feature):
    @model_validator(mode='before')
    def validate_numerical_column(cls, values):
        column_name = values.get("column_name")
        context = values.get("context")
        if context is not None and isinstance(context, pl.DataFrame):
            if column_name not in context.columns:
                raise ValueError(f"Column '{column_name}' not found in the DataFrame.")
            if context[column_name].dtype not in [
                pl.Boolean,
                pl.Int8,
                pl.Int16,
                pl.Int32,
                pl.Int64,
                pl.UInt8,
                pl.UInt16,
                pl.UInt32,
                pl.UInt64,
                pl.Float32,
                pl.Float64,
                pl.Decimal,
            ]:
                raise ValueError(
                    f"Column '{column_name}' must be of a numeric type (Boolean, Integer, Unsigned Integer, Float, or Decimal)."
                )
        return values

class EmbeddingFeature(Feature):
    @model_validator(mode='before')
    def validate_embedding_column(cls, values):
        column_name = values.get("column_name")
        context = values.get("context")
        if context is not None and isinstance(context, pl.DataFrame):
            if column_name not in context.columns:
                raise ValueError(f"Column '{column_name}' not found in the DataFrame.")
            if context[column_name].dtype not in [pl.List(pl.Float32), pl.List(pl.Float64)]:
                raise ValueError(f"Column '{column_name}' must be of type pl.List(pl.Float32) or pl.List(pl.Float64).")
        return values

class CategoricalFeature(Feature):
    @model_validator(mode='before')
    def validate_categorical_column(cls, values):
        column_name = values.get("column_name")
        context = values.get("context")
        if context is not None and isinstance(context, pl.DataFrame):
            if column_name not in context.columns:
                raise ValueError(f"Column '{column_name}' not found in the DataFrame.")
            if context[column_name].dtype not in [
                pl.Utf8,
                pl.Categorical,
                pl.Enum,
                pl.Int8,
                pl.Int16,
                pl.Int32,
                pl.Int64,
                pl.UInt8,
                pl.UInt16,
                pl.UInt32,
                pl.UInt64,
            ]:
                raise ValueError(
                    f"Column '{column_name}' must be of type pl.Utf8, pl.Categorical, pl.Enum, or an integer type."
                )
        return values

class FeatureSet(BaseModel):
    numerical: List[NumericalFeature] = []
    embeddings: List[EmbeddingFeature] = []
    categorical: List[CategoricalFeature] = []

    class Config:
        arbitrary_types_allowed = True
        extra = "allow"

class InputConfig(BaseModel):
    feature_sets: List[FeatureSet]

    def validate_with_dataframe(self, df: pl.DataFrame):
        for feature_set in self.feature_sets:
            for feature_type in ["numerical", "embeddings", "categorical"]:
                for feature in getattr(feature_set, feature_type):
                    feature.model_validate({"context": df, **feature.dict()})

    class Config:
        arbitrary_types_allowed = True
        extra = "allow"

# Example usage
data = {
    "age": [25, 30, 35],
    "income": [50000.0, 60000.0, 70000.0],
    "gender": ["Male", "Female", "Male"],
    "is_employed": [True, False, True],
    "embedding": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],
}

df = pl.DataFrame(data)

# Example 1: Valid feature set
numerical_feature_1 = NumericalFeature(column_name="age", name="Age")
numerical_feature_2 = NumericalFeature(column_name="is_employed", name="Is Employed")
embedding_feature = EmbeddingFeature(column_name="embedding", name="Embedding")
categorical_feature = CategoricalFeature(column_name="gender", name="Gender")

feature_set_valid = FeatureSet(
    numerical=[numerical_feature_1, numerical_feature_2],
    embeddings=[embedding_feature],
    categorical=[categorical_feature],
)

input_config_valid = InputConfig(feature_sets=[feature_set_valid])

try:
    input_config_valid.validate_with_dataframe(df)
    print("Validation successful for valid feature set!")
except ValueError as e:
    print(f"Validation failed for valid feature set: {str(e)}")

# Example 2: Invalid numerical feature
numerical_feature_invalid = NumericalFeature(column_name="gender", name="Gender")

feature_set_invalid_numerical = FeatureSet(
    numerical=[numerical_feature_invalid],
    embeddings=[embedding_feature],
    categorical=[categorical_feature],
)

input_config_invalid_numerical = InputConfig(feature_sets=[feature_set_invalid_numerical])

try:
    input_config_invalid_numerical.validate_with_dataframe(df)
    print("Validation successful for invalid numerical feature!")
except ValueError as e:
    print(f"Validation failed for invalid numerical feature: {str(e)}")

# Example 3: Invalid embedding feature
embedding_feature_invalid = EmbeddingFeature(column_name="age", name="Age")

feature_set_invalid_embedding = FeatureSet(
    numerical=[numerical_feature_1, numerical_feature_2],
    embeddings=[embedding_feature_invalid],
    categorical=[categorical_feature],
)

input_config_invalid_embedding = InputConfig(feature_sets=[feature_set_invalid_embedding])

try:
    input_config_invalid_embedding.validate_with_dataframe(df)
    print("Validation successful for invalid embedding feature!")
except ValueError as e:
    print(f"Validation failed for invalid embedding feature: {str(e)}")

# Example 4: Invalid categorical feature
categorical_feature_invalid = CategoricalFeature(column_name="income", name="Income")

feature_set_invalid_categorical = FeatureSet(
    numerical=[numerical_feature_1, numerical_feature_2],
    embeddings=[embedding_feature],
    categorical=[categorical_feature_invalid],
)

input_config_invalid_categorical = InputConfig(feature_sets=[feature_set_invalid_categorical])

try:
    input_config_invalid_categorical.validate_with_dataframe(df)
    print("Validation successful for invalid categorical feature!")
except ValueError as e:
    print(f"Validation failed for invalid categorical feature: {str(e)}")

Validation successful for valid feature set!
Validation failed for invalid numerical feature: 1 validation error for NumericalFeature
  Value error, Column 'gender' must be of a numeric type (Boolean, Integer, Unsigned Integer, Float, or Decimal). [type=value_error, input_value={'context': shape: (3, 5)...r', 'description': None}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.7/v/value_error
Validation failed for invalid embedding feature: 1 validation error for EmbeddingFeature
  Value error, Column 'age' must be of type pl.List(pl.Float32) or pl.List(pl.Float64). [type=value_error, input_value={'context': shape: (3, 5)...e', 'description': None}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.7/v/value_error
Validation failed for invalid categorical feature: 1 validation error for CategoricalFeature
  Value error, Column 'income' must be of type pl.Utf8, pl.Categorical, pl.Enum, or an integer type. [type=value_error

In [6]:
def convert_utf8_to_enum(df: pl.DataFrame, threshold: float = 0.5) -> pl.DataFrame:
    if not 0 < threshold < 1:
        raise ValueError("Threshold must be between 0 and 1 (exclusive).")

    for column in df.columns:
        if df[column].dtype == pl.Utf8 and len(df[column]) > 0:
            unique_values = df[column].unique()
            unique_ratio = len(unique_values) / len(df[column])

            if unique_ratio <= threshold:
                enum_dtype = pl.Enum(unique_values.to_list())
                df = df.with_columns(df[column].cast(enum_dtype))
            else:
                print(f"Column '{column}' has a high ratio of unique values ({unique_ratio:.2f}). Skipping conversion to Enum.")
        elif df[column].dtype == pl.Utf8 and len(df[column]) == 0:
            print(f"Column '{column}' is empty. Skipping conversion to Enum.")

    return df

In [7]:
data = {
    "name": ["Alice", "Bob", "Charlie", "Alice", "Bob"],
    "age": [25, 30, 35, 25, 30],
    "city": ["New York", "London", "Paris", "New York", "London"],
    "empty_col": [None] * 5,  # Fill with null values to match the length
}

df = pl.DataFrame(data)
print("Original DataFrame:")
print(df)

df_enum = convert_utf8_to_enum(df, threshold=0.6)
print("\nDataFrame with Enum columns:")
print(df_enum)
print("\nColumn dtypes:")
print(df_enum.dtypes)

Original DataFrame:
shape: (5, 4)
┌─────────┬─────┬──────────┬───────────┐
│ name    ┆ age ┆ city     ┆ empty_col │
│ ---     ┆ --- ┆ ---      ┆ ---       │
│ str     ┆ i64 ┆ str      ┆ null      │
╞═════════╪═════╪══════════╪═══════════╡
│ Alice   ┆ 25  ┆ New York ┆ null      │
│ Bob     ┆ 30  ┆ London   ┆ null      │
│ Charlie ┆ 35  ┆ Paris    ┆ null      │
│ Alice   ┆ 25  ┆ New York ┆ null      │
│ Bob     ┆ 30  ┆ London   ┆ null      │
└─────────┴─────┴──────────┴───────────┘

DataFrame with Enum columns:
shape: (5, 4)
┌─────────┬─────┬──────────┬───────────┐
│ name    ┆ age ┆ city     ┆ empty_col │
│ ---     ┆ --- ┆ ---      ┆ ---       │
│ enum    ┆ i64 ┆ enum     ┆ null      │
╞═════════╪═════╪══════════╪═══════════╡
│ Alice   ┆ 25  ┆ New York ┆ null      │
│ Bob     ┆ 30  ┆ London   ┆ null      │
│ Charlie ┆ 35  ┆ Paris    ┆ null      │
│ Alice   ┆ 25  ┆ New York ┆ null      │
│ Bob     ┆ 30  ┆ London   ┆ null      │
└─────────┴─────┴──────────┴───────────┘

Column dtypes:
[Enu