diff --git a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py index 285b4657be6eb..21c8b6502b0bb 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py +++ b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py @@ -20,12 +20,15 @@ import ast import re from collections.abc import Callable, Iterable, Mapping, Sequence +from dataclasses import dataclass from functools import cached_property from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, SupportsAbs +from airflow.providers.common.compat.openlineage.check import require_openlineage_version from airflow.providers.common.compat.sdk import ( AirflowException, AirflowFailException, + AirflowOptionalProviderFeatureException, AirflowSkipException, BaseHook, BaseOperator, @@ -127,6 +130,46 @@ def default_output_processor(results: list[Any], descriptions: list[Sequence[Seq return results +@dataclass +class SQLCheckResult: + """Record of a single SQL check result.""" + + name: str + """Unique name identifying this check.""" + + check_type: str + """Classification of the check, e.g. ``"not_null"``, ``"row_count"``, ``"unique"``.""" + + success: bool + """Whether the check found no issues (``True``) or found issues (``False``).""" + + severity: str = "error" + """How severe a failure of this check is: ``"error"`` (raises, default), ``"warn"`` (logs a warning only), + or ``"info"`` (informational, never causes task failure, like branching operator).""" + + column: str | None = None + """Column the check refers to. When set, the assertion targets a specific column rather than + the whole table. Should match the column name in the schema.""" + + table: str | None = None + """Table the check was performed against.""" + + expected: str | None = None + """The expected value or threshold, serialized as a string, e.g. ``"> 0"`` or ``"[10, 100]"``.""" + + actual: str | None = None + """The actual value observed during the check, serialized as a string.""" + + content: str | None = None + """The check body — typically the SQL expression used to perform the check.""" + + description: str | None = None + """Human-readable description of what the check verifies.""" + + params: dict | None = None + """Arbitrary key-value pairs with check-specific context, e.g. accept_none or partition_clause.""" + + class BaseSQLOperator(BaseOperator): """ This is a base class for generic SQL Operator to get a DB Hook. @@ -156,6 +199,7 @@ def __init__( self.database = database self.hook_params = hook_params or {} self.retry_on_failure = retry_on_failure + self.check_results: list[SQLCheckResult] = [] # Used by listeners @classmethod # TODO: can be removed once Airflow min version for this provider is 3.0.0 or higher @@ -300,14 +344,116 @@ def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage | database_specific_lineage = None if database_specific_lineage is None: - return operator_lineage + if not self.check_results: + return operator_lineage + try: + return self._attach_check_facets(operator_lineage) + except AirflowOptionalProviderFeatureException as err: + self.log.debug("OpenLineage could not attach check facets: %s", err) + return operator_lineage - return OperatorLineage( + merged = OperatorLineage( inputs=operator_lineage.inputs + database_specific_lineage.inputs, outputs=operator_lineage.outputs + database_specific_lineage.outputs, run_facets=merge_dicts(operator_lineage.run_facets, database_specific_lineage.run_facets), job_facets=merge_dicts(operator_lineage.job_facets, database_specific_lineage.job_facets), ) + if not self.check_results: + return merged + try: + return self._attach_check_facets(merged) + except AirflowOptionalProviderFeatureException as err: + self.log.debug("OpenLineage could not attach check facets: %s", err) + return merged + + @require_openlineage_version(client_min_version="1.47.0") + def _attach_check_facets(self, operator_lineage: OperatorLineage) -> OperatorLineage: + """ + Attach OpenLineage check-result facets to the given lineage object. + + Requires openlineage-python >= 1.47.0, which introduced the extended ``Assertion`` + and ``TestExecution`` schemas (``name``, ``description``, ``expected``, ``actual``, ``content``, + ``params`` fields). The decorator raises ``AirflowOptionalProviderFeatureException`` when + the client is absent or too old; callers are expected to catch that exception. + + Results with a ``table`` set are attached as``DataQualityAssertionsDatasetFacet`` on the matching + dataset (matched by suffix). Unmatched results, and results without a table, fall back + to a run-level ``TestRunFacet``. + """ + from openlineage.client.facet_v2 import data_quality_assertions_dataset, test_run + + by_table: dict[str | None, list[SQLCheckResult]] = {} + for r in self.check_results: + by_table.setdefault(r.table, []).append(r) + + run_level: list[SQLCheckResult] = list(by_table.pop(None, [])) + + for table, results in by_table.items(): + assertions = [ + data_quality_assertions_dataset.Assertion( + assertion=r.check_type, + success=r.success, + severity=r.severity, + column=r.column, + name=r.name, + description=r.description, + expected=r.expected, + actual=r.actual, + content=r.content, + contentType="sql", + params=r.params, + ) + for r in results + ] + table_lower = table.lower() # type: ignore[union-attr] + matched = False + for group in (operator_lineage.inputs, operator_lineage.outputs): + # Exact match takes priority over suffix match when both are present in the same group. + target = next( + (ds for ds in group if ds.name.lower() == table_lower), + None, + ) or next( + (ds for ds in group if ds.name.lower().endswith(f".{table_lower}")), + None, + ) + if target is not None: + target.facets = target.facets or {} + existing = target.facets.get("dataQualityAssertions") + facet_assertions = ( + existing.assertions + assertions if existing is not None else assertions + ) + target.facets["dataQualityAssertions"] = ( + data_quality_assertions_dataset.DataQualityAssertionsDatasetFacet( + assertions=facet_assertions + ) + ) + matched = True + if not matched: + run_level.extend(results) + + if run_level: + tests = [ + test_run.TestExecution( + name=r.name, + status="pass" if r.success else "fail", + severity=r.severity, + type=r.check_type, + description=r.description, + expected=r.expected, + actual=r.actual, + content=r.content, + contentType="sql", + params={"tested_column": r.column, "tested_table": r.table, **(r.params or {})}, + ) + for r in run_level + ] + operator_lineage.run_facets = operator_lineage.run_facets or {} + existing = operator_lineage.run_facets.get("test") + if existing is not None: + tests = existing.tests + tests + operator_lineage.run_facets["test"] = test_run.TestRunFacet(tests=tests) + + return operator_lineage class SQLExecuteQueryOperator(BaseSQLOperator): @@ -554,6 +700,9 @@ def execute(self, context: Context): self.column_mapping[column][check], result, tolerance ) + # Save check results before raising exception, to be used by listeners + self.check_results = self._build_check_results() + failed_tests = [ f"Column: {col}\n\tCheck: {check},\n\tCheck Values: {check_values}\n" for col, checks in self.column_mapping.items() @@ -685,6 +834,105 @@ def _column_mapping_validation(self, check, check_values): "'less than or equal to', use geq_to or leq_to." ) + def _build_check_results(self) -> list[SQLCheckResult]: + try: + return [ + SQLCheckResult( + name=f"{col}.{check}", + check_type=self._get_column_check_type(check, check_values), + success=check_values.get("success", False), + column=col, + table=self.table, + expected=self._format_column_expected(check_values), + actual=str(check_values["result"]) if "result" in check_values else None, + content=self.column_checks[check].format(column=col), + description="Column-level statistical check against the configured threshold", + params={ + k: v + for k, v in { + "equal_to": check_values.get("equal_to"), + "greater_than": check_values.get("greater_than"), + "geq_to": check_values.get("geq_to"), + "less_than": check_values.get("less_than"), + "leq_to": check_values.get("leq_to"), + "tolerance": check_values.get("tolerance"), + "accept_none": self.accept_none, + "partition_clause": self.partition_clause, + "check_partition_clause": check_values.get("partition_clause"), + }.items() + if v is not None + } + or None, + ) + for col, checks in self.column_mapping.items() + for check, check_values in checks.items() + ] + except Exception as err: + self.log.debug("Failed to build check results %s", err) + return [] + + @staticmethod + def _get_column_check_type(check: str, check_values: dict) -> str: + """ + Return a dbt-style check type for a column check based on the metric and comparison operators. + + Range operators (``greater_than``, ``geq_to``, ``less_than``, ``leq_to``) always yield + ``accepted_range``. ``equal_to`` with a non-zero tolerance also yields ``accepted_range`` + because the check expands to ``[value*(1-tol), value*(1+tol)]``. For ``equal_to``-only + checks without tolerance, ``null_check`` and ``unique_check`` use their semantic names + (``not_null``, ``unique``) when asserting the canonical zero value; any other target becomes + ``accepted_values``. Custom check names fall through unchanged. + """ + if any(k in check_values for k in ("greater_than", "geq_to", "less_than", "leq_to")): + return "accepted_range" + if "equal_to" in check_values: + if check_values.get("tolerance") and check_values["equal_to"] != 0: + return "accepted_range" + if check == "null_check" and check_values["equal_to"] == 0: + return "not_null" + if check == "unique_check" and check_values["equal_to"] == 0: + return "unique" + return "accepted_values" + return check + + @staticmethod + def _format_column_expected(check_values: dict) -> str: + """ + Return the expected string for a column check, expanding tolerance into actual bounds. + + Without tolerance: shows raw comparison operators (e.g. ``">5, <=10"``). + With tolerance: computes and shows the actual relaxed bounds using the same arithmetic + as ``_get_match``, so ``equal_to=5, tolerance=0.1`` becomes ``">= 4.5, <= 5.5"``. + """ + tol = check_values.get("tolerance") + if tol is None: + return ", ".join( + f"{op}{check_values[key]}" + for key, op in { + "equal_to": "", + "greater_than": ">", + "geq_to": ">=", + "less_than": "<", + "leq_to": "<=", + }.items() + if key in check_values + ) + tol = float(tol) # The operator already treats is as numeric in `_get_match`. + parts = [] + if "equal_to" in check_values: + v = check_values["equal_to"] + parts.append(f">= {v * (1 - tol)}") + parts.append(f"<= {v * (1 + tol)}") + if "greater_than" in check_values: + parts.append(f"> {check_values['greater_than'] * (1 - tol)}") + if "geq_to" in check_values: + parts.append(f">= {check_values['geq_to'] * (1 - tol)}") + if "less_than" in check_values: + parts.append(f"< {check_values['less_than'] * (1 + tol)}") + if "leq_to" in check_values: + parts.append(f"<= {check_values['leq_to'] * (1 + tol)}") + return ", ".join(parts) + class SQLTableCheckOperator(BaseSQLOperator): """ @@ -754,6 +1002,16 @@ def execute(self, context: Context): hook = self.get_db_hook() records = hook.get_records(self.sql) + if records: + self.log.info("Record:\n%s", records) + for row in records: + check, result = row + self.checks[check]["result"] = str(result) + self.checks[check]["success"] = _parse_boolean(str(result)) + + # Save check results before raising exception, to be used by listeners + self.check_results = self._build_check_results(records) + if not records: # accept_none prevents an error from being thrown if there are no records in the table if self.accept_none: @@ -767,12 +1025,6 @@ def execute(self, context: Context): # Otherwise, we'll raise an exception self._raise_exception(f"The following query returned zero rows: {self.sql}") - self.log.info("Record:\n%s", records) - - for row in records: - check, result = row - self.checks[check]["success"] = _parse_boolean(str(result)) - failed_tests = [ f"\tCheck: {check},\n\tCheck Values: {check_values}\n" for check, check_values in self.checks.items() @@ -809,6 +1061,36 @@ def _generate_partition_clause(check_name): for check_name, value in self.checks.items() ) + def _build_check_results(self, records) -> list[SQLCheckResult]: + try: + return [ + SQLCheckResult( + name=check_name, + check_type="expression_is_true", + success=check_values.get("success", False), + table=self.table, + content=check_values.get("check_statement"), + expected="all truthy", + actual=check_values.get("result"), + severity="warn" if (not records and self.accept_none) else "error", + description="User-defined SQL expression must evaluate to true", + params={ + k: v + for k, v in { + "accept_none": self.accept_none, + "partition_clause": self.partition_clause, + "check_partition_clause": check_values.get("partition_clause"), + }.items() + if v is not None + } + or None, + ) + for check_name, check_values in self.checks.items() + ] + except Exception as err: + self.log.debug("Failed to build check results %s", err) + return [] + class SQLCheckOperator(BaseSQLOperator): """ @@ -874,6 +1156,10 @@ def execute(self, context: Context): records = self.get_db_hook().get_first(self.sql, self.parameters) self.log.info("Record: %s", records) + + # Save check results before raising exception, to be used by listeners + self.check_results = self._build_check_results(records) + if not records: self._raise_exception(f"The following query returned zero rows: {self.sql}") elif isinstance(records, dict) and not all(records.values()): @@ -883,6 +1169,31 @@ def execute(self, context: Context): self.log.info("Success.") + def _build_check_results(self, records) -> list[SQLCheckResult]: + try: + if not records: + success = False + elif isinstance(records, dict): + success = all(records.values()) + else: + success = all(records) + + return [ + SQLCheckResult( + name=self.task_id, + check_type="expression_is_true", + success=success, + expected="all truthy", + actual=str(records) if records else None, + content=self.sql, + description="All values in the first returned row must evaluate to true", + params={"parameters": self.parameters} if self.parameters else None, + ) + ] + except Exception as err: + self.log.debug("Failed to build check results %s", err) + return [] + class SQLValueCheckOperator(BaseSQLOperator): """ @@ -956,6 +1267,10 @@ def check_value(self, records): def execute(self, context: Context): self.log.info("Executing SQL check: %s", self.sql) records = self.get_db_hook().get_first(self.sql, self.parameters) + + # Save check results before raising exception, to be used by listeners + self.check_results = self._build_check_results(records) + self.check_value(records) def _to_float(self, records): @@ -973,6 +1288,54 @@ def _get_numeric_matches(self, numeric_records, numeric_pass_value_conv): return [record == numeric_pass_value_conv for record in numeric_records] + def _evaluate_check_value(self, records) -> bool: + """Return whether the value check passes without raising.""" + if not records: + return False + pass_value_conv = _convert_to_float_if_possible(self.pass_value) + if not isinstance(pass_value_conv, float): + return all(self._get_string_matches(records, pass_value_conv)) + try: + numeric_records = self._to_float(records) + except (ValueError, TypeError): + return False + return all(self._get_numeric_matches(numeric_records, pass_value_conv)) + + def _build_check_results(self, records) -> list[SQLCheckResult]: + try: + expected_str = self.pass_value + check_type = "accepted_values" + + pass_value_conv = _convert_to_float_if_possible(self.pass_value) + if isinstance(pass_value_conv, float) and isinstance(self.tol, float): + expected_str = f">= {pass_value_conv * (1 - self.tol)}, <= {pass_value_conv * (1 + self.tol)}" + check_type = "accepted_range" + + return [ + SQLCheckResult( + name=self.task_id, + check_type=check_type, + success=self._evaluate_check_value(records), + expected=expected_str, + actual=str(records) if records else None, + content=self.sql, + description="All values in the first returned row must match the expected value", + params={ + k: v + for k, v in { + "pass_value": self.pass_value, + "tolerance": self.tol, + "parameters": self.parameters, + }.items() + if v is not None + } + or None, + ) + ] + except Exception as err: + self.log.debug("Failed to build check results %s", err) + return [] + class SQLIntervalCheckOperator(BaseSQLOperator): """ @@ -1009,6 +1372,11 @@ class SQLIntervalCheckOperator(BaseSQLOperator): "relative_diff": lambda cur, ref: abs(cur - ref) / ref, } + ratio_formula_expressions = { + "max_over_min": "max({current}, {past}) / min({current}, {past})", + "relative_diff": "abs({current} - {past}) / {past}", + } + def __init__( self, *, @@ -1103,6 +1471,9 @@ def execute(self, context: Context): threshold, ) + # Save check results before raising exception, to be used by listeners + self.check_results = self._build_check_results(all_tests_results) + failed_tests = [single for single in all_tests_results.values() if not single["success"]] if failed_tests: self.log.warning( @@ -1124,6 +1495,37 @@ def execute(self, context: Context): self.log.info("All tests have passed") + def _build_check_results(self, all_tests_results: dict[str, dict[str, Any]]) -> list[SQLCheckResult]: + try: + return [ + SQLCheckResult( + name=f"interval_{metric}", + check_type="accepted_range", + success=details["success"], + table=self.table, + expected=f"< {details['threshold']}", + content=self.ratio_formula_expressions.get(self.ratio_formula, self.ratio_formula).format( + current=details["current_metric"], past=details["past_metric"] + ), + actual=str(details["ratio"]) if details["ratio"] is not None else "0", + description="Ratio of current metric to historical baseline must be below the threshold", + params={ + "threshold": details["threshold"], + "days_back": self.days_back, + "ratio_formula_name": self.ratio_formula, + "ratio_formula": self.ratio_formula_expressions.get(self.ratio_formula), + "date_filter_column": self.date_filter_column, + "ignore_zero": self.ignore_zero, + "current_metric": details["current_metric"], + "past_metric": details["past_metric"], + }, + ) + for metric, details in all_tests_results.items() + ] + except Exception as err: + self.log.debug("Failed to build check results %s", err) + return [] + class SQLThresholdCheckOperator(BaseSQLOperator): """ @@ -1202,6 +1604,10 @@ def execute(self, context: Context): } self.push(meta_data) + + # Save check results before raising exception, to be used by listeners + self.check_results = self._build_check_results(result, meta_data) + if not meta_data["within_threshold"]: result = ( round(meta_data.get("result"), 2) # type: ignore[arg-type] @@ -1220,6 +1626,27 @@ def execute(self, context: Context): self.log.info("Test %s Successful.", self.task_id) + def _build_check_results(self, result: Any, meta_data: dict[str, Any]) -> list[SQLCheckResult]: + try: + return [ + SQLCheckResult( + name=self.task_id, + check_type="accepted_range", + success=meta_data["within_threshold"], + expected=f">= {meta_data['min_threshold']}, <= {meta_data['max_threshold']}", + actual=str(result), + content=self.sql, + description="SQL result must fall within the configured bounds", + params={ + "min_threshold": str(self.min_threshold), + "max_threshold": str(self.max_threshold), + }, + ) + ] + except Exception as err: + self.log.debug("Failed to build check results %s", err) + return [] + def push(self, meta_data): """ Send data check info and metadata to an external database. @@ -1318,8 +1745,40 @@ def execute(self, context: Context): f"Unexpected query return result '{query_result}' type '{type(query_result)}'" ) + # Save check results before raising exception, to be used by listeners + self.check_results = self._build_check_results(query_result) + self.skip_all_except(context["ti"], self.follow_branch) + def _build_check_results(self, query_result: Any) -> list[SQLCheckResult]: + try: + return [ + SQLCheckResult( + name=self.task_id, + check_type="expression_is_true", + success=self.follow_branch == self.follow_task_ids_if_true, + severity="info", + expected="truthy", + actual=str(query_result), + content=self.sql, + description="SQL result is evaluated as boolean to determine the execution branch", + params={ + k: v + for k, v in { + "follow_task_ids_if_true": self.follow_task_ids_if_true, + "follow_task_ids_if_false": self.follow_task_ids_if_false, + "follow_branch": self.follow_branch, + "parameters": self.parameters, + }.items() + if v is not None + } + or None, + ) + ] + except Exception as err: + self.log.debug("Failed to build check results %s", err) + return [] + class SQLInsertRowsOperator(BaseSQLOperator): """ diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py index 08becbc12e15b..14380b717363c 100644 --- a/providers/common/sql/tests/unit/common/sql/operators/test_sql.py +++ b/providers/common/sql/tests/unit/common/sql/operators/test_sql.py @@ -35,6 +35,7 @@ BaseSQLOperator, BranchSQLOperator, SQLCheckOperator, + SQLCheckResult, SQLColumnCheckOperator, SQLExecuteQueryOperator, SQLInsertRowsOperator, @@ -1622,6 +1623,1268 @@ def test_new_style_subclass(self, mock_get_connection, operator_class): mock_get_connection.assert_called_once_with("test_conn") +class TestSQLColumnCheckOperatorBuildCheckResults: + @staticmethod + def _make_operator(column_mapping, **kwargs): + return SQLColumnCheckOperator( + task_id="test_task", table="test_table", column_mapping=column_mapping, **kwargs + ) + + def test_not_null_check(self): + op = self._make_operator({"col": {"null_check": {"equal_to": 0}}}) + op.column_mapping["col"]["null_check"]["result"] = 0 + op.column_mapping["col"]["null_check"]["success"] = True + results = op._build_check_results() + assert len(results) == 1 + r = results[0] + assert r.name == "col.null_check" + assert r.check_type == "not_null" + assert r.success is True + assert r.column == "col" + assert r.table == "test_table" + assert r.expected == "0" + assert r.actual == "0" + assert r.content == "SUM(CASE WHEN col IS NULL THEN 1 ELSE 0 END)" + assert r.description == "Column-level statistical check against the configured threshold" + assert r.severity == "error" + assert r.params == {"equal_to": 0, "accept_none": True} + + def test_unique_check(self): + op = self._make_operator({"col": {"unique_check": {"equal_to": 0}}}) + op.column_mapping["col"]["unique_check"]["result"] = 0 + op.column_mapping["col"]["unique_check"]["success"] = True + results = op._build_check_results() + assert len(results) == 1 + r = results[0] + assert r.check_type == "unique" + assert r.content == "COUNT(col) - COUNT(DISTINCT(col))" + assert r.params == {"equal_to": 0, "accept_none": True} + + def test_accepted_range_with_tolerance(self): + op = self._make_operator({"col": {"distinct_check": {"equal_to": 10, "tolerance": 0.1}}}) + op.column_mapping["col"]["distinct_check"]["result"] = 10 + op.column_mapping["col"]["distinct_check"]["success"] = True + results = op._build_check_results() + assert len(results) == 1 + r = results[0] + assert r.check_type == "accepted_range" + assert r.expected == ">= 9.0, <= 11.0" + assert r.actual == "10" + assert r.content == "COUNT(DISTINCT(col))" + assert r.params == {"equal_to": 10, "tolerance": 0.1, "accept_none": True} + + def test_accepted_range_geq_leq(self): + op = self._make_operator({"col": {"min": {"geq_to": 1, "leq_to": 100}}}) + op.column_mapping["col"]["min"]["result"] = 50 + op.column_mapping["col"]["min"]["success"] = True + results = op._build_check_results() + assert len(results) == 1 + r = results[0] + assert r.check_type == "accepted_range" + assert r.expected == ">=1, <=100" + assert r.params == {"geq_to": 1, "leq_to": 100, "accept_none": True} + + def test_multiple_checks_correct_names_and_order(self): + op = self._make_operator( + { + "col_a": { + "null_check": {"equal_to": 0, "result": 0, "success": True}, + "min": {"geq_to": 1, "result": 5, "success": True}, + }, + "col_b": { + "max": {"less_than": 100, "result": 50, "success": True}, + }, + } + ) + results = op._build_check_results() + assert len(results) == 3 + assert results[0].name == "col_a.null_check" + assert results[1].name == "col_a.min" + assert results[2].name == "col_b.max" + + def test_failing_check_recorded_correctly(self): + op = self._make_operator({"col": {"null_check": {"equal_to": 0}}}) + op.column_mapping["col"]["null_check"]["result"] = 5 + op.column_mapping["col"]["null_check"]["success"] = False + results = op._build_check_results() + assert len(results) == 1 + assert results[0].success is False + assert results[0].actual == "5" + + def test_partition_clause_in_params(self): + op = self._make_operator({"col": {"null_check": {"equal_to": 0, "partition_clause": "year=2024"}}}) + op.partition_clause = "region=us" + op.column_mapping["col"]["null_check"]["result"] = 0 + op.column_mapping["col"]["null_check"]["success"] = True + results = op._build_check_results() + assert len(results) == 1 + assert results[0].params == { + "equal_to": 0, + "accept_none": True, + "partition_clause": "region=us", + "check_partition_clause": "year=2024", + } + + def test_build_check_results_failure_returns_empty_list(self): + op = self._make_operator({"col": {"null_check": {"equal_to": 0, "result": 0, "success": True}}}) + with mock.patch( + "airflow.providers.common.sql.operators.sql.SQLCheckResult", side_effect=RuntimeError("boom") + ): + results = op._build_check_results() + assert results == [] + + @mock.patch("airflow.providers.common.sql.operators.sql.SQLCheckResult", side_effect=RuntimeError("boom")) + @mock.patch.object(SQLColumnCheckOperator, "get_db_hook") + def test_execute_unaffected_when_build_check_results_raises(self, mock_hook, _): + mock_hook.return_value.get_records.return_value = [("col", "null_check", 0)] + op = self._make_operator({"col": {"null_check": {"equal_to": 0}}}) + op.execute(MagicMock()) + assert op.check_results == [] + + @mock.patch.object(SQLColumnCheckOperator, "get_db_hook") + def test_execute_populates_check_results(self, mock_hook): + # Records returned by the DB: (column, check_name, result) + mock_hook.return_value.get_records.return_value = [("col", "null_check", 0)] + op = self._make_operator({"col": {"null_check": {"equal_to": 0}}}) + op.execute(MagicMock()) + assert len(op.check_results) == 1 + r = op.check_results[0] + assert r.name == "col.null_check" + assert r.check_type == "not_null" + assert r.success is True + assert r.severity == "error" + assert r.column == "col" + assert r.table == "test_table" + assert r.expected == "0" + assert r.actual == "0" + assert r.content == "SUM(CASE WHEN col IS NULL THEN 1 ELSE 0 END)" + assert r.description == "Column-level statistical check against the configured threshold" + assert r.params == {"equal_to": 0, "accept_none": True} + + @pytest.mark.parametrize("operator", ["greater_than", "geq_to", "less_than", "leq_to", "equal_to"]) + def test_known_comparison_operators_produce_expected_string(self, operator): + """Regression guard for existing operators. Does NOT detect new operators being added — + when adding a new comparison operator, add a case to this parametrize list manually.""" + result = SQLColumnCheckOperator._format_column_expected({operator: 5}) + assert result != "" + + @pytest.mark.parametrize( + ("check", "check_values", "expected_type"), + [ + # Range operators always win + ("null_check", {"greater_than": 0}, "accepted_range"), + ("col_count", {"geq_to": 10}, "accepted_range"), + ("max", {"less_than": 100}, "accepted_range"), + ("min", {"leq_to": 5}, "accepted_range"), + # equal_to + non-zero tolerance → range (tolerance expands to bounds) + ("distinct_check", {"equal_to": 10, "tolerance": 0.1}, "accepted_range"), + # equal_to without tolerance → semantic names for standard zero-value checks + ("null_check", {"equal_to": 0}, "not_null"), + ("unique_check", {"equal_to": 0}, "unique"), + # equal_to=0 with tolerance: condition requires equal_to != 0, so falls through + ("null_check", {"equal_to": 0, "tolerance": 0.1}, "not_null"), + ("distinct_check", {"equal_to": 0, "tolerance": 0.1}, "accepted_values"), + # equal_to with non-zero value on standard check names + ("null_check", {"equal_to": 5}, "accepted_values"), + ("unique_check", {"equal_to": 5}, "accepted_values"), + # generic check name with no comparison operators → returned unchanged + ("custom_metric", {}, "custom_metric"), + # equal_to on a non-null/unique check name + ("distinct_check", {"equal_to": 10}, "accepted_values"), + ], + ) + def test_get_column_check_type(self, check, check_values, expected_type): + assert SQLColumnCheckOperator._get_column_check_type(check, check_values) == expected_type + + @pytest.mark.parametrize( + ("check_values", "expected"), + [ + ({"equal_to": 0}, "0"), + ({"equal_to": 5}, "5"), + ({"greater_than": 5}, ">5"), + ({"geq_to": 5}, ">=5"), + ({"less_than": 10}, "<10"), + ({"leq_to": 10}, "<=10"), + ({"greater_than": 0, "less_than": 100}, ">0, <100"), + ({"geq_to": 5, "leq_to": 10}, ">=5, <=10"), + ], + ) + def test_format_column_expected_no_tolerance(self, check_values, expected): + assert SQLColumnCheckOperator._format_column_expected(check_values) == expected + + @pytest.mark.parametrize( + ("check_values", "expected"), + [ + ({"equal_to": 5, "tolerance": 0.1}, ">= 4.5, <= 5.5"), + ({"equal_to": 10, "tolerance": 0.5}, ">= 5.0, <= 15.0"), + ({"greater_than": 10, "tolerance": 0.1}, "> 9.0"), + ({"geq_to": 10, "tolerance": 0.1}, ">= 9.0"), + ({"less_than": 10, "tolerance": 0.1}, "< 11.0"), + ({"leq_to": 10, "tolerance": 0.1}, "<= 11.0"), + ({"geq_to": 10, "leq_to": 20, "tolerance": 0.1}, ">= 9.0, <= 22.0"), + ], + ) + def test_format_column_expected_with_tolerance(self, check_values, expected): + assert SQLColumnCheckOperator._format_column_expected(check_values) == expected + + +class TestSQLTableCheckOperatorBuildCheckResults: + @staticmethod + def _make_operator(checks, **kwargs): + return SQLTableCheckOperator(task_id="test_task", table="test_table", checks=checks, **kwargs) + + def test_passing_check(self): + checks = {"row_count_check": {"check_statement": "COUNT(*) >= 3"}} + op = self._make_operator(checks) + op.checks["row_count_check"]["result"] = "1" + op.checks["row_count_check"]["success"] = True + records = [("row_count_check", 1)] + results = op._build_check_results(records) + assert len(results) == 1 + r = results[0] + assert r.name == "row_count_check" + assert r.check_type == "expression_is_true" + assert r.success is True + assert r.table == "test_table" + assert r.content == "COUNT(*) >= 3" + assert r.expected == "all truthy" + assert r.actual == "1" + assert r.severity == "error" + assert r.params == {"accept_none": False} + + def test_failing_check(self): + checks = {"row_count_check": {"check_statement": "COUNT(*) >= 3"}} + op = self._make_operator(checks) + op.checks["row_count_check"]["result"] = "0" + op.checks["row_count_check"]["success"] = False + results = op._build_check_results([("row_count_check", 0)]) + assert len(results) == 1 + assert results[0].success is False + assert results[0].actual == "0" + + def test_accept_none_severity_warn_when_empty_records(self): + checks = {"row_count_check": {"check_statement": "COUNT(*) >= 3"}} + op = self._make_operator(checks, accept_none=True) + op.checks["row_count_check"]["success"] = False + results = op._build_check_results([]) + assert len(results) == 1 + assert results[0].severity == "warn" + + def test_severity_error_when_records_present(self): + checks = {"row_count_check": {"check_statement": "COUNT(*) >= 3"}} + op = self._make_operator(checks, accept_none=True) + op.checks["row_count_check"]["result"] = "1" + op.checks["row_count_check"]["success"] = True + results = op._build_check_results([("row_count_check", 1)]) + assert results[0].severity == "error" + + def test_multiple_checks(self): + checks = { + "count_check": {"check_statement": "COUNT(*) > 0"}, + "sum_check": {"check_statement": "SUM(val) > 0"}, + } + op = self._make_operator(checks) + op.checks["count_check"].update({"result": "1", "success": True}) + op.checks["sum_check"].update({"result": "1", "success": True}) + results = op._build_check_results([("count_check", 1), ("sum_check", 1)]) + assert len(results) == 2 + assert results[0].name == "count_check" + assert results[1].name == "sum_check" + + def test_build_check_results_failure_returns_empty_list(self): + checks = {"row_count_check": {"check_statement": "COUNT(*) >= 3"}} + op = self._make_operator(checks) + op.checks["row_count_check"].update({"result": "1", "success": True}) + with mock.patch( + "airflow.providers.common.sql.operators.sql.SQLCheckResult", side_effect=RuntimeError("boom") + ): + results = op._build_check_results([("row_count_check", 1)]) + assert results == [] + + @mock.patch("airflow.providers.common.sql.operators.sql.SQLCheckResult", side_effect=RuntimeError("boom")) + @mock.patch.object(SQLTableCheckOperator, "get_db_hook") + def test_execute_unaffected_when_build_check_results_raises(self, mock_hook, _): + mock_hook.return_value.get_records.return_value = [("row_count_check", 1)] + op = self._make_operator({"row_count_check": {"check_statement": "COUNT(*) >= 1"}}) + op.execute(MagicMock()) + assert op.check_results == [] + + @mock.patch.object(SQLTableCheckOperator, "get_db_hook") + def test_execute_populates_check_results(self, mock_hook): + mock_hook.return_value.get_records.return_value = [("row_count_check", "1")] + op = self._make_operator({"row_count_check": {"check_statement": "COUNT(*) >= 1"}}) + op.execute(MagicMock()) + assert len(op.check_results) == 1 + r = op.check_results[0] + assert r.name == "row_count_check" + assert r.check_type == "expression_is_true" + assert r.success is True + assert r.severity == "error" + assert r.column is None + assert r.table == "test_table" + assert r.expected == "all truthy" + assert r.actual == "1" + assert r.content == "COUNT(*) >= 1" + assert r.description == "User-defined SQL expression must evaluate to true" + assert r.params == {"accept_none": False} + + +class TestSQLCheckOperatorBuildCheckResults: + @staticmethod + def _make_operator(**kwargs): + return SQLCheckOperator(task_id="test_task", sql="SELECT 1", **kwargs) + + def test_all_truthy_records(self): + op = self._make_operator() + results = op._build_check_results([1, 1, 1]) + assert len(results) == 1 + r = results[0] + assert r.name == "test_task" + assert r.check_type == "expression_is_true" + assert r.success is True + assert r.expected == "all truthy" + assert r.actual == "[1, 1, 1]" + assert r.content == "SELECT 1" + assert r.params is None + + def test_falsy_value_in_records(self): + op = self._make_operator() + results = op._build_check_results([1, 0, 1]) + assert len(results) == 1 + assert results[0].success is False + + def test_empty_records(self): + op = self._make_operator() + results = op._build_check_results([]) + assert len(results) == 1 + r = results[0] + assert r.success is False + assert r.actual is None + + def test_dict_records_all_true(self): + op = self._make_operator() + results = op._build_check_results({"A": True, "B": True}) + assert len(results) == 1 + assert results[0].success is True + + def test_dict_records_not_all_true(self): + op = self._make_operator() + results = op._build_check_results({"A": True, "B": False}) + assert len(results) == 1 + assert results[0].success is False + + def test_parameters_in_params(self): + op = self._make_operator(parameters="my_params") + results = op._build_check_results([1]) + assert len(results) == 1 + assert results[0].params == {"parameters": "my_params"} + + def test_build_check_results_failure_returns_empty_list(self): + op = self._make_operator() + with mock.patch( + "airflow.providers.common.sql.operators.sql.SQLCheckResult", side_effect=RuntimeError("boom") + ): + results = op._build_check_results([1]) + assert results == [] + + @mock.patch("airflow.providers.common.sql.operators.sql.SQLCheckResult", side_effect=RuntimeError("boom")) + @mock.patch.object(SQLCheckOperator, "get_db_hook") + def test_execute_unaffected_when_build_check_results_raises(self, mock_hook, _): + mock_hook.return_value.get_first.return_value = [1] + op = self._make_operator() + op.execute(MagicMock()) + assert op.check_results == [] + + @mock.patch.object(SQLCheckOperator, "get_db_hook") + def test_execute_populates_check_results(self, mock_hook): + mock_hook.return_value.get_first.return_value = [1, 2, 3] + op = self._make_operator() + op.execute(MagicMock()) + assert len(op.check_results) == 1 + r = op.check_results[0] + assert r.name == "test_task" + assert r.check_type == "expression_is_true" + assert r.success is True + assert r.severity == "error" + assert r.column is None + assert r.table is None + assert r.expected == "all truthy" + assert r.actual == "[1, 2, 3]" + assert r.content == "SELECT 1" + assert r.description == "All values in the first returned row must evaluate to true" + assert r.params is None + + +class TestSQLValueCheckOperatorBuildCheckResults: + @staticmethod + def _make_operator(pass_value, tolerance=None): + return SQLValueCheckOperator( + task_id="test_task", sql="SELECT val FROM t", pass_value=pass_value, tolerance=tolerance + ) + + def test_exact_match_no_tolerance(self): + op = self._make_operator(pass_value="5") + results = op._build_check_results([5]) + assert len(results) == 1 + r = results[0] + assert r.name == "test_task" + assert r.check_type == "accepted_values" + assert r.success is True + assert r.expected == "5" + assert r.actual == "[5]" + assert r.content == "SELECT val FROM t" + assert r.params == {"pass_value": "5"} + + def test_numeric_tolerance_produces_accepted_range(self): + op = self._make_operator(pass_value=5, tolerance=0.1) + results = op._build_check_results([5]) + assert len(results) == 1 + r = results[0] + assert r.check_type == "accepted_range" + assert r.expected == ">= 4.5, <= 5.5" + assert r.params == {"pass_value": "5", "tolerance": 0.1} + + def test_non_numeric_pass_value_is_accepted_values(self): + op = self._make_operator(pass_value="hello") + results = op._build_check_results(["hello"]) + assert len(results) == 1 + assert results[0].check_type == "accepted_values" + assert results[0].expected == "hello" + + def test_failing_check(self): + op = self._make_operator(pass_value="10") + results = op._build_check_results([99]) + assert len(results) == 1 + assert results[0].success is False + + def test_parameters_in_params(self): + op = self._make_operator(pass_value="5") + op.parameters = {"key": "val"} + results = op._build_check_results([5]) + assert len(results) == 1 + assert results[0].params == {"pass_value": "5", "parameters": {"key": "val"}} + + def test_build_check_results_failure_returns_empty_list(self): + op = self._make_operator(pass_value="5") + with mock.patch( + "airflow.providers.common.sql.operators.sql.SQLCheckResult", side_effect=RuntimeError("boom") + ): + results = op._build_check_results([5]) + assert results == [] + + @mock.patch("airflow.providers.common.sql.operators.sql.SQLCheckResult", side_effect=RuntimeError("boom")) + @mock.patch.object(SQLValueCheckOperator, "get_db_hook") + def test_execute_unaffected_when_build_check_results_raises(self, mock_hook, _): + mock_hook.return_value.get_first.return_value = [5] + op = self._make_operator(pass_value="5") + op.execute(MagicMock()) + assert op.check_results == [] + + @mock.patch.object(SQLValueCheckOperator, "get_db_hook") + def test_execute_populates_check_results(self, mock_hook): + mock_hook.return_value.get_first.return_value = [5] + op = self._make_operator(pass_value="5") + op.execute(MagicMock()) + assert len(op.check_results) == 1 + r = op.check_results[0] + assert r.name == "test_task" + assert r.check_type == "accepted_values" + assert r.success is True + assert r.severity == "error" + assert r.column is None + assert r.table is None + assert r.expected == "5" + assert r.actual == "[5]" + assert r.content == "SELECT val FROM t" + assert r.description == "All values in the first returned row must match the expected value" + assert r.params == {"pass_value": "5"} + + @mock.patch.object(SQLValueCheckOperator, "get_db_hook") + def test_execute_populates_check_results_with_tolerance(self, mock_hook): + mock_hook.return_value.get_first.return_value = [5] + op = self._make_operator(pass_value=5, tolerance=0.1) + op.execute(MagicMock()) + assert len(op.check_results) == 1 + r = op.check_results[0] + assert r.name == "test_task" + assert r.check_type == "accepted_range" + assert r.success is True + assert r.severity == "error" + assert r.column is None + assert r.table is None + assert r.expected == ">= 4.5, <= 5.5" + assert r.actual == "[5]" + assert r.content == "SELECT val FROM t" + assert r.description == "All values in the first returned row must match the expected value" + assert r.params == {"pass_value": "5", "tolerance": 0.1} + + +class TestSQLIntervalCheckOperatorBuildCheckResults: + @staticmethod + def _make_operator(metrics_thresholds, **kwargs): + return SQLIntervalCheckOperator( + task_id="test_task", + table="test_table", + metrics_thresholds=metrics_thresholds, + ratio_formula="max_over_min", + ignore_zero=True, + **kwargs, + ) + + def _make_all_tests_results(self, metric, current, past, threshold, ratio, success): + return { + metric: { + "metric": metric, + "current_metric": current, + "past_metric": past, + "threshold": threshold, + "ignore_zero": True, + "ratio": ratio, + "success": success, + } + } + + def test_passing_metric(self): + op = self._make_operator({"f1": 1.5}) + all_tests_results = self._make_all_tests_results("f1", 10, 9, 1.5, 1.1, True) + results = op._build_check_results(all_tests_results) + assert len(results) == 1 + r = results[0] + assert r.name == "interval_f1" + assert r.check_type == "accepted_range" + assert r.success is True + assert r.table == "test_table" + assert r.expected == "< 1.5" + assert r.actual == "1.1" + assert r.content == "max(10, 9) / min(10, 9)" + assert r.description == "Ratio of current metric to historical baseline must be below the threshold" + assert r.params == { + "threshold": 1.5, + "days_back": -7, + "ratio_formula_name": "max_over_min", + "ratio_formula": "max({current}, {past}) / min({current}, {past})", + "date_filter_column": "ds", + "ignore_zero": True, + "current_metric": 10, + "past_metric": 9, + } + + def test_failing_metric(self): + op = self._make_operator({"f1": 1.5}) + all_tests_results = self._make_all_tests_results("f1", 10, 2, 1.5, 5.0, False) + results = op._build_check_results(all_tests_results) + assert len(results) == 1 + assert results[0].success is False + assert results[0].actual == "5.0" + + def test_zero_ratio_none(self): + op = self._make_operator({"f1": 1.5}) + all_tests_results = self._make_all_tests_results("f1", 0, 10, 1.5, None, True) + results = op._build_check_results(all_tests_results) + assert len(results) == 1 + assert results[0].actual == "0" + assert results[0].success is True + + def test_multiple_metrics(self): + op = self._make_operator({"f1": 1.5, "f2": 2.0}) + all_tests_results = { + "f1": { + "metric": "f1", + "current_metric": 10, + "past_metric": 9, + "threshold": 1.5, + "ignore_zero": True, + "ratio": 1.1, + "success": True, + }, + "f2": { + "metric": "f2", + "current_metric": 10, + "past_metric": 2, + "threshold": 2.0, + "ignore_zero": True, + "ratio": 5.0, + "success": False, + }, + } + results = op._build_check_results(all_tests_results) + assert len(results) == 2 + assert results[0].name == "interval_f1" + assert results[0].success is True + assert results[1].name == "interval_f2" + assert results[1].success is False + + def test_build_check_results_failure_returns_empty_list(self): + op = self._make_operator({"f1": 1.5}) + all_tests_results = self._make_all_tests_results("f1", 10, 9, 1.5, 1.1, True) + with mock.patch( + "airflow.providers.common.sql.operators.sql.SQLCheckResult", side_effect=RuntimeError("boom") + ): + results = op._build_check_results(all_tests_results) + assert results == [] + + @mock.patch("airflow.providers.common.sql.operators.sql.SQLCheckResult", side_effect=RuntimeError("boom")) + @mock.patch.object(SQLIntervalCheckOperator, "get_db_hook") + def test_execute_unaffected_when_build_check_results_raises(self, mock_hook, _): + mock_hook.return_value.get_first.side_effect = [[10], [10]] + op = self._make_operator({"f1": 1.5}) + op.execute(MagicMock()) + assert op.check_results == [] + + @mock.patch.object(SQLIntervalCheckOperator, "get_db_hook") + def test_execute_populates_check_results(self, mock_hook): + # execute() fetches sql2 (past/reference) first, then sql1 (current) + mock_hook.return_value.get_first.side_effect = [[9], [10]] + op = self._make_operator({"f1": 1.5}) + op.execute(MagicMock()) + assert len(op.check_results) == 1 + r = op.check_results[0] + assert r.name == "interval_f1" + assert r.check_type == "accepted_range" + assert r.success is True + assert r.severity == "error" + assert r.column is None + assert r.table == "test_table" + assert r.expected == "< 1.5" + # ratio = max(10, 9) / min(10, 9) = 10/9 ≈ 1.111, content substitutes actual values + assert r.content == "max(10, 9) / min(10, 9)" + assert r.actual == str(10 / 9) + assert r.description == "Ratio of current metric to historical baseline must be below the threshold" + assert r.params == { + "threshold": 1.5, + "days_back": -7, + "ratio_formula_name": "max_over_min", + "ratio_formula": "max({current}, {past}) / min({current}, {past})", + "date_filter_column": "ds", + "ignore_zero": True, + "current_metric": 10, + "past_metric": 9, + } + + def test_all_ratio_formulas_have_content_expressions(self): + """Fails if a new ratio formula is added to ratio_formulas without a matching expression entry.""" + op = self._make_operator({"f1": 1.5}) + assert set(op.ratio_formulas.keys()) == set(op.ratio_formula_expressions.keys()) + + +class TestSQLThresholdCheckOperatorBuildCheckResults: + @staticmethod + def _make_operator(min_threshold=1, max_threshold=100): + return SQLThresholdCheckOperator( + task_id="test_task", + sql="SELECT val FROM t", + min_threshold=min_threshold, + max_threshold=max_threshold, + ) + + def test_within_threshold(self): + op = self._make_operator(min_threshold=1, max_threshold=100) + meta_data = {"within_threshold": True, "min_threshold": 1.0, "max_threshold": 100.0} + results = op._build_check_results(50, meta_data) + assert len(results) == 1 + r = results[0] + assert r.name == "test_task" + assert r.check_type == "accepted_range" + assert r.success is True + assert r.expected == ">= 1.0, <= 100.0" + assert r.actual == "50" + assert r.content == "SELECT val FROM t" + assert r.description == "SQL result must fall within the configured bounds" + assert r.params == {"min_threshold": "1", "max_threshold": "100"} + + def test_outside_threshold(self): + op = self._make_operator(min_threshold=20, max_threshold=100) + meta_data = {"within_threshold": False, "min_threshold": 20.0, "max_threshold": 100.0} + results = op._build_check_results(10, meta_data) + assert len(results) == 1 + assert results[0].success is False + assert results[0].actual == "10" + assert results[0].expected == ">= 20.0, <= 100.0" + + def test_sql_threshold_raw_strings_preserved_in_params(self): + op = self._make_operator(min_threshold="SELECT MIN(val) FROM ref", max_threshold=100) + meta_data = {"within_threshold": True, "min_threshold": 5.0, "max_threshold": 100.0} + results = op._build_check_results(50, meta_data) + assert len(results) == 1 + assert results[0].params == { + "min_threshold": "SELECT MIN(val) FROM ref", + "max_threshold": "100", + } + + def test_build_check_results_failure_returns_empty_list(self): + op = self._make_operator() + meta_data = {"within_threshold": True, "min_threshold": 1.0, "max_threshold": 100.0} + with mock.patch( + "airflow.providers.common.sql.operators.sql.SQLCheckResult", side_effect=RuntimeError("boom") + ): + results = op._build_check_results(50, meta_data) + assert results == [] + + @mock.patch("airflow.providers.common.sql.operators.sql.SQLCheckResult", side_effect=RuntimeError("boom")) + @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook") + def test_execute_unaffected_when_build_check_results_raises(self, mock_hook, _): + mock_hook.return_value.get_first.return_value = (50,) + op = self._make_operator(min_threshold=1, max_threshold=100) + op.execute(MagicMock()) + assert op.check_results == [] + + @mock.patch.object(SQLThresholdCheckOperator, "get_db_hook") + def test_execute_populates_check_results(self, mock_hook): + mock_hook.return_value.get_first.return_value = (50,) + op = self._make_operator(min_threshold=1, max_threshold=100) + op.execute(MagicMock()) + assert len(op.check_results) == 1 + r = op.check_results[0] + assert r.name == "test_task" + assert r.check_type == "accepted_range" + assert r.success is True + assert r.severity == "error" + assert r.column is None + assert r.table is None + assert r.expected == ">= 1.0, <= 100.0" + assert r.actual == "50" + assert r.content == "SELECT val FROM t" + assert r.description == "SQL result must fall within the configured bounds" + assert r.params == {"min_threshold": "1", "max_threshold": "100"} + + +class TestBranchSQLOperatorBuildCheckResults: + @staticmethod + def _make_operator(**kwargs): + return BranchSQLOperator( + task_id="test_task", + sql="SELECT 1", + follow_task_ids_if_true=["branch_true"], + follow_task_ids_if_false=["branch_false"], + **kwargs, + ) + + def test_true_branch(self): + op = self._make_operator() + op.follow_branch = ["branch_true"] + results = op._build_check_results(1) + assert len(results) == 1 + r = results[0] + assert r.name == "test_task" + assert r.check_type == "expression_is_true" + assert r.success is True + assert r.severity == "info" + assert r.expected == "truthy" + assert r.actual == "1" + assert r.content == "SELECT 1" + assert r.description == "SQL result is evaluated as boolean to determine the execution branch" + assert r.params == { + "follow_task_ids_if_true": ["branch_true"], + "follow_task_ids_if_false": ["branch_false"], + "follow_branch": ["branch_true"], + } + + def test_false_branch(self): + op = self._make_operator() + op.follow_branch = ["branch_false"] + results = op._build_check_results(0) + assert len(results) == 1 + r = results[0] + assert r.success is False + assert r.actual == "0" + assert r.params == { + "follow_task_ids_if_true": ["branch_true"], + "follow_task_ids_if_false": ["branch_false"], + "follow_branch": ["branch_false"], + } + + def test_parameters_included_when_set(self): + op = self._make_operator(parameters={"key": "val"}) + op.follow_branch = ["branch_true"] + results = op._build_check_results(1) + assert len(results) == 1 + assert results[0].params == { + "follow_task_ids_if_true": ["branch_true"], + "follow_task_ids_if_false": ["branch_false"], + "follow_branch": ["branch_true"], + "parameters": {"key": "val"}, + } + + def test_build_check_results_failure_returns_empty_list(self): + op = self._make_operator() + op.follow_branch = ["branch_true"] + with mock.patch( + "airflow.providers.common.sql.operators.sql.SQLCheckResult", side_effect=RuntimeError("boom") + ): + results = op._build_check_results(1) + assert results == [] + + @mock.patch.object(BranchSQLOperator, "skip_all_except") + @mock.patch("airflow.providers.common.sql.operators.sql.SQLCheckResult", side_effect=RuntimeError("boom")) + @mock.patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.get_db_hook") + def test_execute_unaffected_when_build_check_results_raises(self, mock_hook, _, mock_skip): + mock_hook.return_value.get_first.return_value = 1 + op = self._make_operator() + op.execute({"ti": MagicMock()}) + assert op.check_results == [] + mock_skip.assert_called_once() + + @mock.patch.object(BranchSQLOperator, "skip_all_except") + @mock.patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.get_db_hook") + def test_execute_populates_check_results(self, mock_hook, mock_skip): + mock_hook.return_value.get_first.return_value = 1 + op = self._make_operator() + op.execute({"ti": MagicMock()}) + assert len(op.check_results) == 1 + r = op.check_results[0] + assert r.name == "test_task" + assert r.check_type == "expression_is_true" + assert r.success is True + assert r.severity == "info" + assert r.column is None + assert r.table is None + assert r.expected == "truthy" + assert r.actual == "1" + assert r.content == "SELECT 1" + assert r.description == "SQL result is evaluated as boolean to determine the execution branch" + assert r.params == { + "follow_task_ids_if_true": ["branch_true"], + "follow_task_ids_if_false": ["branch_false"], + "follow_branch": ["branch_true"], + } + mock_skip.assert_called_once() + + +class TestSqlBaseOperatorAttachCheckFacets: + """Tests for BaseSQLOperator._attach_check_facets.""" + + @staticmethod + def _make_operator(): + return SQLCheckOperator(task_id="test_task", sql="SELECT 1") + + @staticmethod + def _dataset(name): + from openlineage.client.event_v2 import Dataset + + return Dataset(namespace="default", name=name) + + def test_empty_check_results_returns_lineage_unchanged(self): + pytest.importorskip( + "openlineage.client", minversion="1.47.0", reason="openlineage-python >= 1.47.0 required" + ) + from airflow.providers.openlineage.extractors import OperatorLineage + + op = self._make_operator() + result = op._attach_check_facets(OperatorLineage()) + assert result == OperatorLineage() + assert result.run_facets == {} + + def test_old_openlineage_client_raises_optional_feature_exception(self): + try: + from airflow.providers.openlineage.extractors import OperatorLineage + except ImportError: + pytest.skip("openlineage provider not installed") + from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException + + op = self._make_operator() + op.check_results = [SQLCheckResult(name="t", check_type="expr", success=True)] + with mock.patch( + "airflow.providers.common.compat.openlineage.check.metadata.version", return_value="1.46.0" + ): + with pytest.raises(AirflowOptionalProviderFeatureException): + op._attach_check_facets(OperatorLineage()) + + def test_run_facet_all_fields_populated(self): + pytest.importorskip( + "openlineage.client", minversion="1.47.0", reason="openlineage-python >= 1.47.0 required" + ) + from airflow.providers.openlineage.extractors import OperatorLineage + + op = self._make_operator() + op.check_results = [ + SQLCheckResult( + name="my_check", + check_type="expression_is_true", + success=True, + severity="warn", + column="col_a", + table=None, # no table → run facet + expected="all truthy", + actual="[1, 1]", + content="SELECT col_a FROM t", + description="All values must be truthy", + params={"custom_key": "custom_val"}, + ) + ] + result = op._attach_check_facets(OperatorLineage()) + + facet = result.run_facets.get("test") + assert facet is not None + assert len(facet.tests) == 1 + t = facet.tests[0] + assert t.name == "my_check" + assert t.status == "pass" + assert t.severity == "warn" + assert t.type == "expression_is_true" + assert t.description == "All values must be truthy" + assert t.expected == "all truthy" + assert t.actual == "[1, 1]" + assert t.content == "SELECT col_a FROM t" + assert t.contentType == "sql" + # params merges tested_column / tested_table with the SQLCheckResult params + assert t.params == { + "tested_column": "col_a", + "tested_table": None, + "custom_key": "custom_val", + } + + def test_run_facet_failing_check_has_fail_status(self): + pytest.importorskip( + "openlineage.client", minversion="1.47.0", reason="openlineage-python >= 1.47.0 required" + ) + from airflow.providers.openlineage.extractors import OperatorLineage + + op = self._make_operator() + op.check_results = [SQLCheckResult(name="bad_check", check_type="not_null", success=False)] + result = op._attach_check_facets(OperatorLineage()) + + facet = result.run_facets.get("test") + assert facet is not None + assert facet.tests[0].status == "fail" + + def test_dataset_facet_all_fields_populated(self): + pytest.importorskip( + "openlineage.client", minversion="1.47.0", reason="openlineage-python >= 1.47.0 required" + ) + from airflow.providers.openlineage.extractors import OperatorLineage + + op = self._make_operator() + op.check_results = [ + SQLCheckResult( + name="col.null_check", + check_type="not_null", + success=True, + severity="error", + column="col", + table="test_table", + expected="0 nulls", + actual="0", + content="SUM(CASE WHEN col IS NULL THEN 1 ELSE 0 END)", + description="Column must contain no nulls", + params={"accept_none": False}, + ) + ] + lineage = OperatorLineage(inputs=[self._dataset("myschema.test_table")]) + result = op._attach_check_facets(lineage) + + assert result.run_facets == {} + dq_facet = result.inputs[0].facets.get("dataQualityAssertions") + assert dq_facet is not None + assert len(dq_facet.assertions) == 1 + a = dq_facet.assertions[0] + assert a.assertion == "not_null" + assert a.success is True + assert a.severity == "error" + assert a.column == "col" + assert a.name == "col.null_check" + assert a.description == "Column must contain no nulls" + assert a.expected == "0 nulls" + assert a.actual == "0" + assert a.content == "SUM(CASE WHEN col IS NULL THEN 1 ELSE 0 END)" + assert a.contentType == "sql" + assert a.params == {"accept_none": False} + + def test_dataset_facet_multiple_assertions_same_table(self): + pytest.importorskip( + "openlineage.client", minversion="1.47.0", reason="openlineage-python >= 1.47.0 required" + ) + from airflow.providers.openlineage.extractors import OperatorLineage + + op = self._make_operator() + op.check_results = [ + SQLCheckResult( + name="col_a.null_check", + check_type="not_null", + success=True, + column="col_a", + table="test_table", + expected="0 nulls", + actual="0", + ), + SQLCheckResult( + name="col_b.unique_check", + check_type="unique", + success=False, + column="col_b", + table="test_table", + expected="0 duplicates", + actual="3", + ), + SQLCheckResult( + name="row_count_check", + check_type="expression_is_true", + success=True, + column=None, + table="test_table", + expected="all truthy", + actual="1", + content="COUNT(*) > 0", + ), + ] + lineage = OperatorLineage(inputs=[self._dataset("myschema.test_table")]) + result = op._attach_check_facets(lineage) + + assert result.run_facets == {} + dq_facet = result.inputs[0].facets.get("dataQualityAssertions") + assert dq_facet is not None + assert len(dq_facet.assertions) == 3 + + a0 = dq_facet.assertions[0] + assert a0.name == "col_a.null_check" + assert a0.assertion == "not_null" + assert a0.success is True + assert a0.column == "col_a" + assert a0.expected == "0 nulls" + assert a0.actual == "0" + + a1 = dq_facet.assertions[1] + assert a1.name == "col_b.unique_check" + assert a1.assertion == "unique" + assert a1.success is False + assert a1.column == "col_b" + assert a1.expected == "0 duplicates" + assert a1.actual == "3" + + a2 = dq_facet.assertions[2] + assert a2.name == "row_count_check" + assert a2.assertion == "expression_is_true" + assert a2.success is True + assert a2.column is None + assert a2.content == "COUNT(*) > 0" + + def test_dataset_facet_attached_to_output_all_fields_populated(self): + pytest.importorskip( + "openlineage.client", minversion="1.47.0", reason="openlineage-python >= 1.47.0 required" + ) + from airflow.providers.openlineage.extractors import OperatorLineage + + op = self._make_operator() + op.check_results = [ + SQLCheckResult( + name="col.null_check", + check_type="not_null", + success=True, + severity="error", + column="col", + table="target_table", + expected="0 nulls", + actual="0", + content="SUM(CASE WHEN col IS NULL THEN 1 ELSE 0 END)", + description="Column must contain no nulls", + params={"accept_none": False}, + ) + ] + # table appears only in outputs (e.g. a post-write quality check) + lineage = OperatorLineage(outputs=[self._dataset("myschema.target_table")]) + result = op._attach_check_facets(lineage) + + assert result.run_facets == {} + dq_facet = result.outputs[0].facets.get("dataQualityAssertions") + assert dq_facet is not None + assert len(dq_facet.assertions) == 1 + a = dq_facet.assertions[0] + assert a.assertion == "not_null" + assert a.success is True + assert a.severity == "error" + assert a.column == "col" + assert a.name == "col.null_check" + assert a.description == "Column must contain no nulls" + assert a.expected == "0 nulls" + assert a.actual == "0" + assert a.content == "SUM(CASE WHEN col IS NULL THEN 1 ELSE 0 END)" + assert a.contentType == "sql" + assert a.params == {"accept_none": False} + + def test_unmatched_table_falls_back_to_run_facet_with_table_context_in_params(self): + pytest.importorskip( + "openlineage.client", minversion="1.47.0", reason="openlineage-python >= 1.47.0 required" + ) + from airflow.providers.openlineage.extractors import OperatorLineage + + op = self._make_operator() + op.check_results = [ + SQLCheckResult( + name="col.null_check", + check_type="not_null", + success=False, + severity="error", + column="col", + table="other_table", + expected="0 nulls", + actual="5", + content="SUM(CASE WHEN col IS NULL THEN 1 ELSE 0 END)", + description="Column must contain no nulls", + params={"accept_none": False}, + ) + ] + lineage = OperatorLineage(inputs=[self._dataset("myschema.test_table")]) + result = op._attach_check_facets(lineage) + + assert result.inputs[0].facets == {} + facet = result.run_facets.get("test") + assert facet is not None + assert len(facet.tests) == 1 + t = facet.tests[0] + assert t.name == "col.null_check" + assert t.status == "fail" + assert t.severity == "error" + assert t.type == "not_null" + assert t.description == "Column must contain no nulls" + assert t.expected == "0 nulls" + assert t.actual == "5" + assert t.content == "SUM(CASE WHEN col IS NULL THEN 1 ELSE 0 END)" + assert t.contentType == "sql" + # tested_column / tested_table are injected alongside the SQLCheckResult params + assert t.params == { + "tested_column": "col", + "tested_table": "other_table", + "accept_none": False, + } + + def test_exact_match_preferred_over_suffix_match(self): + pytest.importorskip( + "openlineage.client", minversion="1.47.0", reason="openlineage-python >= 1.47.0 required" + ) + from airflow.providers.openlineage.extractors import OperatorLineage + + op = self._make_operator() + op.check_results = [ + SQLCheckResult( + name="col.null_check", + check_type="not_null", + success=True, + column="col", + table="orders", + ) + ] + suffix_match = self._dataset("schema.orders") # endswith match — listed first + exact_match = self._dataset("orders") # exact match — listed second + lineage = OperatorLineage(inputs=[suffix_match, exact_match]) + result = op._attach_check_facets(lineage) + + assert result.run_facets == {} + # confirm ordering is preserved so index assertions below are unambiguous + assert result.inputs[0].name == "schema.orders" + assert result.inputs[1].name == "orders" + # suffix-match dataset must not receive the facet + assert result.inputs[0].facets == {} + # exact-match dataset must receive it + dq_facet = result.inputs[1].facets.get("dataQualityAssertions") + assert dq_facet is not None + assert dq_facet.assertions[0].name == "col.null_check" + + def test_same_table_in_inputs_and_outputs_both_receive_assertions(self): + pytest.importorskip( + "openlineage.client", minversion="1.47.0", reason="openlineage-python >= 1.47.0 required" + ) + from airflow.providers.openlineage.extractors import OperatorLineage + + op = self._make_operator() + op.check_results = [ + SQLCheckResult( + name="col.null_check", + check_type="not_null", + success=True, + column="col", + table="orders", + expected="0 nulls", + actual="0", + ) + ] + input_ds = self._dataset("schema.orders") + output_ds = self._dataset("schema.orders") + lineage = OperatorLineage(inputs=[input_ds], outputs=[output_ds]) + result = op._attach_check_facets(lineage) + + assert result.run_facets == {} + assert len(result.inputs) == 1 + assert len(result.outputs) == 1 + + for ds in (result.inputs[0], result.outputs[0]): + dq_facet = ds.facets.get("dataQualityAssertions") + assert dq_facet is not None, f"missing facet on {ds.name}" + assert len(dq_facet.assertions) == 1 + a = dq_facet.assertions[0] + assert a.name == "col.null_check" + assert a.assertion == "not_null" + assert a.success is True + assert a.column == "col" + + def test_mixed_results_dataset_facet_and_run_facet_populated_correctly(self): + pytest.importorskip( + "openlineage.client", minversion="1.47.0", reason="openlineage-python >= 1.47.0 required" + ) + from airflow.providers.openlineage.extractors import OperatorLineage + + op = self._make_operator() + op.check_results = [ + # matched input dataset → DataQualityAssertionsDatasetFacet + SQLCheckResult( + name="col_a.null_check", + check_type="not_null", + success=True, + column="col_a", + table="known_table", + expected="0 nulls", + actual="0", + ), + # table not in any dataset → falls back to TestRunFacet + SQLCheckResult( + name="col_b.null_check", + check_type="not_null", + success=False, + column="col_b", + table="unknown_table", + expected="0 nulls", + actual="2", + ), + # no table at all → directly to TestRunFacet + SQLCheckResult( + name="row_count_check", + check_type="expression_is_true", + success=True, + table=None, + expected="all truthy", + actual="1", + content="COUNT(*) > 0", + ), + ] + lineage = OperatorLineage(inputs=[self._dataset("myschema.known_table")]) + result = op._attach_check_facets(lineage) + + # matched check lands on the input dataset facet + dq_facet = result.inputs[0].facets.get("dataQualityAssertions") + assert dq_facet is not None + assert len(dq_facet.assertions) == 1 + a = dq_facet.assertions[0] + assert a.name == "col_a.null_check" + assert a.assertion == "not_null" + assert a.success is True + assert a.column == "col_a" + + # unmatched-table check and no-table check both land in the run facet + run_facet = result.run_facets.get("test") + assert run_facet is not None + assert len(run_facet.tests) == 2 + + t0 = run_facet.tests[0] + assert t0.name == "row_count_check" + assert t0.status == "pass" + assert t0.params["tested_column"] is None + assert t0.params["tested_table"] is None + + t1 = run_facet.tests[1] + assert t1.name == "col_b.null_check" + assert t1.status == "fail" + assert t1.params["tested_column"] == "col_b" + assert t1.params["tested_table"] == "unknown_table" + + class TestSQLInsertRowsOperator: @mock.patch.object(SQLInsertRowsOperator, "get_db_hook") def test_insert_rows_operator_with_preoperator(self, mock_get_db_hook): diff --git a/providers/openlineage/docs/index.rst b/providers/openlineage/docs/index.rst index 2a748dd326911..acf88d949a487 100644 --- a/providers/openlineage/docs/index.rst +++ b/providers/openlineage/docs/index.rst @@ -110,8 +110,8 @@ PIP package Version required ``apache-airflow-providers-common-sql`` ``>=1.32.0`` ``apache-airflow-providers-common-compat`` ``>=1.14.3`` ``attrs`` ``>=22.2`` -``openlineage-integration-common`` ``>=1.46.0`` -``openlineage-python`` ``>=1.46.0`` +``openlineage-integration-common`` ``>=1.47.0`` +``openlineage-python`` ``>=1.47.0`` ========================================== ================== Cross provider package dependencies diff --git a/providers/openlineage/pyproject.toml b/providers/openlineage/pyproject.toml index bfd1eaee1e5d6..7d5d799889c11 100644 --- a/providers/openlineage/pyproject.toml +++ b/providers/openlineage/pyproject.toml @@ -63,8 +63,8 @@ dependencies = [ "apache-airflow-providers-common-sql>=1.32.0", "apache-airflow-providers-common-compat>=1.14.3", # use next version "attrs>=22.2", - "openlineage-integration-common>=1.46.0", - "openlineage-python>=1.46.0", + "openlineage-integration-common>=1.47.0", + "openlineage-python>=1.47.0", ] # The optional dependencies should be modified in place in the generated file diff --git a/uv.lock b/uv.lock index 368ee79022af5..8201711de7f9b 100644 --- a/uv.lock +++ b/uv.lock @@ -6392,8 +6392,8 @@ requires-dist = [ { name = "apache-airflow-providers-common-compat", editable = "providers/common/compat" }, { name = "apache-airflow-providers-common-sql", editable = "providers/common/sql" }, { name = "attrs", specifier = ">=22.2" }, - { name = "openlineage-integration-common", specifier = ">=1.46.0" }, - { name = "openlineage-python", specifier = ">=1.46.0" }, + { name = "openlineage-integration-common", specifier = ">=1.47.0" }, + { name = "openlineage-python", specifier = ">=1.47.0" }, { name = "sqlalchemy", marker = "extra == 'sqlalchemy'", specifier = ">=1.4.54" }, ] provides-extras = ["sqlalchemy"] @@ -16380,7 +16380,7 @@ wheels = [ [[package]] name = "openlineage-integration-common" -version = "1.46.0" +version = "1.47.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "attrs" }, @@ -16389,12 +16389,12 @@ dependencies = [ { name = "pyyaml" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/ab/52/d81eca43a980058b661868fdeb9f5d4af876d1dab982635c859756907fee/openlineage_integration_common-1.46.0-py3-none-any.whl", hash = "sha256:bdc50613b94346dd9f3441d079260282e7cc85cc16ea1469630a9c76912d5d92", size = 59006, upload-time = "2026-04-08T13:20:18.972Z" }, + { url = "https://files.pythonhosted.org/packages/10/ee/ddb3e68fd3b796edd899d9bafb1d22cf9bcbe49c05dbbb30c8bdf9c7712c/openlineage_integration_common-1.47.0-py3-none-any.whl", hash = "sha256:a7457dc826e85212cfc199e7f7eeb95fe916799339e0e7805be26a8ebb023088", size = 60218, upload-time = "2026-05-08T16:09:46.485Z" }, ] [[package]] name = "openlineage-python" -version = "1.46.0" +version = "1.47.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "attrs" }, @@ -16405,7 +16405,7 @@ dependencies = [ { name = "requests" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/92/97/f5614980b5d43884ce7e71568239555b4e959bdf7098d9734e22c32ef5c6/openlineage_python-1.46.0-py3-none-any.whl", hash = "sha256:f6228a01d34990e76ede5b55b3f99169e54e2e624814c4493f064b9cb1bfba37", size = 112669, upload-time = "2026-04-08T13:20:20.024Z" }, + { url = "https://files.pythonhosted.org/packages/6f/cd/d376aca3344b30e6b5160cf3031d134ab4d31d596c1d89c3609099acd556/openlineage_python-1.47.0-py3-none-any.whl", hash = "sha256:fcdee97c9466dfb3c87b9f798b9db441f6f8fcb8aa8d65ff9eaffeeb3901747a", size = 113172, upload-time = "2026-05-08T16:09:47.901Z" }, ] [[package]]