Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data validation Phase 1 #1239

Merged
merged 36 commits into from
Dec 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
ddb8ea4
Saving work
utkarsharma2 Nov 13, 2022
8996d25
Added test case
utkarsharma2 Dec 5, 2022
595bd02
Add testcases for ColumnCheckOperator
utkarsharma2 Dec 7, 2022
8640e5e
Update data types supported by ColumnCheckOperator
utkarsharma2 Dec 7, 2022
d335c3c
Add task_id to operator
utkarsharma2 Dec 7, 2022
43d67bd
Add testcase for table dataset
utkarsharma2 Dec 7, 2022
cce4567
Add doc string to functions
utkarsharma2 Dec 7, 2022
488a668
Moved the test_ColumnCheckOperator.py to data_validation.py/test_Colu…
utkarsharma2 Dec 9, 2022
518cad3
Add SQLCheckOperator to validate tables via sql
utkarsharma2 Dec 9, 2022
57a5a00
Update python-sdk/src/astro/sql/operators/data_validations/SQLCheckOp…
utkarsharma2 Dec 14, 2022
d2f4dcf
Update python-sdk/src/astro/sql/operators/data_validations/SQLCheckOp…
utkarsharma2 Dec 14, 2022
f0140c0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 14, 2022
a2be757
Update the dataframe check method
utkarsharma2 Dec 14, 2022
016461b
Change test execution method to run_dag() from operator.execute()
utkarsharma2 Dec 14, 2022
1220883
Override GoogleBaseHook with BigqueryHook (#1442)
utkarsharma2 Dec 20, 2022
3db2eab
Revoke region changes
utkarsharma2 Dec 20, 2022
12a9334
Add apache-airflow-providers-common-sql as dependency
utkarsharma2 Dec 20, 2022
bc037cb
Remove unwanted code
utkarsharma2 Dec 20, 2022
04d527e
Remove unwanted codes
utkarsharma2 Dec 20, 2022
52e3821
Add location
utkarsharma2 Dec 20, 2022
a18e8d8
Add google_cloud_platform connection
utkarsharma2 Dec 20, 2022
3905153
Update conn_id
utkarsharma2 Dec 20, 2022
a0e6773
Change return type
utkarsharma2 Dec 20, 2022
686a09d
Updated hook
utkarsharma2 Dec 20, 2022
2b91128
Update python-sdk/src/astro/sql/operators/data_validations/ColumnChec…
utkarsharma2 Dec 20, 2022
ec3b4fd
Update python-sdk/src/astro/sql/operators/data_validations/ColumnChec…
utkarsharma2 Dec 20, 2022
cd828d7
Update python-sdk/src/astro/sql/operators/data_validations/ColumnChec…
utkarsharma2 Dec 20, 2022
d1d5c28
Update python-sdk/src/astro/sql/operators/data_validations/ColumnChec…
utkarsharma2 Dec 20, 2022
a0543ad
Revert changes to check since nunique() don't count None
utkarsharma2 Dec 20, 2022
764e08a
Fix Deep Source
utkarsharma2 Dec 20, 2022
c25a736
Merge branch 'main' into DataValidation
utkarsharma2 Dec 20, 2022
9673af7
Refactored code to remove duplication
utkarsharma2 Dec 20, 2022
5864a07
Refactored code
utkarsharma2 Dec 20, 2022
493d651
Code refactor
utkarsharma2 Dec 20, 2022
8c71f39
Refactored ColumnCheckOperator operator
utkarsharma2 Dec 20, 2022
ec40143
Merge branch 'main' into DataValidation
utkarsharma2 Dec 20, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/ci-test-connections.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,9 @@ connections:
description: null
extra:
connection_string: $AZURE_WASB_CONN_STRING
- conn_id: gcp_conn_project
conn_type: google_cloud_platform
description: null
extra:
project: "astronomer-dag-authoring"
project_id: "astronomer-dag-authoring"
4 changes: 4 additions & 0 deletions python-sdk/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ def database_table_fixture(request):
params = deepcopy(request.param)

database_name = params["database"]
user_table = params.get("table", None)
conn_id = DATABASE_NAME_TO_CONN_ID[database_name]
if user_table and user_table.conn_id:
conn_id = user_table.conn_id

database = create_database(conn_id)
table = params.get("table", Table(conn_id=database.conn_id, metadata=database.default_metadata))
if not isinstance(table, TempTable):
Expand Down
3 changes: 2 additions & 1 deletion python-sdk/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"python-frontmatter",
"smart-open",
"SQLAlchemy>=1.3.18",
"apache-airflow-providers-common-sql"
]

