diff --git a/datajunction-clients/python/tests/examples.py b/datajunction-clients/python/tests/examples.py index 682e59221..87197543f 100644 --- a/datajunction-clients/python/tests/examples.py +++ b/datajunction-clients/python/tests/examples.py @@ -1256,7 +1256,7 @@ "JOINdefault_repair_ordert2ONt1.repair_order_id=t2.repair_order_idLEFTOUTERJOIN" "default_hard_hatt3ONt2.hard_hat_id=t3.hard_hat_idGROUPBYt3.state)SELECT" "repair_order_details_0.stateASstate,\tSUM(repair_order_details_0.price_sum_252381cf)" - "/SUM(repair_order_details_0.price_count_252381cf)ASavg_repair_priceFROM" + "/NULLIF(SUM(repair_order_details_0.price_count_252381cf),0)ASavg_repair_priceFROM" "repair_order_details_0GROUPBYrepair_order_details_0.state": QueryWithResults( **{ "id": "bd98d6be-e2d2-413e-94c7-96d9411ddee2", @@ -1399,7 +1399,7 @@ "LEFTOUTERJOINdefault_repair_ordert2ONt1.repair_order_id=t2.repair_order_id" "LEFTOUTERJOINdefault_hard_hatt3ONt2.hard_hat_id=t3.hard_hat_idGROUPBYt3.city)" "SELECTCOALESCE(repair_order_details_0.city)AScity,\tSUM(repair_order_details_0." - "price_sum_252381cf)/SUM(repair_order_details_0.price_count_252381cf)AS" + "price_sum_252381cf)/NULLIF(SUM(repair_order_details_0.price_count_252381cf),0)AS" "avg_repair_priceFROMrepair_order_details_0GROUPBYrepair_order_details_0.city": QueryWithResults( id="v3-avg-repair-price-city", submitted_query="...", @@ -1437,7 +1437,7 @@ "LEFTOUTERJOINdefault_repair_ordert2ONt1.repair_order_id=t2.repair_order_id" "LEFTOUTERJOINdefault_hard_hatt3ONt2.hard_hat_id=t3.hard_hat_idGROUPBYt3.city)" "SELECTrepair_order_details_0.cityAScity,\tSUM(repair_order_details_0." - "price_sum_252381cf)/SUM(repair_order_details_0.price_count_252381cf)AS" + "price_sum_252381cf)/NULLIF(SUM(repair_order_details_0.price_count_252381cf),0)AS" "avg_repair_priceFROMrepair_order_details_0GROUPBYrepair_order_details_0.city": QueryWithResults( id="v3-avg-repair-price-city", submitted_query="...", @@ -1465,7 +1465,7 @@ "LEFTOUTERJOINdefault_repair_ordert2ONt1.repair_order_id=t2.repair_order_id" "LEFTOUTERJOINdefault_hard_hatt3ONt2.hard_hat_id=t3.hard_hat_idGROUPBYt3.state)" "SELECTCOALESCE(repair_order_details_0.state)ASstate,\tSUM(repair_order_details_0." - "price_sum_252381cf)/SUM(repair_order_details_0.price_count_252381cf)AS" + "price_sum_252381cf)/NULLIF(SUM(repair_order_details_0.price_count_252381cf),0)AS" "avg_repair_priceFROMrepair_order_details_0GROUPBYrepair_order_details_0.state": QueryWithResults( id="v3-avg-repair-price-state-no-data", submitted_query="...", @@ -1481,7 +1481,7 @@ "LEFTOUTERJOINdefault_repair_ordert2ONt1.repair_order_id=t2.repair_order_id" "LEFTOUTERJOINdefault_hard_hatt3ONt2.hard_hat_id=t3.hard_hat_idGROUPBYt3.state)" "SELECTrepair_order_details_0.stateASstate,\tSUM(repair_order_details_0." - "price_sum_252381cf)/SUM(repair_order_details_0.price_count_252381cf)AS" + "price_sum_252381cf)/NULLIF(SUM(repair_order_details_0.price_count_252381cf),0)AS" "avg_repair_priceFROMrepair_order_details_0GROUPBYrepair_order_details_0.state": QueryWithResults( id="v3-avg-repair-price-state-no-data-no-coalesce", submitted_query="...", @@ -1498,7 +1498,7 @@ "LEFTOUTERJOINdefault_repair_ordert2ONt1.repair_order_id=t2.repair_order_id" "LEFTOUTERJOINdefault_hard_hatt3ONt2.hard_hat_id=t3.hard_hat_idGROUPBYt3.city)" "SELECTrepair_order_details_0.cityAScity,\tSUM(repair_order_details_0." - "price_sum_252381cf)/SUM(repair_order_details_0.price_count_252381cf)AS" + "price_sum_252381cf)/NULLIF(SUM(repair_order_details_0.price_count_252381cf),0)AS" "avg_repair_priceFROMrepair_order_details_0GROUPBYrepair_order_details_0.city": QueryWithResults( id="v3-avg-repair-price-city-with-join-key", submitted_query="...", @@ -1526,7 +1526,7 @@ "LEFTOUTERJOINdefault_repair_ordert2ONt1.repair_order_id=t2.repair_order_id" "LEFTOUTERJOINdefault_hard_hatt3ONt2.hard_hat_id=t3.hard_hat_idGROUPBYt3.state)" "SELECTrepair_order_details_0.stateASstate,\tSUM(repair_order_details_0." - "price_sum_252381cf)/SUM(repair_order_details_0.price_count_252381cf)AS" + "price_sum_252381cf)/NULLIF(SUM(repair_order_details_0.price_count_252381cf),0)AS" "avg_repair_priceFROMrepair_order_details_0GROUPBYrepair_order_details_0.state": QueryWithResults( id="v3-avg-repair-price-state-no-data-with-join-key", submitted_query="...", diff --git a/datajunction-server/datajunction_server/construction/build_v3/cte.py b/datajunction-server/datajunction_server/construction/build_v3/cte.py index 3e4c8cc0a..3cf034bd7 100644 --- a/datajunction-server/datajunction_server/construction/build_v3/cte.py +++ b/datajunction-server/datajunction_server/construction/build_v3/cte.py @@ -23,6 +23,7 @@ from datajunction_server.construction.build_v3.utils import get_cte_name from datajunction_server.database.node import Node from datajunction_server.models.node_type import NodeType +from datajunction_server.sql.decompose import wrap_divisions_in_nullif from datajunction_server.sql.parsing import ast from datajunction_server.utils import SEPARATOR @@ -1593,4 +1594,8 @@ def process_metric_combiner_expression( partition_cte_alias=cte_alias, ) + # Wrap denominators so user-authored ratio metrics don't blow up on + # zero (NaN/Infinity/error → NULL). Idempotent. + wrap_divisions_in_nullif(expr_ast) + return expr_ast diff --git a/datajunction-server/datajunction_server/construction/build_v3/metrics.py b/datajunction-server/datajunction_server/construction/build_v3/metrics.py index 9f8a8975a..01c12d2a2 100644 --- a/datajunction-server/datajunction_server/construction/build_v3/metrics.py +++ b/datajunction-server/datajunction_server/construction/build_v3/metrics.py @@ -10,7 +10,7 @@ import logging from copy import deepcopy from dataclasses import dataclass -from typing import Any, Optional +from typing import Any, Optional, cast from datajunction_server.construction.build_v3.cte import ( build_alias_to_dimension_node, @@ -50,6 +50,7 @@ from datajunction_server.errors import DJInvalidInputException from datajunction_server.models.decompose import Aggregability from datajunction_server.models.node_type import NodeType +from datajunction_server.sql.decompose import wrap_divisions_in_nullif from datajunction_server.sql.parsing import ast logger = logging.getLogger(__name__) @@ -686,6 +687,10 @@ def build_intermediate_metric_expr( # The dependency hasn't been built, so defer this metric return None # pragma: no cover + # Intermediate derived metrics like avg_order_value = + # total_revenue / order_count inline raw aggregations on both + # sides — wrap denominators to avoid NaN/Infinity/error on 0. + wrap_divisions_in_nullif(cast(ast.Expression, expr_ast)) return expr_ast # type: ignore diff --git a/datajunction-server/datajunction_server/internal/materializations.py b/datajunction-server/datajunction_server/internal/materializations.py index 2650ae2e8..13b60c16c 100644 --- a/datajunction-server/datajunction_server/internal/materializations.py +++ b/datajunction-server/datajunction_server/internal/materializations.py @@ -428,8 +428,14 @@ def decompose_expression( args=[ast.Column(name=numerator_measure_name)], ), right=ast.Function( - ast.Name("count"), - args=[ast.Column(name=denominator_measure_name)], + ast.Name("NULLIF"), + args=[ + ast.Function( + ast.Name("count"), + args=[ast.Column(name=denominator_measure_name)], + ), + ast.Number(value=0), + ], ), op=ast.BinaryOpKind.Divide, ) @@ -455,6 +461,14 @@ def decompose_expression( if expr.op in acceptable_binary_ops: # pragma: no cover measures_combiner_left, measures_left = decompose_expression(expr.left) measures_combiner_right, measures_right = decompose_expression(expr.right) + if expr.op == ast.BinaryOpKind.Divide and not ( + isinstance(measures_combiner_right, ast.Function) + and measures_combiner_right.alias_or_name.name.lower() == "nullif" + ): + measures_combiner_right = ast.Function( + ast.Name("NULLIF"), + args=[measures_combiner_right, ast.Number(value=0)], + ) combiner = ast.BinaryOp( left=measures_combiner_left, right=measures_combiner_right, diff --git a/datajunction-server/datajunction_server/sql/decompose.py b/datajunction-server/datajunction_server/sql/decompose.py index 49b21b217..2a47e15e0 100644 --- a/datajunction-server/datajunction_server/sql/decompose.py +++ b/datajunction-server/datajunction_server/sql/decompose.py @@ -34,6 +34,45 @@ def make_func(name: str, *args: ast.Expression | str) -> ast.Function: ) +def safe_denominator(expr: ast.Expression) -> ast.Expression: + """Wrap expr in NULLIF(expr, 0) to make it safe as a divisor. + + Idempotent: if expr is already NULLIF(_, 0) or a numeric literal, + returns it unchanged. Caller passes the RHS of a Divide. + """ + # Numeric literals: x / 100 doesn't need wrapping. + if isinstance(expr, ast.Number): + return expr + # Already wrapped. + if ( + isinstance(expr, ast.Function) + and expr.name.name.upper() == "NULLIF" + and len(expr.args) == 2 + and isinstance(expr.args[1], ast.Number) + and expr.args[1].value == 0 + ): + return expr + return ast.Function( + ast.Name("NULLIF"), + args=[expr, ast.Number(value=0)], + ) + + +def wrap_divisions_in_nullif(expr: ast.Expression) -> ast.Expression: + """Walk expr and wrap the RHS of every Divide BinaryOp in NULLIF(_, 0) + so division-by-zero produces NULL instead of NaN/Infinity/error. + + Mutates and returns expr. Idempotent via :func:`safe_denominator`. + """ + for node in expr.find_all(ast.BinaryOp): + if node.op != ast.BinaryOpKind.Divide: + continue + wrapped = safe_denominator(node.right) + if wrapped is not node.right: + node.right = wrapped + return expr + + # ============================================================================= # Decomposition Framework # ============================================================================= @@ -999,6 +1038,15 @@ def _extract_base( components_tracker.add(comp.name) components.append(comp) + # Wrap user-authored divisions in the outer expression too. + # _decompose only wraps the small combiners it builds for AVG / + # variance / stddev / covariance; the user's top-level + # expression (e.g. CAST(SUM(...)) / COUNT(*)) is unchanged + # after sub-aggregation replacements and would otherwise emit + # bare divisions. + for proj in query_ast.select.projection: + wrap_divisions_in_nullif(cast(ast.Expression, proj)) + return components, query_ast def _substitute_metric_references( @@ -1080,6 +1128,11 @@ def _decompose( ) else: combiner_ast = decomposition.combine(components) + # Decomposed AVG / variance / stddev / covariance all build + # SUM(...) / SUM(count)-style combiners where the denominator + # can legitimately be 0. Wrap to produce NULL rather than + # NaN/Infinity/error. + combiner_ast = wrap_divisions_in_nullif(combiner_ast) return DecompositionResult(components, combiner_ast) diff --git a/datajunction-server/tests/api/cubes_test.py b/datajunction-server/tests/api/cubes_test.py index 03d37c84a..a1fee6a22 100644 --- a/datajunction-server/tests/api/cubes_test.py +++ b/datajunction-server/tests/api/cubes_test.py @@ -947,9 +947,9 @@ async def test_create_cube( COALESCE(repair_order_details_0.company_name, repair_orders_fact_0.company_name) AS company_name, COALESCE(repair_order_details_0.local_region, repair_orders_fact_0.local_region) AS local_region, COALESCE(repair_order_details_0.hire_date, repair_orders_fact_0.hire_date) AS hire_date, - CAST(SUM(repair_orders_fact_0.discount_sum_30b84e6c) AS DOUBLE) / SUM(repair_orders_fact_0.count_c8e42e74) AS discounted_orders_rate, + CAST(SUM(repair_orders_fact_0.discount_sum_30b84e6c) AS DOUBLE) / NULLIF(SUM(repair_orders_fact_0.count_c8e42e74), 0) AS discounted_orders_rate, SUM(repair_orders_fact_0.repair_order_id_count_bd241964) AS num_repair_orders, - SUM(repair_orders_fact_0.price_sum_935e7117) / SUM(repair_orders_fact_0.price_count_935e7117) AS avg_repair_price, + SUM(repair_orders_fact_0.price_sum_935e7117) / NULLIF(SUM(repair_orders_fact_0.price_count_935e7117), 0) AS avg_repair_price, SUM(repair_orders_fact_0.total_repair_cost_sum_67874507) AS total_repair_cost, SUM(repair_orders_fact_0.price_discount_sum_e4ba5456) AS total_repair_order_discounts, SUM(repair_order_details_0.price_sum_252381cf) + SUM(repair_order_details_0.price_sum_252381cf) AS double_total_repair_cost @@ -1082,9 +1082,9 @@ async def test_cube_filters_merged_with_request_filters( COALESCE(repair_order_details_0.company_name, repair_orders_fact_0.company_name) AS company_name, COALESCE(repair_order_details_0.local_region, repair_orders_fact_0.local_region) AS local_region, COALESCE(repair_order_details_0.hire_date, repair_orders_fact_0.hire_date) AS hire_date, - CAST(SUM(repair_orders_fact_0.discount_sum_30b84e6c) AS DOUBLE) / SUM(repair_orders_fact_0.count_c8e42e74) AS discounted_orders_rate, + CAST(SUM(repair_orders_fact_0.discount_sum_30b84e6c) AS DOUBLE) / NULLIF(SUM(repair_orders_fact_0.count_c8e42e74), 0) AS discounted_orders_rate, SUM(repair_orders_fact_0.repair_order_id_count_bd241964) AS num_repair_orders, - SUM(repair_orders_fact_0.price_sum_935e7117) / SUM(repair_orders_fact_0.price_count_935e7117) AS avg_repair_price, + SUM(repair_orders_fact_0.price_sum_935e7117) / NULLIF(SUM(repair_orders_fact_0.price_count_935e7117), 0) AS avg_repair_price, SUM(repair_orders_fact_0.total_repair_cost_sum_67874507) AS total_repair_cost, SUM(repair_orders_fact_0.price_discount_sum_e4ba5456) AS total_repair_order_discounts, SUM(repair_order_details_0.price_sum_252381cf) + SUM(repair_order_details_0.price_sum_252381cf) AS double_total_repair_cost @@ -1163,7 +1163,7 @@ async def test_cube_filters_applied_in_v3_sql_via_cube_param( SELECT repair_orders_fact_0.state AS state, repair_orders_fact_0.company_name AS company_name, SUM(repair_orders_fact_0.repair_order_id_count_bd241964) AS num_repair_orders, - SUM(repair_orders_fact_0.price_sum_935e7117) / SUM(repair_orders_fact_0.price_count_935e7117) AS avg_repair_price, + SUM(repair_orders_fact_0.price_sum_935e7117) / NULLIF(SUM(repair_orders_fact_0.price_count_935e7117), 0) AS avg_repair_price, SUM(repair_orders_fact_0.total_repair_cost_sum_67874507) AS total_repair_cost FROM repair_orders_fact_0 WHERE repair_orders_fact_0.state = 'AZ' @@ -1247,7 +1247,7 @@ async def test_cube_filters_applied_in_v3_sql_via_cube_param( SELECT repair_orders_fact_0.state AS state, repair_orders_fact_0.company_name AS company_name, SUM(repair_orders_fact_0.repair_order_id_count_bd241964) AS num_repair_orders, - SUM(repair_orders_fact_0.price_sum_935e7117) / SUM(repair_orders_fact_0.price_count_935e7117) AS avg_repair_price, + SUM(repair_orders_fact_0.price_sum_935e7117) / NULLIF(SUM(repair_orders_fact_0.price_count_935e7117), 0) AS avg_repair_price, SUM(repair_orders_fact_0.total_repair_cost_sum_67874507) AS total_repair_cost FROM repair_orders_fact_0 GROUP BY repair_orders_fact_0.state, repair_orders_fact_0.company_name @@ -1296,7 +1296,7 @@ async def test_cube_filters_applied_in_v3_sql_via_cube_param( SELECT repair_orders_fact_0.state AS state, repair_orders_fact_0.company_name AS company_name, SUM(repair_orders_fact_0.repair_order_id_count_bd241964) AS num_repair_orders, - SUM(repair_orders_fact_0.price_sum_935e7117) / SUM(repair_orders_fact_0.price_count_935e7117) AS avg_repair_price, + SUM(repair_orders_fact_0.price_sum_935e7117) / NULLIF(SUM(repair_orders_fact_0.price_count_935e7117), 0) AS avg_repair_price, SUM(repair_orders_fact_0.total_repair_cost_sum_67874507) AS total_repair_cost FROM repair_orders_fact_0 WHERE repair_orders_fact_0.state = 'AZ' AND repair_orders_fact_0.company_name = 'Potts LLC' @@ -1398,9 +1398,9 @@ async def test_cube_only_no_metrics_no_dims(client_with_repairs_cube: AsyncClien COALESCE(repair_order_details_0.company_name, repair_orders_fact_0.company_name) AS company_name, COALESCE(repair_order_details_0.local_region, repair_orders_fact_0.local_region) AS local_region, COALESCE(repair_order_details_0.hire_date, repair_orders_fact_0.hire_date) AS hire_date, - CAST(SUM(repair_orders_fact_0.discount_sum_30b84e6c) AS DOUBLE) / SUM(repair_orders_fact_0.count_c8e42e74) AS discounted_orders_rate, + CAST(SUM(repair_orders_fact_0.discount_sum_30b84e6c) AS DOUBLE) / NULLIF(SUM(repair_orders_fact_0.count_c8e42e74), 0) AS discounted_orders_rate, SUM(repair_orders_fact_0.repair_order_id_count_bd241964) AS num_repair_orders, - SUM(repair_orders_fact_0.price_sum_935e7117) / SUM(repair_orders_fact_0.price_count_935e7117) AS avg_repair_price, + SUM(repair_orders_fact_0.price_sum_935e7117) / NULLIF(SUM(repair_orders_fact_0.price_count_935e7117), 0) AS avg_repair_price, SUM(repair_orders_fact_0.total_repair_cost_sum_67874507) AS total_repair_cost, SUM(repair_orders_fact_0.price_discount_sum_e4ba5456) AS total_repair_order_discounts, SUM(repair_order_details_0.price_sum_252381cf) + SUM(repair_order_details_0.price_sum_252381cf) AS double_total_repair_cost @@ -3000,10 +3000,10 @@ async def test_cube_materialization_metadata( assert results["metrics"] == [ { "derived_expression": "SELECT CAST(SUM(discount_sum_30b84e6c) AS DOUBLE) / " - "SUM(count_c8e42e74) AS default_DOT_discounted_orders_rate FROM " + "NULLIF(SUM(count_c8e42e74), 0) AS default_DOT_discounted_orders_rate FROM " "default.repair_orders_fact", "metric_expression": "CAST(SUM(discount_sum_30b84e6c) AS DOUBLE) / " - "SUM(count_c8e42e74)", + "NULLIF(SUM(count_c8e42e74), 0)", "metric": { "name": "default.discounted_orders_rate", "version": mock.ANY, @@ -3049,9 +3049,9 @@ async def test_cube_materialization_metadata( ], }, { - "derived_expression": "SELECT SUM(price_sum_935e7117) / SUM(price_count_935e7117) FROM " + "derived_expression": "SELECT SUM(price_sum_935e7117) / NULLIF(SUM(price_count_935e7117), 0) FROM " "default.repair_orders_fact", - "metric_expression": "SUM(price_sum_935e7117) / SUM(price_count_935e7117)", + "metric_expression": "SUM(price_sum_935e7117) / NULLIF(SUM(price_count_935e7117), 0)", "metric": { "name": "default.avg_repair_price", "version": mock.ANY, diff --git a/datajunction-server/tests/api/deployments_test.py b/datajunction-server/tests/api/deployments_test.py index bcdf68668..65e0dafa2 100644 --- a/datajunction-server/tests/api/deployments_test.py +++ b/datajunction-server/tests/api/deployments_test.py @@ -1064,8 +1064,8 @@ def default_regional_repair_efficiency(): "Total Repair Amount in Region" is the total amount spent on repairs in a given region. "Total Repair Amount Nationwide" is the total amount spent on all repairs nationwide.""", query="""SELECT - (SUM(rm.completed_repairs) * 1.0 / SUM(rm.total_repairs_dispatched)) * - (SUM(rm.total_amount_in_region) * 1.0 / SUM(na.total_amount_nationwide)) * 100 + (SUM(rm.completed_repairs) * 1.0 / NULLIF(SUM(rm.total_repairs_dispatched), 0)) * + (SUM(rm.total_amount_in_region) * 1.0 / NULLIF(SUM(na.total_amount_nationwide), 0)) * 100 FROM ${prefix}default.regional_level_agg rm CROSS JOIN diff --git a/datajunction-server/tests/api/djql_test.py b/datajunction-server/tests/api/djql_test.py index e04eaf686..69fdc5453 100644 --- a/datajunction-server/tests/api/djql_test.py +++ b/datajunction-server/tests/api/djql_test.py @@ -483,7 +483,7 @@ async def test_get_djsql_with_orderby_and_limit( ) SELECT repair_orders_fact_0.country AS country, - SUM(repair_orders_fact_0.price_sum_HASH) / SUM(repair_orders_fact_0.price_count_HASH) AS avg_repair_price + SUM(repair_orders_fact_0.price_sum_HASH) / NULLIF(SUM(repair_orders_fact_0.price_count_HASH), 0) AS avg_repair_price FROM repair_orders_fact_0 GROUP BY repair_orders_fact_0.country ORDER BY country DESC diff --git a/datajunction-server/tests/api/graphql/find_nodes_test.py b/datajunction-server/tests/api/graphql/find_nodes_test.py index 7ac4f67d6..733381522 100644 --- a/datajunction-server/tests/api/graphql/find_nodes_test.py +++ b/datajunction-server/tests/api/graphql/find_nodes_test.py @@ -716,16 +716,16 @@ async def test_find_metric( }, ], "derivedQuery": "SELECT (SUM(completed_repairs_sum_8b112bf1) * 1.0 / " - "SUM(total_repairs_dispatched_sum_601dc4f1)) * " + "NULLIF(SUM(total_repairs_dispatched_sum_601dc4f1), 0)) * " "(SUM(total_amount_in_region_sum_3426ede4) * 1.0 / " - "SUM(na_DOT_total_amount_nationwide_sum_4ecb2318)) * 100 \n" + "NULLIF(SUM(na_DOT_total_amount_nationwide_sum_4ecb2318), 0)) * 100 \n" " FROM default.regional_level_agg CROSS JOIN " "default.national_level_agg na\n" "\n", "derivedExpression": "(SUM(completed_repairs_sum_8b112bf1) * 1.0 / " - "SUM(total_repairs_dispatched_sum_601dc4f1)) * " + "NULLIF(SUM(total_repairs_dispatched_sum_601dc4f1), 0)) * " "(SUM(total_amount_in_region_sum_3426ede4) * 1.0 / " - "SUM(na_DOT_total_amount_nationwide_sum_4ecb2318)) * 100", + "NULLIF(SUM(na_DOT_total_amount_nationwide_sum_4ecb2318), 0)) * 100", }, }, "name": "default.regional_repair_efficiency", diff --git a/datajunction-server/tests/api/metrics_test.py b/datajunction-server/tests/api/metrics_test.py index d0961f189..e38ef243c 100644 --- a/datajunction-server/tests/api/metrics_test.py +++ b/datajunction-server/tests/api/metrics_test.py @@ -482,11 +482,11 @@ async def test_read_metrics(module__client_with_roads: AsyncClient) -> None: }, ] assert data["derived_query"] == ( - "SELECT CAST(SUM(discount_sum_30b84e6c) AS DOUBLE) / SUM(count_c8e42e74) AS " + "SELECT CAST(SUM(discount_sum_30b84e6c) AS DOUBLE) / NULLIF(SUM(count_c8e42e74), 0) AS " "default_DOT_discounted_orders_rate \n FROM default.repair_orders_fact" ) assert data["derived_expression"] == ( - "CAST(SUM(discount_sum_30b84e6c) AS DOUBLE) / SUM(count_c8e42e74) " + "CAST(SUM(discount_sum_30b84e6c) AS DOUBLE) / NULLIF(SUM(count_c8e42e74), 0) " "AS default_DOT_discounted_orders_rate" ) assert data["custom_metadata"] is None diff --git a/datajunction-server/tests/api/nodes_test.py b/datajunction-server/tests/api/nodes_test.py index 1b19fe29a..9ed22756c 100644 --- a/datajunction-server/tests/api/nodes_test.py +++ b/datajunction-server/tests/api/nodes_test.py @@ -4964,7 +4964,7 @@ async def test_node_column_lineage(self, client_with_roads: AsyncClient): "query": ( """ SELECT - cast(sum(if(discount > 0.0, 1, 0)) as double) / count(repair_order_id) + cast(sum(if(discount > 0.0, 1, 0)) as double) / NULLIF(count(repair_order_id), 0) FROM default.repair_order_details """ ), @@ -5750,7 +5750,10 @@ def test_decompose_expression(): res = decompose_expression( ast.Function(ast.Name("avg"), args=[ast.Column(ast.Name("orders"))]), ) - assert str(res[0]) == "sum(orders3845127662_sum) / count(orders3845127662_count)" + assert ( + str(res[0]) + == "sum(orders3845127662_sum) / NULLIF(count(orders3845127662_count), 0)" + ) assert [measure.alias_or_name.name for measure in res[1]] == [ "orders3845127662_sum", "orders3845127662_count", @@ -5765,7 +5768,8 @@ def test_decompose_expression(): ), ) assert ( - str(res[0]) == "sum(orders3845127662_sum) / count(orders3845127662_count) + 5.5" + str(res[0]) + == "sum(orders3845127662_sum) / NULLIF(count(orders3845127662_count), 0) + 5.5" ) assert [measure.alias_or_name.name for measure in res[1]] == [ "orders3845127662_sum", @@ -5793,8 +5797,8 @@ def test_decompose_expression(): ) assert ( str(res[0]) - == "max(sum(orders_a1170126662_sum) / count(orders_a1170126662_count) " - "+ sum(orders_b3703039740_sum) / count(orders_b3703039740_count))" + == "max(sum(orders_a1170126662_sum) / NULLIF(count(orders_a1170126662_count), 0) " + "+ sum(orders_b3703039740_sum) / NULLIF(count(orders_b3703039740_count), 0))" ) # Decompose `sum(max(orders))` @@ -5834,7 +5838,7 @@ def test_decompose_expression(): ) assert ( str(res[0]) - == "max(orders3845127662_max) + min(validations2970758927_min) / sum(total3257917790_sum)" + == "max(orders3845127662_max) + min(validations2970758927_min) / NULLIF(sum(total3257917790_sum), 0)" ) assert [measure.alias_or_name.name for measure in res[1]] == [ "orders3845127662_max", @@ -5864,7 +5868,10 @@ def test_decompose_expression(): op=ast.BinaryOpKind.Divide, ), ) - assert str(res[0]) == "sum(has_ordered2766370626_sum) / sum(total3257917790_sum)" + assert ( + str(res[0]) + == "sum(has_ordered2766370626_sum) / NULLIF(sum(total3257917790_sum), 0)" + ) assert [measure.alias_or_name.name for measure in res[1]] == [ "has_ordered2766370626_sum", "total3257917790_sum", diff --git a/datajunction-server/tests/api/sql_test.py b/datajunction-server/tests/api/sql_test.py index 2814892bb..6bc3d721b 100644 --- a/datajunction-server/tests/api/sql_test.py +++ b/datajunction-server/tests/api/sql_test.py @@ -1782,7 +1782,7 @@ async def test_metric_with_second_order_dimensions( ) SELECT repair_orders_fact_0.city AS city, - SUM(repair_orders_fact_0.price_sum_935e7117) / SUM(repair_orders_fact_0.price_count_935e7117) AS avg_repair_price + SUM(repair_orders_fact_0.price_sum_935e7117) / NULLIF(SUM(repair_orders_fact_0.price_count_935e7117), 0) AS avg_repair_price FROM repair_orders_fact_0 GROUP BY repair_orders_fact_0.city @@ -1861,7 +1861,7 @@ async def test_metric_with_nth_order_dimensions( SELECT repair_orders_fact_0.city AS city, repair_orders_fact_0.company_name AS company_name, - SUM(repair_orders_fact_0.price_sum_935e7117) / SUM(repair_orders_fact_0.price_count_935e7117) AS avg_repair_price + SUM(repair_orders_fact_0.price_sum_935e7117) / NULLIF(SUM(repair_orders_fact_0.price_count_935e7117), 0) AS avg_repair_price FROM repair_orders_fact_0 GROUP BY repair_orders_fact_0.city, repair_orders_fact_0.company_name @@ -4243,7 +4243,7 @@ async def test_role_path_dimensions_in_filters_single_hop( ) SELECT user_dim_0.name_user_birth_country AS name_user_birth_country, - SUM(user_dim_0.age_sum_4ebaaaaa) / SUM(user_dim_0.age_count_4ebaaaaa) AS avg_user_age + SUM(user_dim_0.age_sum_4ebaaaaa) / NULLIF(SUM(user_dim_0.age_count_4ebaaaaa), 0) AS avg_user_age FROM user_dim_0 WHERE user_dim_0.name_user_birth_country = 'United States' GROUP BY user_dim_0.name_user_birth_country @@ -4312,7 +4312,7 @@ async def test_role_path_dimensions_in_filters_multi_hop_geographic( ) SELECT user_dim_0.continent_name_region_continent AS continent_name_region_continent, - SUM(user_dim_0.age_sum_4ebaaaaa) / SUM(user_dim_0.age_count_4ebaaaaa) AS avg_user_age + SUM(user_dim_0.age_sum_4ebaaaaa) / NULLIF(SUM(user_dim_0.age_count_4ebaaaaa), 0) AS avg_user_age FROM user_dim_0 WHERE user_dim_0.continent_name_region_continent = 'North America' GROUP BY user_dim_0.continent_name_region_continent @@ -4388,7 +4388,7 @@ async def test_role_path_dimensions_in_filters_multi_hop_temporal( ) SELECT user_dim_0.year_number_month_year AS year_number_month_year, - SUM(user_dim_0.age_sum_4ebaaaaa) / SUM(user_dim_0.age_count_4ebaaaaa) AS avg_user_age + SUM(user_dim_0.age_sum_4ebaaaaa) / NULLIF(SUM(user_dim_0.age_count_4ebaaaaa), 0) AS avg_user_age FROM user_dim_0 WHERE user_dim_0.year_number_month_year = 2024 GROUP BY user_dim_0.year_number_month_year @@ -4450,7 +4450,7 @@ async def test_role_path_dimensions_mixed_paths( SELECT user_dim_0.name_user_birth_country AS name_user_birth_country, user_dim_0.name_user_residence_country AS name_user_residence_country, - SUM(user_dim_0.age_sum_4ebaaaaa) / SUM(user_dim_0.age_count_4ebaaaaa) AS avg_user_age + SUM(user_dim_0.age_sum_4ebaaaaa) / NULLIF(SUM(user_dim_0.age_count_4ebaaaaa), 0) AS avg_user_age FROM user_dim_0 WHERE user_dim_0.name_user_birth_country = 'Canada' AND user_dim_0.name_user_residence_country = 'United States' GROUP BY user_dim_0.name_user_birth_country, user_dim_0.name_user_residence_country @@ -4548,7 +4548,7 @@ async def test_role_path_dimensions_mixed_hierarchies( SELECT user_dim_0.continent_name_region_continent AS continent_name_region_continent, user_dim_0.month_name_week_month AS month_name_week_month, - SUM(user_dim_0.age_sum_4ebaaaaa) / SUM(user_dim_0.age_count_4ebaaaaa) AS avg_user_age + SUM(user_dim_0.age_sum_4ebaaaaa) / NULLIF(SUM(user_dim_0.age_count_4ebaaaaa), 0) AS avg_user_age FROM user_dim_0 GROUP BY user_dim_0.continent_name_region_continent, user_dim_0.month_name_week_month @@ -4744,7 +4744,7 @@ async def test_multiple_filters_same_role_path( GROUP BY t2.name ) SELECT user_dim_0.name_user_birth_country AS name_user_birth_country, - SUM(user_dim_0.age_sum_4ebaaaaa) / SUM(user_dim_0.age_count_4ebaaaaa) AS avg_user_age + SUM(user_dim_0.age_sum_4ebaaaaa) / NULLIF(SUM(user_dim_0.age_count_4ebaaaaa), 0) AS avg_user_age FROM user_dim_0 WHERE user_dim_0.name_user_birth_country IS NOT NULL GROUP BY user_dim_0.name_user_birth_country @@ -4849,7 +4849,7 @@ async def test_role_path_dimensions_performance_complex_query( user_dim_0.region_name_country_region AS region_name_country_region, user_dim_0.month_name_week_month AS month_name_week_month, user_dim_0.year_number_month_year AS year_number_month_year, - SUM(user_dim_0.age_sum_4ebaaaaa) / SUM(user_dim_0.age_count_4ebaaaaa) AS avg_user_age + SUM(user_dim_0.age_sum_4ebaaaaa) / NULLIF(SUM(user_dim_0.age_count_4ebaaaaa), 0) AS avg_user_age FROM user_dim_0 WHERE user_dim_0.name_user_birth_country IN ('Canada', 'United States', 'Mexico') AND user_dim_0.region_name_country_region = 'North America' AND user_dim_0.month_name_week_month IN ('January', 'February', 'March') AND user_dim_0.year_number_month_year >= 2020 GROUP BY user_dim_0.name_user_birth_country, user_dim_0.name_user_residence_country, user_dim_0.region_name_country_region, user_dim_0.month_name_week_month, user_dim_0.year_number_month_year diff --git a/datajunction-server/tests/construction/build_v3/cte_test.py b/datajunction-server/tests/construction/build_v3/cte_test.py index cc542ca47..71ebc11f8 100644 --- a/datajunction-server/tests/construction/build_v3/cte_test.py +++ b/datajunction-server/tests/construction/build_v3/cte_test.py @@ -264,7 +264,7 @@ def test_aggregate_window_not_modified(self): def test_weighted_cpm_pattern(self): """Test weighted CPM pattern with grand total weight.""" - # Weighted CPM = (revenue / impressions) * (impressions / SUM(impressions) OVER ()) + # Weighted CPM = (revenue / impressions) * (impressions / NULLIF(SUM(impressions), 0) OVER ()) query = parse( "SELECT " "(revenue / NULLIF(impressions / 1000.0, 0)) " diff --git a/datajunction-server/tests/construction/build_v3/cube_matcher_test.py b/datajunction-server/tests/construction/build_v3/cube_matcher_test.py index 85d99888a..04c309eb1 100644 --- a/datajunction-server/tests/construction/build_v3/cube_matcher_test.py +++ b/datajunction-server/tests/construction/build_v3/cube_matcher_test.py @@ -844,13 +844,13 @@ async def test_builds_sql_from_cube_with_all_v3_order_details_metrics( SUM(test_cube_all_order_metrics_0.quantity_sum_06b64d2e) AS total_quantity, COUNT( DISTINCT test_cube_all_order_metrics_0.order_id_distinct_f93d50ab) AS order_count, hll_sketch_estimate(hll_union_agg(test_cube_all_order_metrics_0.customer_id_hll_23002251)) AS customer_count, - SUM(test_cube_all_order_metrics_0.unit_price_sum_55cff00f) / SUM(test_cube_all_order_metrics_0.unit_price_count_55cff00f) AS avg_unit_price, + SUM(test_cube_all_order_metrics_0.unit_price_sum_55cff00f) / NULLIF(SUM(test_cube_all_order_metrics_0.unit_price_count_55cff00f), 0) AS avg_unit_price, MAX(test_cube_all_order_metrics_0.unit_price_max_55cff00f) AS max_unit_price, MIN(test_cube_all_order_metrics_0.unit_price_min_55cff00f) AS min_unit_price, SUM(test_cube_all_order_metrics_0.line_total_sum_e1f61696) / NULLIF(COUNT( DISTINCT test_cube_all_order_metrics_0.order_id_distinct_f93d50ab), 0) AS avg_order_value, SUM(test_cube_all_order_metrics_0.quantity_sum_06b64d2e) / NULLIF(COUNT( DISTINCT test_cube_all_order_metrics_0.order_id_distinct_f93d50ab), 0) AS avg_items_per_order, SUM(test_cube_all_order_metrics_0.line_total_sum_e1f61696) / NULLIF(hll_sketch_estimate(hll_union_agg(test_cube_all_order_metrics_0.customer_id_hll_23002251)), 0) AS revenue_per_customer, - (MAX(test_cube_all_order_metrics_0.unit_price_max_55cff00f) - MIN(test_cube_all_order_metrics_0.unit_price_min_55cff00f)) / NULLIF(SUM(test_cube_all_order_metrics_0.unit_price_sum_55cff00f) / SUM(test_cube_all_order_metrics_0.unit_price_count_55cff00f), 0) * 100 AS price_spread_pct + (MAX(test_cube_all_order_metrics_0.unit_price_max_55cff00f) - MIN(test_cube_all_order_metrics_0.unit_price_min_55cff00f)) / NULLIF(SUM(test_cube_all_order_metrics_0.unit_price_sum_55cff00f) / NULLIF(SUM(test_cube_all_order_metrics_0.unit_price_count_55cff00f), 0), 0) * 100 AS price_spread_pct FROM test_cube_all_order_metrics_0 GROUP BY test_cube_all_order_metrics_0.category @@ -890,13 +890,13 @@ async def test_builds_sql_from_cube_with_all_v3_order_details_metrics( SUM(test_cube_all_order_metrics_0.quantity_sum_06b64d2e) AS total_quantity, COUNT( DISTINCT test_cube_all_order_metrics_0.order_id_distinct_f93d50ab) AS order_count, hll_sketch_estimate(ds_hll(test_cube_all_order_metrics_0.customer_id_hll_23002251)) AS customer_count, - SAFE_DIVIDE(SUM(test_cube_all_order_metrics_0.unit_price_sum_55cff00f), SUM(test_cube_all_order_metrics_0.unit_price_count_55cff00f)) AS avg_unit_price, + SAFE_DIVIDE(SUM(test_cube_all_order_metrics_0.unit_price_sum_55cff00f), NULLIF(SUM(test_cube_all_order_metrics_0.unit_price_count_55cff00f), 0)) AS avg_unit_price, MAX(test_cube_all_order_metrics_0.unit_price_max_55cff00f) AS max_unit_price, MIN(test_cube_all_order_metrics_0.unit_price_min_55cff00f) AS min_unit_price, SAFE_DIVIDE(SUM(test_cube_all_order_metrics_0.line_total_sum_e1f61696), NULLIF(COUNT( DISTINCT test_cube_all_order_metrics_0.order_id_distinct_f93d50ab), 0)) AS avg_order_value, SAFE_DIVIDE(SUM(test_cube_all_order_metrics_0.quantity_sum_06b64d2e), NULLIF(COUNT( DISTINCT test_cube_all_order_metrics_0.order_id_distinct_f93d50ab), 0)) AS avg_items_per_order, SAFE_DIVIDE(SUM(test_cube_all_order_metrics_0.line_total_sum_e1f61696), NULLIF(hll_sketch_estimate(ds_hll(test_cube_all_order_metrics_0.customer_id_hll_23002251)), 0)) AS revenue_per_customer, - SAFE_DIVIDE((MAX(test_cube_all_order_metrics_0.unit_price_max_55cff00f) - MIN(test_cube_all_order_metrics_0.unit_price_min_55cff00f)), NULLIF(SAFE_DIVIDE(SUM(test_cube_all_order_metrics_0.unit_price_sum_55cff00f), SUM(test_cube_all_order_metrics_0.unit_price_count_55cff00f)), 0)) * 100 AS price_spread_pct + SAFE_DIVIDE((MAX(test_cube_all_order_metrics_0.unit_price_max_55cff00f) - MIN(test_cube_all_order_metrics_0.unit_price_min_55cff00f)), NULLIF(SAFE_DIVIDE(SUM(test_cube_all_order_metrics_0.unit_price_sum_55cff00f), NULLIF(SUM(test_cube_all_order_metrics_0.unit_price_count_55cff00f), 0)), 0)) * 100 AS price_spread_pct FROM test_cube_all_order_metrics_0 GROUP BY test_cube_all_order_metrics_0.category @@ -2295,7 +2295,7 @@ async def test_build_metrics_sql_cube_with_multi_component_metric( SELECT cube_avg_metric_0.category AS category, SAFE_DIVIDE(SUM(cube_avg_metric_0.unit_price_sum_55cff00f), - SUM(cube_avg_metric_0.unit_price_count_55cff00f)) AS avg_unit_price + NULLIF(SUM(cube_avg_metric_0.unit_price_count_55cff00f), 0)) AS avg_unit_price FROM cube_avg_metric_0 GROUP BY cube_avg_metric_0.category """, diff --git a/datajunction-server/tests/construction/build_v3/metrics_sql_test.py b/datajunction-server/tests/construction/build_v3/metrics_sql_test.py index 0e932d7cb..dfae474a6 100644 --- a/datajunction-server/tests/construction/build_v3/metrics_sql_test.py +++ b/datajunction-server/tests/construction/build_v3/metrics_sql_test.py @@ -146,7 +146,7 @@ async def test_multi_component_metric(self, client_with_build_v3): Test metrics SQL for a multi-component metric (AVG). AVG decomposes into SUM and COUNT, and the combiner expression - should be applied: SUM(x) / COUNT(x). + should be applied: SUM(x) / NULLIF(COUNT(x), 0). """ response = await client_with_build_v3.get( "/sql/metrics/v3/", @@ -178,7 +178,7 @@ async def test_multi_component_metric(self, client_with_build_v3): GROUP BY t1.status ) SELECT order_details_0.status AS status, - SUM(order_details_0.unit_price_sum_55cff00f) / SUM(order_details_0.unit_price_count_55cff00f) AS avg_unit_price + SUM(order_details_0.unit_price_sum_55cff00f) / NULLIF(SUM(order_details_0.unit_price_count_55cff00f), 0) AS avg_unit_price FROM order_details_0 GROUP BY order_details_0.status """, @@ -1307,7 +1307,7 @@ async def test_all_additional_metrics_metrics_sql(self, client_with_build_v3): SUM(order_details_0.status_line_total_sum_43004dae) AS completed_order_revenue, SUM(order_details_0.line_total_sum_e1f61696) AS total_revenue, MAX(order_details_0.unit_price_max_55cff00f) - MIN(order_details_0.unit_price_min_55cff00f) AS price_spread, - (MAX(order_details_0.unit_price_max_55cff00f) - MIN(order_details_0.unit_price_min_55cff00f)) / NULLIF(SUM(order_details_0.unit_price_sum_55cff00f) / SUM(order_details_0.unit_price_count_55cff00f), 0) * 100 AS price_spread_pct + (MAX(order_details_0.unit_price_max_55cff00f) - MIN(order_details_0.unit_price_min_55cff00f)) / NULLIF(SUM(order_details_0.unit_price_sum_55cff00f) / NULLIF(SUM(order_details_0.unit_price_count_55cff00f), 0), 0) * 100 AS price_spread_pct FROM order_details_0 GROUP BY order_details_0.status, order_details_0.category """, diff --git a/datajunction-server/tests/construction/build_v3/preagg_substitution_test.py b/datajunction-server/tests/construction/build_v3/preagg_substitution_test.py index 981e9d1e2..1849dcf8a 100644 --- a/datajunction-server/tests/construction/build_v3/preagg_substitution_test.py +++ b/datajunction-server/tests/construction/build_v3/preagg_substitution_test.py @@ -235,7 +235,7 @@ async def test_derived_metric_no_preagg(self, client_with_build_v3): assert metrics_response.status_code == 200 metrics_data = metrics_response.json() - # avg_order_value = SUM(total_revenue) / COUNT(DISTINCT order_id) + # avg_order_value = SUM(total_revenue) / NULLIF(COUNT(DISTINCT order_id), 0) assert_sql_equal( metrics_data["sql"], """ diff --git a/datajunction-server/tests/construction/build_v3/types_test.py b/datajunction-server/tests/construction/build_v3/types_test.py index 5a32c2fbd..ab4cb8e8c 100644 --- a/datajunction-server/tests/construction/build_v3/types_test.py +++ b/datajunction-server/tests/construction/build_v3/types_test.py @@ -61,7 +61,7 @@ def test_is_fully_decomposable_with_limited(self): metric_node=metric_node, components=[component1, component2], aggregability=Aggregability.LIMITED, - combiner="SUM(x) / COUNT(DISTINCT y)", + combiner="SUM(x) / NULLIF(COUNT(DISTINCT y), 0)", derived_ast=derived_ast, ) diff --git a/datajunction-server/tests/sql/decompose_test.py b/datajunction-server/tests/sql/decompose_test.py index 1ecfe7e21..e125d0bf5 100644 --- a/datajunction-server/tests/sql/decompose_test.py +++ b/datajunction-server/tests/sql/decompose_test.py @@ -14,7 +14,13 @@ MetricComponent, ) from datajunction_server.models.node_type import NodeType -from datajunction_server.sql.decompose import MetricComponentExtractor +from datajunction_server.sql.decompose import ( + MetricComponentExtractor, + safe_denominator, + wrap_divisions_in_nullif, +) +from datajunction_server.sql.parsing import ast +from datajunction_server.sql.parsing.backends.antlr4 import parse from datajunction_server.sql.parsing.backends.exceptions import DJParseException from datajunction_server.models.engine import Dialect from datajunction_server.sql.parsing.ast import to_sql @@ -344,7 +350,7 @@ async def test_average(session: AsyncSession, create_metric): assert_sql_equal( str(derived_sql), "SELECT SUM(sales_amount_sum_b5a3cefe) / " - "SUM(sales_amount_count_b5a3cefe) FROM parent_node", + "NULLIF(SUM(sales_amount_count_b5a3cefe), 0) FROM parent_node", ) @@ -377,7 +383,7 @@ async def test_rate(session: AsyncSession, create_metric): assert measures == expected_measures0 assert_sql_equal( str(derived_sql), - "SELECT SUM(clicks_sum_c45fd8cf) / SUM(impressions_sum_3be0a0e7) FROM parent_node", + "SELECT SUM(clicks_sum_c45fd8cf) / NULLIF(SUM(impressions_sum_3be0a0e7), 0) FROM parent_node", ) metric_rev2 = await create_metric( @@ -418,7 +424,7 @@ async def test_rate(session: AsyncSession, create_metric): assert_sql_equal( str(derived_sql), "SELECT CAST(CAST(SUM(clicks_sum_c45fd8cf) AS INT) AS DOUBLE) / " - "CAST(SUM(impressions_sum_3be0a0e7) AS DOUBLE) FROM parent_node", + "NULLIF(CAST(SUM(impressions_sum_3be0a0e7) AS DOUBLE), 0) FROM parent_node", ) metric_rev4 = await create_metric( @@ -446,7 +452,7 @@ async def test_rate(session: AsyncSession, create_metric): assert_sql_equal( str(derived_sql), "SELECT COALESCE(SUM(clicks_sum_c45fd8cf) / " - "SUM(impressions_sum_3be0a0e7), 0) FROM parent_node", + "NULLIF(SUM(impressions_sum_3be0a0e7), 0), 0) FROM parent_node", ) metric_rev5 = await create_metric( @@ -475,7 +481,7 @@ async def test_rate(session: AsyncSession, create_metric): assert_sql_equal( str(derived_sql), "SELECT IF(SUM(clicks_sum_c45fd8cf) > 0, CAST(SUM(impressions_sum_3be0a0e7) AS DOUBLE)" - " / CAST(SUM(clicks_sum_c45fd8cf) AS DOUBLE), NULL) FROM parent_node", + " / NULLIF(CAST(SUM(clicks_sum_c45fd8cf) AS DOUBLE), 0), NULL) FROM parent_node", ) metric_rev6 = await create_metric( @@ -502,7 +508,7 @@ async def test_rate(session: AsyncSession, create_metric): assert measures == expected_measures assert_sql_equal( str(derived_sql), - "SELECT ln(SUM(clicks_sum_c45fd8cf) + 1) / SUM(views_sum_d8e39817) FROM parent_node", + "SELECT ln(SUM(clicks_sum_c45fd8cf) + 1) / NULLIF(SUM(views_sum_d8e39817), 0) FROM parent_node", ) @@ -564,7 +570,7 @@ async def test_fraction_with_if(session: AsyncSession, create_metric): str(derived_sql), "SELECT IF(SUM(action_sum_c9802ccb) > 0, " "CAST(SUM(action_two_sum_05d921a8) AS DOUBLE) / " - "CAST(SUM(action_sum_c9802ccb) AS DOUBLE), NULL) FROM parent_node", + "NULLIF(CAST(SUM(action_sum_c9802ccb) AS DOUBLE), 0), NULL) FROM parent_node", ) @@ -632,7 +638,7 @@ async def test_count_distinct_rate(session: AsyncSession, create_metric): assert_sql_equal( str(derived_sql), "SELECT COUNT( DISTINCT user_id_distinct_7f092f23) / " - "SUM(action_count_50d753fd) FROM parent_node", + "NULLIF(SUM(action_count_50d753fd), 0) FROM parent_node", ) @@ -825,7 +831,7 @@ async def test_count_if(session: AsyncSession, create_metric): assert measures == expected_measures assert_sql_equal( str(derived_sql), - "SELECT CAST(SUM(field_a_count_if_3979ffbd) AS FLOAT) / SUM(count_58ac32c5) " + "SELECT CAST(SUM(field_a_count_if_3979ffbd) AS FLOAT) / NULLIF(SUM(count_58ac32c5), 0) " "FROM parent_node", ) @@ -862,7 +868,7 @@ async def test_metric_query_with_aliases(session: AsyncSession, create_metric): assert_sql_equal( str(derived_sql), "SELECT SUM(time_to_dispatch_sum_3bc9baed) / " - "SUM(time_to_dispatch_count_3bc9baed) FROM default.repair_orders_fact", + "NULLIF(SUM(time_to_dispatch_count_3bc9baed), 0) FROM default.repair_orders_fact", ) @@ -1123,7 +1129,7 @@ async def test_approx_count_distinct_rate(session: AsyncSession, create_metric): derived_str = str(derived_sql) assert_sql_equal( derived_str, - "SELECT CAST(hll_sketch_estimate(hll_union_agg(clicked_user_id_hll_f3824813)) AS DOUBLE) / CAST(hll_sketch_estimate(hll_union_agg(user_id_hll_7f092f23)) AS DOUBLE) FROM parent_node", + "SELECT CAST(hll_sketch_estimate(hll_union_agg(clicked_user_id_hll_f3824813)) AS DOUBLE) / NULLIF(CAST(hll_sketch_estimate(hll_union_agg(user_id_hll_7f092f23)) AS DOUBLE), 0) FROM parent_node", ) @@ -1190,21 +1196,21 @@ async def test_approx_count_distinct_combined_metrics_dialect_translation( # Verify Spark SQL structure - contains both SUM and HLL assert_sql_equal( spark_sql, - "SELECT SUM(revenue_sum_60e4d31f) / hll_sketch_estimate(hll_union_agg(user_id_hll_7f092f23)) AS revenue_per_user FROM parent_node", + "SELECT SUM(revenue_sum_60e4d31f) / NULLIF(hll_sketch_estimate(hll_union_agg(user_id_hll_7f092f23)), 0) AS revenue_per_user FROM parent_node", ) # Translate to Druid - should preserve SUM but translate HLL druid_sql = to_sql(derived_sql, Dialect.DRUID) assert_sql_equal( druid_sql, - "SELECT SAFE_DIVIDE(SUM(revenue_sum_60e4d31f), hll_sketch_estimate(ds_hll(user_id_hll_7f092f23))) AS revenue_per_user FROM parent_node", + "SELECT SAFE_DIVIDE(SUM(revenue_sum_60e4d31f), NULLIF(hll_sketch_estimate(ds_hll(user_id_hll_7f092f23)), 0)) AS revenue_per_user FROM parent_node", ) # Translate to Trino trino_sql = to_sql(derived_sql, Dialect.TRINO) assert_sql_equal( trino_sql, - "SELECT SUM(revenue_sum_60e4d31f) / cardinality(merge(user_id_hll_7f092f23)) AS revenue_per_user FROM parent_node", + "SELECT SUM(revenue_sum_60e4d31f) / NULLIF(cardinality(merge(user_id_hll_7f092f23)), 0) AS revenue_per_user FROM parent_node", ) @@ -1254,8 +1260,8 @@ async def test_var_pop(session: AsyncSession, create_metric): derived_str, """ SELECT - SUM(price_sum_sq_726db899) / SUM(price_count_726db899) - - POWER(SUM(price_sum_726db899) / SUM(price_count_726db899), 2) + SUM(price_sum_sq_726db899) / NULLIF(SUM(price_count_726db899), 0) - + POWER(SUM(price_sum_726db899) / NULLIF(SUM(price_count_726db899), 0), 2) FROM parent_node""", ) @@ -1279,7 +1285,7 @@ async def test_var_samp(session: AsyncSession, create_metric): str(derived_sql), """ SELECT - SUM(price_count_726db899) * SUM(price_sum_sq_726db899) - POWER(SUM(price_sum_726db899), 2) / SUM(price_count_726db899) * SUM(price_count_726db899) - 1 + SUM(price_count_726db899) * SUM(price_sum_sq_726db899) - POWER(SUM(price_sum_726db899), 2) / NULLIF(SUM(price_count_726db899) * SUM(price_count_726db899) - 1, 0) FROM parent_node""", ) @@ -1306,8 +1312,8 @@ async def test_stddev_pop(session: AsyncSession, create_metric): derived_str, """SELECT SQRT( - SUM(price_sum_sq_726db899) / SUM(price_count_726db899) - - POWER(SUM(price_sum_726db899) / SUM(price_count_726db899), 2) + SUM(price_sum_sq_726db899) / NULLIF(SUM(price_count_726db899), 0) - + POWER(SUM(price_sum_726db899) / NULLIF(SUM(price_count_726db899), 0), 2) ) FROM parent_node""", ) @@ -1338,7 +1344,7 @@ async def test_stddev_samp(session: AsyncSession, create_metric): SQRT( SUM(price_count_726db899) * SUM(price_sum_sq_726db899) - POWER(SUM(price_sum_726db899), 2) / - SUM(price_count_726db899) * SUM(price_count_726db899) - 1 + NULLIF(SUM(price_count_726db899) * SUM(price_count_726db899) - 1, 0) ) FROM parent_node""", ) @@ -2155,3 +2161,74 @@ async def test_extract_nested_derived_with_avg( # Derived SQL should contain the AVG combiner (SUM/COUNT pattern) assert "SUM(" in derived_sql assert "/" in derived_sql # Division from AVG decomposition + + +# ============================================================================= +# Division-safety NULLIF auto-wrapping +# ============================================================================= + + +def test_safe_denominator_idempotent(): + """safe_denominator wraps once and only once.""" + inner = ast.Function(ast.Name("SUM"), args=[ast.Column(ast.Name("n"))]) + wrapped = safe_denominator(inner) + assert_sql_equal(f"SELECT {wrapped}", "SELECT NULLIF(SUM(n), 0)") + # Re-wrapping returns the same expression unchanged. + re_wrapped = safe_denominator(wrapped) + assert re_wrapped is wrapped + + +def test_safe_denominator_preserves_literal(): + """Numeric literals (e.g. x / 100) don't need NULLIF.""" + lit = ast.Number(value=100) + assert safe_denominator(lit) is lit + + +def test_wrap_divisions_in_nullif_walks_nested(): + """Every nested Divide in the expression tree gets its RHS wrapped.""" + expr = parse( + "SELECT SUM(a) / SUM(b) - SUM(c) / SUM(d) FROM t", + ).select.projection[0] + wrap_divisions_in_nullif(expr) + assert_sql_equal( + f"SELECT {expr}", + "SELECT SUM(a) / NULLIF(SUM(b), 0) - SUM(c) / NULLIF(SUM(d), 0)", + ) + + +@pytest.mark.asyncio +async def test_avg_decomposition_wraps_denominator( + session: AsyncSession, + create_metric, +): + """The AVG combiner is auto-wrapped so 0/0 produces NULL instead of + NaN/Infinity/error. Together with the unit-level test_average above, + locks in the auto-wrap behaviour end-to-end via decomposition. + """ + metric_rev = await create_metric("SELECT AVG(amount) FROM parent_node") + extractor = MetricComponentExtractor(metric_rev.id) + _, derived_sql = await extractor.extract(session) + assert_sql_equal( + str(derived_sql), + "SELECT SUM(amount_sum_9e341235) / NULLIF(SUM(amount_count_9e341235), 0) " + "FROM parent_node", + ) + + +@pytest.mark.asyncio +async def test_user_authored_division_not_double_wrapped( + session: AsyncSession, + create_metric, +): + """If the author already wrote NULLIF(denominator, 0), the auto-wrap + is idempotent — no double-wrap. + """ + metric_rev = await create_metric( + "SELECT SUM(x) / NULLIF(SUM(y), 0) FROM parent_node", + ) + extractor = MetricComponentExtractor(metric_rev.id) + _, derived_sql = await extractor.extract(session) + assert_sql_equal( + str(derived_sql), + "SELECT SUM(x_sum_b5c12ce5) / NULLIF(SUM(y_sum_898a9389), 0) FROM parent_node", + )