Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using decorators to perform schema validations on DataFrames #141

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions quinn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,7 @@
"sort_columns",
"append_if_schema_identical",
"flatten_dataframe",
"validate_returned_schema",
"ensure_columns_present",
"ensure_columns_absent",
]
42 changes: 42 additions & 0 deletions quinn/dataframe_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,45 @@ def validate_absence_of_columns(df: DataFrame, prohibited_col_names: list[str])
error_message = f"The {extra_col_names} columns are not allowed to be included in the DataFrame with the following columns {all_col_names}"
if extra_col_names:
raise DataFrameProhibitedColumnError(error_message)


def validate_returned_schema(required_schema: StructType, ignore_nullable: bool = False):
def inner_decorator(func):
def wrapper(*args, **kwargs):
# Call the function that returns a DataFrame
result_df = func(*args, **kwargs)

# Validate the schema of the DataFrame
validate_schema(result_df, required_schema, ignore_nullable)

return result_df
return wrapper
return inner_decorator


def ensure_columns_present(required_col_names: list[str]):
def inner_decorator(func):
def wrapper(*args, **kwargs):
# Call the function that returns a DataFrame
result_df = func(*args, **kwargs)

# Validate the presence of columns in the DataFrame
validate_presence_of_columns(result_df, required_col_names)

return result_df
return wrapper
return inner_decorator


def ensure_columns_absent(prohibited_col_names: list[str]):
def inner_decorator(func):
def wrapper(*args, **kwargs):
# Call the function that returns a DataFrame
result_df = func(*args, **kwargs)

# Validate the absence of prohibited columns in the DataFrame
validate_absence_of_columns(result_df, prohibited_col_names)

return result_df
return wrapper
return inner_decorator
72 changes: 72 additions & 0 deletions tests/test_dataframe_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,75 @@ def it_does_nothing_when_no_unallowed_columns_are_present(spark):
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
quinn.validate_absence_of_columns(source_df, ["favorite_color"])


@pytest.fixture(scope="session")
def sample_schema(spark):
return StructType([
StructField("col1", StringType(), False),
StructField("col2", LongType(), False)
])


@pytest.fixture(scope='session')
def sample_df(spark):
return spark.createDataFrame(
[
("A", 1),
("B", 2)
], ["col1", "col2"]
)


@pytest.fixture(scope='session')
def missing_col_df(spark):
return spark.createDataFrame([(1,)], ["col1"])


@pytest.fixture(scope='session')
def extra_col_df(spark):
return spark.createDataFrame([("A", 1, "C")], ["col1", "col2", "extra_col"])


def test_validate_returned_schema_positive(sample_schema, sample_df):
@quinn.validate_returned_schema(sample_schema)
def get_df():
return sample_df
get_df()


def test_validate_returned_schema_negative(sample_schema, missing_col_df):
@quinn.validate_returned_schema(sample_schema)
def get_wrong_df():
return missing_col_df
with pytest.raises(quinn.DataFrameMissingStructFieldError):
get_wrong_df()


def test_ensure_columns_present_positive(sample_df):
@quinn.ensure_columns_present(["col1", "col2"])
def get_df():
return sample_df
get_df()


def test_ensure_columns_present_negative(missing_col_df):
@quinn.ensure_columns_present(["col1", "col2"])
def get_wrong_df():
return missing_col_df
with pytest.raises(quinn.DataFrameMissingColumnError):
get_wrong_df()

def test_ensure_columns_absent_positive(sample_df):
@quinn.ensure_columns_absent(["extra_col"])
def get_df():
return sample_df
get_df()


def test_ensure_columns_absent_negative(extra_col_df):
@quinn.ensure_columns_absent(["extra_col"])
def get_wrong_df():
return extra_col_df
with pytest.raises(quinn.DataFrameProhibitedColumnError):
get_wrong_df()