keywords = ["airflow", "provider", "astronomer", "sql", "decorator", "task flow", "elt", "etl", "dag"]
Expand Down Expand Up @@ -95,7 +96,7 @@ all = [
"protobuf<=3.20", # Google bigquery client require protobuf <= 3.20.0. We can remove the limitation when this limitation is removed
"openlineage-airflow>=0.17.0",
"apache-airflow-providers-microsoft-azure",
"azure-storage-blob",
"azure-storage-blob"
]
doc = [
"myst-parser>=0.17",
Expand Down
5 changes: 1 addition & 4 deletions python-sdk/src/astro/databases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@
SUPPORTED_DATABASES = set(DEFAULT_CONN_TYPE_TO_MODULE_PATH.keys())


def create_database(
conn_id: str,
table: BaseTable | None = None,
) -> BaseDatabase:
def create_database(conn_id: str, table: BaseTable | None = None) -> BaseDatabase:
"""
Given a conn_id, return the associated Database class.

Expand Down
2 changes: 1 addition & 1 deletion python-sdk/src/astro/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class BaseDatabase(ABC):
NATIVE_AUTODETECT_SCHEMA_CONFIG: Mapping[FileLocation, Mapping[str, list[FileType] | Callable]] = {}
FILE_PATTERN_BASED_AUTODETECT_SCHEMA_SUPPORTED: set[FileLocation] = set()

def __init__(self, conn_id: str):
def __init__(self, conn_id: str, table: BaseTable | None = None): # skipcq: PYL-W0613
self.conn_id = conn_id
self.sql: str | ClauseElement = ""

Expand Down
2 changes: 1 addition & 1 deletion python-sdk/src/astro/databases/google/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def sql_type(self) -> str:
@property
def hook(self) -> BigQueryHook:
"""Retrieve Airflow hook to interface with the BigQuery database."""
return BigQueryHook(gcp_conn_id=self.conn_id, use_legacy_sql=False)
return BigQueryHook(gcp_conn_id=self.conn_id, use_legacy_sql=False, location=BIGQUERY_SCHEMA_LOCATION)

@property
def sqlalchemy_engine(self) -> Engine:
Expand Down
8 changes: 8 additions & 0 deletions python-sdk/src/astro/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@

from astro.sql.operators.append import AppendOperator, append
from astro.sql.operators.cleanup import CleanupOperator, cleanup
from astro.sql.operators.data_validations.ColumnCheckOperator import ( # skipcq: PY-W2000
ColumnCheckOperator,
column_check,
)
from astro.sql.operators.data_validations.SQLCheckOperator import ( # skipcq: PY-W2000
SQLCheckOperator,
sql_check,
)
from astro.sql.operators.dataframe import DataframeOperator, dataframe
from astro.sql.operators.drop import DropTableOperator, drop_table
from astro.sql.operators.export_file import ExportFileOperator, export_file
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
from typing import Any, Dict, Optional, Union

import pandas
from airflow import AirflowException
from airflow.decorators.base import get_unique_task_id
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.common.sql.operators.sql import SQLColumnCheckOperator

from astro.databases import create_database
from astro.table import BaseTable
from astro.utils.typing_compat import Context


class ColumnCheckOperator(SQLColumnCheckOperator):
"""
Performs one or more of the templated checks in the column_checks dictionary.
Checks are performed on a per-column basis specified by the column_mapping.
Each check can take one or more of the following options:
- equal_to: an exact value to equal, cannot be used with other comparison options
- greater_than: value that result should be strictly greater than
- less_than: value that results should be strictly less than
- geq_to: value that results should be greater than or equal to
- leq_to: value that results should be less than or equal to
- tolerance: the percentage that the result may be off from the expected value

:param dataset: the table or dataframe to run checks on
:param column_mapping: the dictionary of columns and their associated checks, e.g.

.. code-block:: python

{
"col_name": {
"null_check": {
"equal_to": 0,
},
"min": {
"greater_than": 5,
"leq_to": 10,
"tolerance": 0.2,
},
"max": {"less_than": 1000, "geq_to": 10, "tolerance": 0.01},
}
}
"""

def __init__(
self,
dataset: Union[BaseTable, pandas.DataFrame],
column_mapping: Dict[str, Dict[str, Any]],
partition_clause: Optional[str] = None,
task_id: Optional[str] = None,
**kwargs,
):
for checks in column_mapping.values():
for check, check_values in checks.items():
self._column_mapping_validation(check, check_values)

self.dataset = dataset
self.column_mapping = column_mapping
self.partition_clause = partition_clause
self.kwargs = kwargs
self.df = None

dataset_qualified_name = ""
dataset_conn_id = ""

if isinstance(dataset, BaseTable):
db = create_database(conn_id=self.dataset.conn_id) # type: ignore
self.conn_id = self.dataset.conn_id
dataset_qualified_name = db.get_table_qualified_name(table=self.dataset)
dataset_conn_id = dataset.conn_id

super().__init__(
table=dataset_qualified_name,
column_mapping=self.column_mapping,
partition_clause=self.partition_clause,
conn_id=dataset_conn_id,
task_id=task_id if task_id is not None else get_unique_task_id("column_check"),
)

def get_db_hook(self) -> DbApiHook:
"""
Get the database hook for the connection.

:return: the database hook object.
"""
db = create_database(conn_id=self.conn_id)
if db.sql_type == "bigquery":
return db.hook
return super().get_db_hook()

def execute(self, context: "Context"):
if isinstance(self.dataset, BaseTable):
return super().execute(context=context)
elif isinstance(self.dataset, pandas.DataFrame):
self.df = self.dataset
else:
raise ValueError("dataset can only be of type pandas.dataframe | Table object")

self.process_checks()

def get_check_result(self, check_name: str, column_name: str):
"""
Get the check method results post validating the dataframe
"""
if self.df is not None and column_name in self.df.columns:
column_checks = {
"null_check": lambda column: column.isna().sum(),
"distinct_check": lambda column: len(column.unique()),
"unique_check": lambda column: len(column) - len(column.unique()),
"min": lambda column: column.min(),
"max": lambda column: column.max(),
}
return column_checks[check_name](column=self.df[column_name])
if self.df is None:
raise ValueError("Dataframe is None")
if column_name not in self.df.columns:
raise ValueError(f"Dataframe doesn't have column {column_name}")

def process_checks(self):
"""
Process all the checks and print the result or raise an exception in the event of failed checks
"""
failed_tests = []
passed_tests = []

# Iterating over columns
for column in self.column_mapping:
checks = self.column_mapping[column]

# Iterating over checks
for check_key, check_val in checks.items():
tolerance = check_val.get("tolerance")
result = self.get_check_result(check_key, column_name=column)
check_val["result"] = result
check_val["success"] = self._get_match(check_val, result, tolerance)
failed_tests.extend(_get_failed_checks(checks, column))
passed_tests.extend(_get_success_checks(checks, column))

if len(failed_tests) > 0:
raise AirflowException(f"The following tests have failed:" f"\n{''.join(failed_tests)}")
if len(passed_tests) > 0:
print(f"The following tests have passed:" f"\n{''.join(passed_tests)}")


def _get_failed_checks(checks, col=None):
return [
f"{get_checks_string(checks, col)} {check_values}\n"
for check, check_values in checks.items()
if not check_values["success"]
]


def _get_success_checks(checks, col=None):
return [
f"{get_checks_string(checks, col)} {check_values}\n"
for check, check_values in checks.items()
if check_values["success"]
]


def get_checks_string(check, col):
if col:
return f"Column: {col}\nCheck: {check},\nCheck Values:"
return f"\tCheck: {check},\n\tCheck Values:"


def column_check(
dataset: Union[BaseTable, pandas.DataFrame],
column_mapping: Dict[str, Dict[str, Any]],
partition_clause: Optional[str] = None,
task_id: Optional[str] = None,
**kwargs,
) -> ColumnCheckOperator:
"""
Performs one or more of the templated checks in the column_checks dictionary.
Checks are performed on a per-column basis specified by the column_mapping.
Each check can take one or more of the following options:
- equal_to: an exact value to equal, cannot be used with other comparison options
- greater_than: value that result should be strictly greater than
- less_than: value that results should be strictly less than
- geq_to: value that results should be greater than or equal to
- leq_to: value that results should be less than or equal to
- tolerance: the percentage that the result may be off from the expected value

:param dataset: dataframe or BaseTable that has to be validated
:param column_mapping: the dictionary of columns and their associated checks, e.g.

.. code-block:: python

{
"col_name": {
"null_check": {
"equal_to": 0,
},
"min": {
"greater_than": 5,
"leq_to": 10,
"tolerance": 0.2,
},
"max": {"less_than": 1000, "geq_to": 10, "tolerance": 0.01},
}
}
"""
return ColumnCheckOperator(
dataset=dataset,
column_mapping=column_mapping,
partition_clause=partition_clause,
kwargs=kwargs,
task_id=task_id,
)
Loading