Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add decorators for functions #221

20 changes: 18 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,28 @@ quinn.validate_presence_of_columns(source_df, ["name", "age", "fun"])

**validate_schema()**

Raises an exception unless `source_df` contains all the `StructFields` defined in the `required_schema`.
Raises an exception unless `source_df` contains all the `StructFields` defined in the `required_schema`. By default, `ignore_nullable` is set to False, so exception will be raised even if column names and data types are matching but nullability conditions are mismatching.

```python
quinn.validate_schema(source_df, required_schema)
quinn.validate_schema(required_schema, _df=source_df)
```

You can also set `ignore_nullable` to True, so the validation will happen only on column names and data types, not on nullability.

```python
quinn.validate_schema(required_schema, ignore_nullable=True, _df=source_df)
```

> [!TIP]
> This function can also be used as a decorator to other functions that return a dataframe. This can help validate the schema of the returned df. When used as a decorator, you don't need to pass the `_df` argument as this validation is performed on the df returned by the base function on which the decorator is applied.
>
> ```python
> @quinn.validate_schema(required_schema, ignore_nullable=True)
> def get_df():
> return df
> ```


**validate_absence_of_columns()**

Raises an exception if `source_df` contains `age` or `cool` columns.
Expand Down
56 changes: 37 additions & 19 deletions quinn/dataframe_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,41 +36,59 @@ def validate_presence_of_columns(df: DataFrame, required_col_names: list[str]) -
error_message = f"The {missing_col_names} columns are not included in the DataFrame with the following columns {all_col_names}"
if missing_col_names:
raise DataFrameMissingColumnError(error_message)



def validate_schema(
df: DataFrame,
required_schema: StructType,
ignore_nullable: bool = False,
) -> None:
required_schema: StructType,
ignore_nullable: bool = False,
_df: DataFrame = None
) -> function:
"""Function that validate if a given DataFrame has a given StructType as its schema.
Implemented as a decorator factory so can be used both as a standalone function or as
a decorator to another function.

:param df: DataFrame to validate
:type df: DataFrame
:param required_schema: StructType required for the DataFrame
:type required_schema: StructType
:param ignore_nullable: (Optional) A flag for if nullable fields should be
ignored during validation
:type ignore_nullable: bool, optional
:param _df: DataFrame to validate, mandatory when called as a function. Not required
when called as a decorator
:type _df: DataFrame

:raises DataFrameMissingStructFieldError: if any StructFields from the required
schema are not included in the DataFrame schema
"""
_all_struct_fields = copy.deepcopy(df.schema)
_required_schema = copy.deepcopy(required_schema)

if ignore_nullable:
for x in _all_struct_fields:
x.nullable = None
def decorator(func):
def wrapper(*args, **kwargs):
df = func(*args, **kwargs)
_all_struct_fields = copy.deepcopy(df.schema)
_required_schema = copy.deepcopy(required_schema)

if ignore_nullable:
for x in _all_struct_fields:
x.nullable = None

for x in _required_schema:
x.nullable = None

missing_struct_fields = [x for x in _required_schema if x not in _all_struct_fields]
error_message = f"The {missing_struct_fields} StructFields are not included in the DataFrame with the following StructFields {_all_struct_fields}"

for x in _required_schema:
x.nullable = None
if missing_struct_fields:
raise DataFrameMissingStructFieldError(error_message)
else:
print("Success! DataFrame matches the required schema!")

missing_struct_fields = [x for x in _required_schema if x not in _all_struct_fields]
error_message = f"The {missing_struct_fields} StructFields are not included in the DataFrame with the following StructFields {_all_struct_fields}"
return df
return wrapper

if missing_struct_fields:
raise DataFrameMissingStructFieldError(error_message)
if _df is None:
# This means the function is being used as a decorator
return decorator
else:
# This means the function is being called directly with a DataFrame
return decorator(lambda: _df)()


def validate_absence_of_columns(df: DataFrame, prohibited_col_names: list[str]) -> None:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_dataframe_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def it_raises_when_struct_field_is_missing1():
]
)
with pytest.raises(quinn.DataFrameMissingStructFieldError) as excinfo:
quinn.validate_schema(source_df, required_schema)
quinn.validate_schema(required_schema, _df=source_df)

current_spark_version = semver.Version.parse(spark.version)
spark_330 = semver.Version.parse("3.3.0")
Expand All @@ -53,7 +53,7 @@ def it_does_nothing_when_the_schema_matches():
StructField("age", LongType(), True),
]
)
quinn.validate_schema(source_df, required_schema)
quinn.validate_schema(required_schema, _df=source_df)

def nullable_column_mismatches_are_ignored():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
Expand All @@ -64,7 +64,7 @@ def nullable_column_mismatches_are_ignored():
StructField("age", LongType(), False),
]
)
quinn.validate_schema(source_df, required_schema, ignore_nullable=True)
quinn.validate_schema(required_schema, ignore_nullable=True, _df=source_df)


def describe_validate_absence_of_columns():
Expand Down
Loading