From 9184c48c69025c4d72de650983054434aa4ccc97 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Mon, 18 May 2026 14:05:16 -0700 Subject: [PATCH 1/6] Add NULLIF checks to ratio metrics --- .../construction/build_v3/cte.py | 8 ++ .../construction/build_v3/metrics.py | 9 +- .../datajunction_server/sql/decompose.py | 49 ++++++++++ .../construction/build_v3/metrics_sql_test.py | 4 +- .../tests/sql/decompose_test.py | 96 +++++++++++++++++-- 5 files changed, 155 insertions(+), 11 deletions(-) diff --git a/datajunction-server/datajunction_server/construction/build_v3/cte.py b/datajunction-server/datajunction_server/construction/build_v3/cte.py index 3e4c8cc0a..b05a82bf0 100644 --- a/datajunction-server/datajunction_server/construction/build_v3/cte.py +++ b/datajunction-server/datajunction_server/construction/build_v3/cte.py @@ -1593,4 +1593,12 @@ def process_metric_combiner_expression( partition_cte_alias=cte_alias, ) + # Auto-wrap every Divide's RHS in NULLIF(_, 0) so user-authored + # ratio metrics don't produce NaN/Infinity/error on zero denominators. + # Idempotent — denominators that already have NULLIF or a literal + # constant are left alone. + from datajunction_server.sql.decompose import wrap_divisions_in_nullif + + wrap_divisions_in_nullif(expr_ast) + return expr_ast diff --git a/datajunction-server/datajunction_server/construction/build_v3/metrics.py b/datajunction-server/datajunction_server/construction/build_v3/metrics.py index 9f8a8975a..3a1276089 100644 --- a/datajunction-server/datajunction_server/construction/build_v3/metrics.py +++ b/datajunction-server/datajunction_server/construction/build_v3/metrics.py @@ -10,7 +10,7 @@ import logging from copy import deepcopy from dataclasses import dataclass -from typing import Any, Optional +from typing import Any, Optional, cast from datajunction_server.construction.build_v3.cte import ( build_alias_to_dimension_node, @@ -686,6 +686,13 @@ def build_intermediate_metric_expr( # The dependency hasn't been built, so defer this metric return None # pragma: no cover + # Auto-wrap every Divide's RHS in NULLIF(_, 0). Intermediate + # derived metrics like ``avg_order_value = total_revenue / + # order_count`` inline raw aggregations on both sides; without + # NULLIF the result is NaN/Infinity/error when the denominator is 0. + from datajunction_server.sql.decompose import wrap_divisions_in_nullif + + wrap_divisions_in_nullif(cast(ast.Expression, expr_ast)) return expr_ast # type: ignore diff --git a/datajunction-server/datajunction_server/sql/decompose.py b/datajunction-server/datajunction_server/sql/decompose.py index 49b21b217..19075269a 100644 --- a/datajunction-server/datajunction_server/sql/decompose.py +++ b/datajunction-server/datajunction_server/sql/decompose.py @@ -34,6 +34,49 @@ def make_func(name: str, *args: ast.Expression | str) -> ast.Function: ) +def safe_denominator(expr: ast.Expression) -> ast.Expression: + """Wrap ``expr`` in ``NULLIF(expr, 0)`` to make it safe as a divisor. + + Idempotent — if ``expr`` is already ``NULLIF(..., 0)`` (or a numeric + literal), it's returned unchanged. Caller's responsibility to pass + only the ``right`` side of a Divide. + """ + # Already a literal — division by literal 0 is the author's intent + # and we should preserve it as-is (most likely the literal isn't 0). + if isinstance(expr, ast.Number): + return expr + # Already wrapped in NULLIF(x, 0) + if ( + isinstance(expr, ast.Function) + and expr.name.name.upper() == "NULLIF" + and len(expr.args) == 2 + and isinstance(expr.args[1], ast.Number) + and expr.args[1].value == 0 + ): + return expr + return ast.Function( + ast.Name("NULLIF"), + args=[expr, ast.Number(value=0)], + ) + + +def wrap_divisions_in_nullif(expr: ast.Expression) -> ast.Expression: + """Walk ``expr`` and wrap the right-hand side of every Divide ``BinaryOp`` + in ``NULLIF(..., 0)`` so division-by-zero produces NULL instead of + NaN/Infinity/error. + + Returns ``expr`` (mutated in place where possible). Idempotent via + :func:`safe_denominator`. + """ + for node in expr.find_all(ast.BinaryOp): + if node.op != ast.BinaryOpKind.Divide: + continue + wrapped = safe_denominator(node.right) + if wrapped is not node.right: + node.right = wrapped + return expr + + # ============================================================================= # Decomposition Framework # ============================================================================= @@ -1080,6 +1123,12 @@ def _decompose( ) else: combiner_ast = decomposition.combine(components) + # Auto-wrap every Divide's RHS in NULLIF(_, 0). Decomposed AVG / + # variance / stddev / covariance all construct ``SUM(...) / + # SUM(count)`` patterns where the denominator can legitimately + # be 0; without NULLIF the result is NaN/Infinity/error + # depending on dialect. Idempotent. + combiner_ast = wrap_divisions_in_nullif(combiner_ast) return DecompositionResult(components, combiner_ast) diff --git a/datajunction-server/tests/construction/build_v3/metrics_sql_test.py b/datajunction-server/tests/construction/build_v3/metrics_sql_test.py index 0e932d7cb..b2c46f7a0 100644 --- a/datajunction-server/tests/construction/build_v3/metrics_sql_test.py +++ b/datajunction-server/tests/construction/build_v3/metrics_sql_test.py @@ -178,7 +178,7 @@ async def test_multi_component_metric(self, client_with_build_v3): GROUP BY t1.status ) SELECT order_details_0.status AS status, - SUM(order_details_0.unit_price_sum_55cff00f) / SUM(order_details_0.unit_price_count_55cff00f) AS avg_unit_price + SUM(order_details_0.unit_price_sum_55cff00f) / NULLIF(SUM(order_details_0.unit_price_count_55cff00f), 0) AS avg_unit_price FROM order_details_0 GROUP BY order_details_0.status """, @@ -1307,7 +1307,7 @@ async def test_all_additional_metrics_metrics_sql(self, client_with_build_v3): SUM(order_details_0.status_line_total_sum_43004dae) AS completed_order_revenue, SUM(order_details_0.line_total_sum_e1f61696) AS total_revenue, MAX(order_details_0.unit_price_max_55cff00f) - MIN(order_details_0.unit_price_min_55cff00f) AS price_spread, - (MAX(order_details_0.unit_price_max_55cff00f) - MIN(order_details_0.unit_price_min_55cff00f)) / NULLIF(SUM(order_details_0.unit_price_sum_55cff00f) / SUM(order_details_0.unit_price_count_55cff00f), 0) * 100 AS price_spread_pct + (MAX(order_details_0.unit_price_max_55cff00f) - MIN(order_details_0.unit_price_min_55cff00f)) / NULLIF(SUM(order_details_0.unit_price_sum_55cff00f) / NULLIF(SUM(order_details_0.unit_price_count_55cff00f), 0), 0) * 100 AS price_spread_pct FROM order_details_0 GROUP BY order_details_0.status, order_details_0.category """, diff --git a/datajunction-server/tests/sql/decompose_test.py b/datajunction-server/tests/sql/decompose_test.py index 1ecfe7e21..8d65ce973 100644 --- a/datajunction-server/tests/sql/decompose_test.py +++ b/datajunction-server/tests/sql/decompose_test.py @@ -344,7 +344,7 @@ async def test_average(session: AsyncSession, create_metric): assert_sql_equal( str(derived_sql), "SELECT SUM(sales_amount_sum_b5a3cefe) / " - "SUM(sales_amount_count_b5a3cefe) FROM parent_node", + "NULLIF(SUM(sales_amount_count_b5a3cefe), 0) FROM parent_node", ) @@ -862,7 +862,7 @@ async def test_metric_query_with_aliases(session: AsyncSession, create_metric): assert_sql_equal( str(derived_sql), "SELECT SUM(time_to_dispatch_sum_3bc9baed) / " - "SUM(time_to_dispatch_count_3bc9baed) FROM default.repair_orders_fact", + "NULLIF(SUM(time_to_dispatch_count_3bc9baed), 0) FROM default.repair_orders_fact", ) @@ -1254,8 +1254,8 @@ async def test_var_pop(session: AsyncSession, create_metric): derived_str, """ SELECT - SUM(price_sum_sq_726db899) / SUM(price_count_726db899) - - POWER(SUM(price_sum_726db899) / SUM(price_count_726db899), 2) + SUM(price_sum_sq_726db899) / NULLIF(SUM(price_count_726db899), 0) - + POWER(SUM(price_sum_726db899) / NULLIF(SUM(price_count_726db899), 0), 2) FROM parent_node""", ) @@ -1279,7 +1279,7 @@ async def test_var_samp(session: AsyncSession, create_metric): str(derived_sql), """ SELECT - SUM(price_count_726db899) * SUM(price_sum_sq_726db899) - POWER(SUM(price_sum_726db899), 2) / SUM(price_count_726db899) * SUM(price_count_726db899) - 1 + SUM(price_count_726db899) * SUM(price_sum_sq_726db899) - POWER(SUM(price_sum_726db899), 2) / NULLIF(SUM(price_count_726db899) * SUM(price_count_726db899) - 1, 0) FROM parent_node""", ) @@ -1306,8 +1306,8 @@ async def test_stddev_pop(session: AsyncSession, create_metric): derived_str, """SELECT SQRT( - SUM(price_sum_sq_726db899) / SUM(price_count_726db899) - - POWER(SUM(price_sum_726db899) / SUM(price_count_726db899), 2) + SUM(price_sum_sq_726db899) / NULLIF(SUM(price_count_726db899), 0) - + POWER(SUM(price_sum_726db899) / NULLIF(SUM(price_count_726db899), 0), 2) ) FROM parent_node""", ) @@ -1338,7 +1338,7 @@ async def test_stddev_samp(session: AsyncSession, create_metric): SQRT( SUM(price_count_726db899) * SUM(price_sum_sq_726db899) - POWER(SUM(price_sum_726db899), 2) / - SUM(price_count_726db899) * SUM(price_count_726db899) - 1 + NULLIF(SUM(price_count_726db899) * SUM(price_count_726db899) - 1, 0) ) FROM parent_node""", ) @@ -2155,3 +2155,83 @@ async def test_extract_nested_derived_with_avg( # Derived SQL should contain the AVG combiner (SUM/COUNT pattern) assert "SUM(" in derived_sql assert "/" in derived_sql # Division from AVG decomposition + + +# ============================================================================= +# Division-safety NULLIF auto-wrapping +# ============================================================================= + + +def test_safe_denominator_idempotent(): + """``safe_denominator`` wraps once and only once.""" + from datajunction_server.sql.decompose import safe_denominator + from datajunction_server.sql.parsing.backends.antlr4 import ast + + inner = ast.Function(ast.Name("SUM"), args=[ast.Column(ast.Name("n"))]) + wrapped = safe_denominator(inner) + assert_sql_equal(str(wrapped), "NULLIF(SUM(n), 0)") + # Re-wrapping returns the same expression unchanged. + re_wrapped = safe_denominator(wrapped) + assert re_wrapped is wrapped + + +def test_safe_denominator_preserves_literal(): + """Numeric literals (e.g. ``x / 100``) don't need NULLIF.""" + from datajunction_server.sql.decompose import safe_denominator + from datajunction_server.sql.parsing.backends.antlr4 import ast + + lit = ast.Number(value=100) + assert safe_denominator(lit) is lit + + +def test_wrap_divisions_in_nullif_walks_nested(): + """Every nested Divide in the expression tree gets its RHS wrapped.""" + from datajunction_server.sql.decompose import wrap_divisions_in_nullif + from datajunction_server.sql.parsing.backends.antlr4 import parse + + expr = parse( + "SELECT (SUM(a) / SUM(b)) - (SUM(c) / SUM(d)) FROM t", + ).select.projection[0] + wrap_divisions_in_nullif(expr) + assert_sql_equal( + str(expr), + "(SUM(a) / NULLIF(SUM(b), 0)) - (SUM(c) / NULLIF(SUM(d), 0))", + ) + + +@pytest.mark.asyncio +async def test_avg_decomposition_wraps_denominator( + session: AsyncSession, + create_metric, +): + """The AVG combiner is auto-wrapped so 0/0 produces NULL instead of + NaN/Infinity/error. Together with the unit-level ``test_average`` + above, locks in the auto-wrap behaviour end-to-end via decomposition. + """ + metric_rev = await create_metric("SELECT AVG(amount) FROM parent_node") + extractor = MetricComponentExtractor(metric_rev.id) + _, derived_sql = await extractor.extract(session) + assert_sql_equal( + str(derived_sql), + "SELECT SUM(amount_sum_67a0b14a) / NULLIF(SUM(amount_count_67a0b14a), 0) " + "FROM parent_node", + ) + + +@pytest.mark.asyncio +async def test_user_authored_division_not_double_wrapped( + session: AsyncSession, + create_metric, +): + """If the author already wrote ``NULLIF(denominator, 0)``, the + auto-wrap is idempotent — no double-wrap. + """ + metric_rev = await create_metric( + "SELECT SUM(x) / NULLIF(SUM(y), 0) FROM parent_node", + ) + extractor = MetricComponentExtractor(metric_rev.id) + _, derived_sql = await extractor.extract(session) + assert_sql_equal( + str(derived_sql), + "SELECT SUM(x_sum_22e2f0d6) / NULLIF(SUM(y_sum_22e2f0d6), 0) FROM parent_node", + ) From 6e05962848cee156899aa3dc0a52ca5368681ea3 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Mon, 18 May 2026 14:21:54 -0700 Subject: [PATCH 2/6] Fix tests --- .../construction/build_v3/cte.py | 3 +- .../construction/build_v3/metrics.py | 3 +- .../build_v3/cube_matcher_test.py | 10 +++---- .../tests/sql/decompose_test.py | 29 +++++++++---------- 4 files changed, 20 insertions(+), 25 deletions(-) diff --git a/datajunction-server/datajunction_server/construction/build_v3/cte.py b/datajunction-server/datajunction_server/construction/build_v3/cte.py index b05a82bf0..c5ff9f28c 100644 --- a/datajunction-server/datajunction_server/construction/build_v3/cte.py +++ b/datajunction-server/datajunction_server/construction/build_v3/cte.py @@ -23,6 +23,7 @@ from datajunction_server.construction.build_v3.utils import get_cte_name from datajunction_server.database.node import Node from datajunction_server.models.node_type import NodeType +from datajunction_server.sql.decompose import wrap_divisions_in_nullif from datajunction_server.sql.parsing import ast from datajunction_server.utils import SEPARATOR @@ -1597,8 +1598,6 @@ def process_metric_combiner_expression( # ratio metrics don't produce NaN/Infinity/error on zero denominators. # Idempotent — denominators that already have NULLIF or a literal # constant are left alone. - from datajunction_server.sql.decompose import wrap_divisions_in_nullif - wrap_divisions_in_nullif(expr_ast) return expr_ast diff --git a/datajunction-server/datajunction_server/construction/build_v3/metrics.py b/datajunction-server/datajunction_server/construction/build_v3/metrics.py index 3a1276089..e4649ec00 100644 --- a/datajunction-server/datajunction_server/construction/build_v3/metrics.py +++ b/datajunction-server/datajunction_server/construction/build_v3/metrics.py @@ -50,6 +50,7 @@ from datajunction_server.errors import DJInvalidInputException from datajunction_server.models.decompose import Aggregability from datajunction_server.models.node_type import NodeType +from datajunction_server.sql.decompose import wrap_divisions_in_nullif from datajunction_server.sql.parsing import ast logger = logging.getLogger(__name__) @@ -690,8 +691,6 @@ def build_intermediate_metric_expr( # derived metrics like ``avg_order_value = total_revenue / # order_count`` inline raw aggregations on both sides; without # NULLIF the result is NaN/Infinity/error when the denominator is 0. - from datajunction_server.sql.decompose import wrap_divisions_in_nullif - wrap_divisions_in_nullif(cast(ast.Expression, expr_ast)) return expr_ast # type: ignore diff --git a/datajunction-server/tests/construction/build_v3/cube_matcher_test.py b/datajunction-server/tests/construction/build_v3/cube_matcher_test.py index 85d99888a..04c309eb1 100644 --- a/datajunction-server/tests/construction/build_v3/cube_matcher_test.py +++ b/datajunction-server/tests/construction/build_v3/cube_matcher_test.py @@ -844,13 +844,13 @@ async def test_builds_sql_from_cube_with_all_v3_order_details_metrics( SUM(test_cube_all_order_metrics_0.quantity_sum_06b64d2e) AS total_quantity, COUNT( DISTINCT test_cube_all_order_metrics_0.order_id_distinct_f93d50ab) AS order_count, hll_sketch_estimate(hll_union_agg(test_cube_all_order_metrics_0.customer_id_hll_23002251)) AS customer_count, - SUM(test_cube_all_order_metrics_0.unit_price_sum_55cff00f) / SUM(test_cube_all_order_metrics_0.unit_price_count_55cff00f) AS avg_unit_price, + SUM(test_cube_all_order_metrics_0.unit_price_sum_55cff00f) / NULLIF(SUM(test_cube_all_order_metrics_0.unit_price_count_55cff00f), 0) AS avg_unit_price, MAX(test_cube_all_order_metrics_0.unit_price_max_55cff00f) AS max_unit_price, MIN(test_cube_all_order_metrics_0.unit_price_min_55cff00f) AS min_unit_price, SUM(test_cube_all_order_metrics_0.line_total_sum_e1f61696) / NULLIF(COUNT( DISTINCT test_cube_all_order_metrics_0.order_id_distinct_f93d50ab), 0) AS avg_order_value, SUM(test_cube_all_order_metrics_0.quantity_sum_06b64d2e) / NULLIF(COUNT( DISTINCT test_cube_all_order_metrics_0.order_id_distinct_f93d50ab), 0) AS avg_items_per_order, SUM(test_cube_all_order_metrics_0.line_total_sum_e1f61696) / NULLIF(hll_sketch_estimate(hll_union_agg(test_cube_all_order_metrics_0.customer_id_hll_23002251)), 0) AS revenue_per_customer, - (MAX(test_cube_all_order_metrics_0.unit_price_max_55cff00f) - MIN(test_cube_all_order_metrics_0.unit_price_min_55cff00f)) / NULLIF(SUM(test_cube_all_order_metrics_0.unit_price_sum_55cff00f) / SUM(test_cube_all_order_metrics_0.unit_price_count_55cff00f), 0) * 100 AS price_spread_pct + (MAX(test_cube_all_order_metrics_0.unit_price_max_55cff00f) - MIN(test_cube_all_order_metrics_0.unit_price_min_55cff00f)) / NULLIF(SUM(test_cube_all_order_metrics_0.unit_price_sum_55cff00f) / NULLIF(SUM(test_cube_all_order_metrics_0.unit_price_count_55cff00f), 0), 0) * 100 AS price_spread_pct FROM test_cube_all_order_metrics_0 GROUP BY test_cube_all_order_metrics_0.category @@ -890,13 +890,13 @@ async def test_builds_sql_from_cube_with_all_v3_order_details_metrics( SUM(test_cube_all_order_metrics_0.quantity_sum_06b64d2e) AS total_quantity, COUNT( DISTINCT test_cube_all_order_metrics_0.order_id_distinct_f93d50ab) AS order_count, hll_sketch_estimate(ds_hll(test_cube_all_order_metrics_0.customer_id_hll_23002251)) AS customer_count, - SAFE_DIVIDE(SUM(test_cube_all_order_metrics_0.unit_price_sum_55cff00f), SUM(test_cube_all_order_metrics_0.unit_price_count_55cff00f)) AS avg_unit_price, + SAFE_DIVIDE(SUM(test_cube_all_order_metrics_0.unit_price_sum_55cff00f), NULLIF(SUM(test_cube_all_order_metrics_0.unit_price_count_55cff00f), 0)) AS avg_unit_price, MAX(test_cube_all_order_metrics_0.unit_price_max_55cff00f) AS max_unit_price, MIN(test_cube_all_order_metrics_0.unit_price_min_55cff00f) AS min_unit_price, SAFE_DIVIDE(SUM(test_cube_all_order_metrics_0.line_total_sum_e1f61696), NULLIF(COUNT( DISTINCT test_cube_all_order_metrics_0.order_id_distinct_f93d50ab), 0)) AS avg_order_value, SAFE_DIVIDE(SUM(test_cube_all_order_metrics_0.quantity_sum_06b64d2e), NULLIF(COUNT( DISTINCT test_cube_all_order_metrics_0.order_id_distinct_f93d50ab), 0)) AS avg_items_per_order, SAFE_DIVIDE(SUM(test_cube_all_order_metrics_0.line_total_sum_e1f61696), NULLIF(hll_sketch_estimate(ds_hll(test_cube_all_order_metrics_0.customer_id_hll_23002251)), 0)) AS revenue_per_customer, - SAFE_DIVIDE((MAX(test_cube_all_order_metrics_0.unit_price_max_55cff00f) - MIN(test_cube_all_order_metrics_0.unit_price_min_55cff00f)), NULLIF(SAFE_DIVIDE(SUM(test_cube_all_order_metrics_0.unit_price_sum_55cff00f), SUM(test_cube_all_order_metrics_0.unit_price_count_55cff00f)), 0)) * 100 AS price_spread_pct + SAFE_DIVIDE((MAX(test_cube_all_order_metrics_0.unit_price_max_55cff00f) - MIN(test_cube_all_order_metrics_0.unit_price_min_55cff00f)), NULLIF(SAFE_DIVIDE(SUM(test_cube_all_order_metrics_0.unit_price_sum_55cff00f), NULLIF(SUM(test_cube_all_order_metrics_0.unit_price_count_55cff00f), 0)), 0)) * 100 AS price_spread_pct FROM test_cube_all_order_metrics_0 GROUP BY test_cube_all_order_metrics_0.category @@ -2295,7 +2295,7 @@ async def test_build_metrics_sql_cube_with_multi_component_metric( SELECT cube_avg_metric_0.category AS category, SAFE_DIVIDE(SUM(cube_avg_metric_0.unit_price_sum_55cff00f), - SUM(cube_avg_metric_0.unit_price_count_55cff00f)) AS avg_unit_price + NULLIF(SUM(cube_avg_metric_0.unit_price_count_55cff00f), 0)) AS avg_unit_price FROM cube_avg_metric_0 GROUP BY cube_avg_metric_0.category """, diff --git a/datajunction-server/tests/sql/decompose_test.py b/datajunction-server/tests/sql/decompose_test.py index 8d65ce973..ce8dd8fb7 100644 --- a/datajunction-server/tests/sql/decompose_test.py +++ b/datajunction-server/tests/sql/decompose_test.py @@ -14,7 +14,13 @@ MetricComponent, ) from datajunction_server.models.node_type import NodeType -from datajunction_server.sql.decompose import MetricComponentExtractor +from datajunction_server.sql.decompose import ( + MetricComponentExtractor, + safe_denominator, + wrap_divisions_in_nullif, +) +from datajunction_server.sql.parsing import ast +from datajunction_server.sql.parsing.backends.antlr4 import parse from datajunction_server.sql.parsing.backends.exceptions import DJParseException from datajunction_server.models.engine import Dialect from datajunction_server.sql.parsing.ast import to_sql @@ -2164,12 +2170,9 @@ async def test_extract_nested_derived_with_avg( def test_safe_denominator_idempotent(): """``safe_denominator`` wraps once and only once.""" - from datajunction_server.sql.decompose import safe_denominator - from datajunction_server.sql.parsing.backends.antlr4 import ast - inner = ast.Function(ast.Name("SUM"), args=[ast.Column(ast.Name("n"))]) wrapped = safe_denominator(inner) - assert_sql_equal(str(wrapped), "NULLIF(SUM(n), 0)") + assert_sql_equal(f"SELECT {wrapped}", "SELECT NULLIF(SUM(n), 0)") # Re-wrapping returns the same expression unchanged. re_wrapped = safe_denominator(wrapped) assert re_wrapped is wrapped @@ -2177,25 +2180,19 @@ def test_safe_denominator_idempotent(): def test_safe_denominator_preserves_literal(): """Numeric literals (e.g. ``x / 100``) don't need NULLIF.""" - from datajunction_server.sql.decompose import safe_denominator - from datajunction_server.sql.parsing.backends.antlr4 import ast - lit = ast.Number(value=100) assert safe_denominator(lit) is lit def test_wrap_divisions_in_nullif_walks_nested(): """Every nested Divide in the expression tree gets its RHS wrapped.""" - from datajunction_server.sql.decompose import wrap_divisions_in_nullif - from datajunction_server.sql.parsing.backends.antlr4 import parse - expr = parse( - "SELECT (SUM(a) / SUM(b)) - (SUM(c) / SUM(d)) FROM t", + "SELECT SUM(a) / SUM(b) - SUM(c) / SUM(d) FROM t", ).select.projection[0] wrap_divisions_in_nullif(expr) assert_sql_equal( - str(expr), - "(SUM(a) / NULLIF(SUM(b), 0)) - (SUM(c) / NULLIF(SUM(d), 0))", + f"SELECT {expr}", + "SELECT SUM(a) / NULLIF(SUM(b), 0) - SUM(c) / NULLIF(SUM(d), 0)", ) @@ -2213,7 +2210,7 @@ async def test_avg_decomposition_wraps_denominator( _, derived_sql = await extractor.extract(session) assert_sql_equal( str(derived_sql), - "SELECT SUM(amount_sum_67a0b14a) / NULLIF(SUM(amount_count_67a0b14a), 0) " + "SELECT SUM(amount_sum_9e341235) / NULLIF(SUM(amount_count_9e341235), 0) " "FROM parent_node", ) @@ -2233,5 +2230,5 @@ async def test_user_authored_division_not_double_wrapped( _, derived_sql = await extractor.extract(session) assert_sql_equal( str(derived_sql), - "SELECT SUM(x_sum_22e2f0d6) / NULLIF(SUM(y_sum_22e2f0d6), 0) FROM parent_node", + "SELECT SUM(x_sum_b5c12ce5) / NULLIF(SUM(y_sum_898a9389), 0) FROM parent_node", ) From cb6064a1021bffaf336d5ce83e86fb71f2672fdd Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Mon, 18 May 2026 14:28:45 -0700 Subject: [PATCH 3/6] Fix all tests --- .../construction/build_v3/cte.py | 6 ++-- .../construction/build_v3/metrics.py | 7 ++--- .../datajunction_server/sql/decompose.py | 29 ++++++++----------- .../tests/sql/decompose_test.py | 12 ++++---- 4 files changed, 23 insertions(+), 31 deletions(-) diff --git a/datajunction-server/datajunction_server/construction/build_v3/cte.py b/datajunction-server/datajunction_server/construction/build_v3/cte.py index c5ff9f28c..3cf034bd7 100644 --- a/datajunction-server/datajunction_server/construction/build_v3/cte.py +++ b/datajunction-server/datajunction_server/construction/build_v3/cte.py @@ -1594,10 +1594,8 @@ def process_metric_combiner_expression( partition_cte_alias=cte_alias, ) - # Auto-wrap every Divide's RHS in NULLIF(_, 0) so user-authored - # ratio metrics don't produce NaN/Infinity/error on zero denominators. - # Idempotent — denominators that already have NULLIF or a literal - # constant are left alone. + # Wrap denominators so user-authored ratio metrics don't blow up on + # zero (NaN/Infinity/error → NULL). Idempotent. wrap_divisions_in_nullif(expr_ast) return expr_ast diff --git a/datajunction-server/datajunction_server/construction/build_v3/metrics.py b/datajunction-server/datajunction_server/construction/build_v3/metrics.py index e4649ec00..01c12d2a2 100644 --- a/datajunction-server/datajunction_server/construction/build_v3/metrics.py +++ b/datajunction-server/datajunction_server/construction/build_v3/metrics.py @@ -687,10 +687,9 @@ def build_intermediate_metric_expr( # The dependency hasn't been built, so defer this metric return None # pragma: no cover - # Auto-wrap every Divide's RHS in NULLIF(_, 0). Intermediate - # derived metrics like ``avg_order_value = total_revenue / - # order_count`` inline raw aggregations on both sides; without - # NULLIF the result is NaN/Infinity/error when the denominator is 0. + # Intermediate derived metrics like avg_order_value = + # total_revenue / order_count inline raw aggregations on both + # sides — wrap denominators to avoid NaN/Infinity/error on 0. wrap_divisions_in_nullif(cast(ast.Expression, expr_ast)) return expr_ast # type: ignore diff --git a/datajunction-server/datajunction_server/sql/decompose.py b/datajunction-server/datajunction_server/sql/decompose.py index 19075269a..61297adaf 100644 --- a/datajunction-server/datajunction_server/sql/decompose.py +++ b/datajunction-server/datajunction_server/sql/decompose.py @@ -35,17 +35,15 @@ def make_func(name: str, *args: ast.Expression | str) -> ast.Function: def safe_denominator(expr: ast.Expression) -> ast.Expression: - """Wrap ``expr`` in ``NULLIF(expr, 0)`` to make it safe as a divisor. + """Wrap expr in NULLIF(expr, 0) to make it safe as a divisor. - Idempotent — if ``expr`` is already ``NULLIF(..., 0)`` (or a numeric - literal), it's returned unchanged. Caller's responsibility to pass - only the ``right`` side of a Divide. + Idempotent: if expr is already NULLIF(_, 0) or a numeric literal, + returns it unchanged. Caller passes the RHS of a Divide. """ - # Already a literal — division by literal 0 is the author's intent - # and we should preserve it as-is (most likely the literal isn't 0). + # Numeric literals: x / 100 doesn't need wrapping. if isinstance(expr, ast.Number): return expr - # Already wrapped in NULLIF(x, 0) + # Already wrapped. if ( isinstance(expr, ast.Function) and expr.name.name.upper() == "NULLIF" @@ -61,12 +59,10 @@ def safe_denominator(expr: ast.Expression) -> ast.Expression: def wrap_divisions_in_nullif(expr: ast.Expression) -> ast.Expression: - """Walk ``expr`` and wrap the right-hand side of every Divide ``BinaryOp`` - in ``NULLIF(..., 0)`` so division-by-zero produces NULL instead of - NaN/Infinity/error. + """Walk expr and wrap the RHS of every Divide BinaryOp in NULLIF(_, 0) + so division-by-zero produces NULL instead of NaN/Infinity/error. - Returns ``expr`` (mutated in place where possible). Idempotent via - :func:`safe_denominator`. + Mutates and returns expr. Idempotent via :func:`safe_denominator`. """ for node in expr.find_all(ast.BinaryOp): if node.op != ast.BinaryOpKind.Divide: @@ -1123,11 +1119,10 @@ def _decompose( ) else: combiner_ast = decomposition.combine(components) - # Auto-wrap every Divide's RHS in NULLIF(_, 0). Decomposed AVG / - # variance / stddev / covariance all construct ``SUM(...) / - # SUM(count)`` patterns where the denominator can legitimately - # be 0; without NULLIF the result is NaN/Infinity/error - # depending on dialect. Idempotent. + # Decomposed AVG / variance / stddev / covariance all build + # SUM(...) / SUM(count)-style combiners where the denominator + # can legitimately be 0. Wrap to produce NULL rather than + # NaN/Infinity/error. combiner_ast = wrap_divisions_in_nullif(combiner_ast) return DecompositionResult(components, combiner_ast) diff --git a/datajunction-server/tests/sql/decompose_test.py b/datajunction-server/tests/sql/decompose_test.py index ce8dd8fb7..a12dfcc82 100644 --- a/datajunction-server/tests/sql/decompose_test.py +++ b/datajunction-server/tests/sql/decompose_test.py @@ -2169,7 +2169,7 @@ async def test_extract_nested_derived_with_avg( def test_safe_denominator_idempotent(): - """``safe_denominator`` wraps once and only once.""" + """safe_denominator wraps once and only once.""" inner = ast.Function(ast.Name("SUM"), args=[ast.Column(ast.Name("n"))]) wrapped = safe_denominator(inner) assert_sql_equal(f"SELECT {wrapped}", "SELECT NULLIF(SUM(n), 0)") @@ -2179,7 +2179,7 @@ def test_safe_denominator_idempotent(): def test_safe_denominator_preserves_literal(): - """Numeric literals (e.g. ``x / 100``) don't need NULLIF.""" + """Numeric literals (e.g. x / 100) don't need NULLIF.""" lit = ast.Number(value=100) assert safe_denominator(lit) is lit @@ -2202,8 +2202,8 @@ async def test_avg_decomposition_wraps_denominator( create_metric, ): """The AVG combiner is auto-wrapped so 0/0 produces NULL instead of - NaN/Infinity/error. Together with the unit-level ``test_average`` - above, locks in the auto-wrap behaviour end-to-end via decomposition. + NaN/Infinity/error. Together with the unit-level test_average above, + locks in the auto-wrap behaviour end-to-end via decomposition. """ metric_rev = await create_metric("SELECT AVG(amount) FROM parent_node") extractor = MetricComponentExtractor(metric_rev.id) @@ -2220,8 +2220,8 @@ async def test_user_authored_division_not_double_wrapped( session: AsyncSession, create_metric, ): - """If the author already wrote ``NULLIF(denominator, 0)``, the - auto-wrap is idempotent — no double-wrap. + """If the author already wrote NULLIF(denominator, 0), the auto-wrap + is idempotent — no double-wrap. """ metric_rev = await create_metric( "SELECT SUM(x) / NULLIF(SUM(y), 0) FROM parent_node", From 3ee3845c937a17ef40cb456dfcc088443a204c16 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Mon, 18 May 2026 17:07:10 -0700 Subject: [PATCH 4/6] Fix SQL gen tests after NULLIF changes --- datajunction-server/tests/api/cubes_test.py | 24 +++++++++---------- .../tests/api/deployments_test.py | 6 ++--- datajunction-server/tests/api/djql_test.py | 2 +- datajunction-server/tests/api/metrics_test.py | 4 ++-- datajunction-server/tests/api/nodes_test.py | 21 ++++++++++------ datajunction-server/tests/api/sql_test.py | 22 ++++++++--------- .../tests/construction/build_v3/cte_test.py | 2 +- .../construction/build_v3/metrics_sql_test.py | 2 +- .../build_v3/preagg_substitution_test.py | 2 +- .../tests/construction/build_v3/types_test.py | 2 +- 10 files changed, 47 insertions(+), 40 deletions(-) diff --git a/datajunction-server/tests/api/cubes_test.py b/datajunction-server/tests/api/cubes_test.py index 03d37c84a..941e88925 100644 --- a/datajunction-server/tests/api/cubes_test.py +++ b/datajunction-server/tests/api/cubes_test.py @@ -757,7 +757,7 @@ async def test_create_cube( default_DOT_dispatcher.company_name default_DOT_dispatcher_DOT_company_name, default_DOT_municipality_dim.local_region default_DOT_municipality_dim_DOT_local_region, default_DOT_hard_hat_to_delete.hire_date default_DOT_hard_hat_to_delete_DOT_hire_date, - CAST(sum(if(default_DOT_repair_orders_fact.discount > 0.0, 1, 0)) AS DOUBLE) / count(*) + CAST(sum(if(default_DOT_repair_orders_fact.discount > 0.0, 1, 0)) AS DOUBLE) / NULLIF(count(*), 0) AS default_DOT_discounted_orders_rate, count(default_DOT_repair_orders_fact.repair_order_id) default_DOT_num_repair_orders, avg(default_DOT_repair_orders_fact.price) default_DOT_avg_repair_price, @@ -947,9 +947,9 @@ async def test_create_cube( COALESCE(repair_order_details_0.company_name, repair_orders_fact_0.company_name) AS company_name, COALESCE(repair_order_details_0.local_region, repair_orders_fact_0.local_region) AS local_region, COALESCE(repair_order_details_0.hire_date, repair_orders_fact_0.hire_date) AS hire_date, - CAST(SUM(repair_orders_fact_0.discount_sum_30b84e6c) AS DOUBLE) / SUM(repair_orders_fact_0.count_c8e42e74) AS discounted_orders_rate, + CAST(SUM(repair_orders_fact_0.discount_sum_30b84e6c) AS DOUBLE) / NULLIF(SUM(repair_orders_fact_0.count_c8e42e74), 0) AS discounted_orders_rate, SUM(repair_orders_fact_0.repair_order_id_count_bd241964) AS num_repair_orders, - SUM(repair_orders_fact_0.price_sum_935e7117) / SUM(repair_orders_fact_0.price_count_935e7117) AS avg_repair_price, + SUM(repair_orders_fact_0.price_sum_935e7117) / NULLIF(SUM(repair_orders_fact_0.price_count_935e7117), 0) AS avg_repair_price, SUM(repair_orders_fact_0.total_repair_cost_sum_67874507) AS total_repair_cost, SUM(repair_orders_fact_0.price_discount_sum_e4ba5456) AS total_repair_order_discounts, SUM(repair_order_details_0.price_sum_252381cf) + SUM(repair_order_details_0.price_sum_252381cf) AS double_total_repair_cost @@ -1082,9 +1082,9 @@ async def test_cube_filters_merged_with_request_filters( COALESCE(repair_order_details_0.company_name, repair_orders_fact_0.company_name) AS company_name, COALESCE(repair_order_details_0.local_region, repair_orders_fact_0.local_region) AS local_region, COALESCE(repair_order_details_0.hire_date, repair_orders_fact_0.hire_date) AS hire_date, - CAST(SUM(repair_orders_fact_0.discount_sum_30b84e6c) AS DOUBLE) / SUM(repair_orders_fact_0.count_c8e42e74) AS discounted_orders_rate, + CAST(SUM(repair_orders_fact_0.discount_sum_30b84e6c) AS DOUBLE) / NULLIF(SUM(repair_orders_fact_0.count_c8e42e74), 0) AS discounted_orders_rate, SUM(repair_orders_fact_0.repair_order_id_count_bd241964) AS num_repair_orders, - SUM(repair_orders_fact_0.price_sum_935e7117) / SUM(repair_orders_fact_0.price_count_935e7117) AS avg_repair_price, + SUM(repair_orders_fact_0.price_sum_935e7117) / NULLIF(SUM(repair_orders_fact_0.price_count_935e7117), 0) AS avg_repair_price, SUM(repair_orders_fact_0.total_repair_cost_sum_67874507) AS total_repair_cost, SUM(repair_orders_fact_0.price_discount_sum_e4ba5456) AS total_repair_order_discounts, SUM(repair_order_details_0.price_sum_252381cf) + SUM(repair_order_details_0.price_sum_252381cf) AS double_total_repair_cost @@ -1163,7 +1163,7 @@ async def test_cube_filters_applied_in_v3_sql_via_cube_param( SELECT repair_orders_fact_0.state AS state, repair_orders_fact_0.company_name AS company_name, SUM(repair_orders_fact_0.repair_order_id_count_bd241964) AS num_repair_orders, - SUM(repair_orders_fact_0.price_sum_935e7117) / SUM(repair_orders_fact_0.price_count_935e7117) AS avg_repair_price, + SUM(repair_orders_fact_0.price_sum_935e7117) / NULLIF(SUM(repair_orders_fact_0.price_count_935e7117), 0) AS avg_repair_price, SUM(repair_orders_fact_0.total_repair_cost_sum_67874507) AS total_repair_cost FROM repair_orders_fact_0 WHERE repair_orders_fact_0.state = 'AZ' @@ -1247,7 +1247,7 @@ async def test_cube_filters_applied_in_v3_sql_via_cube_param( SELECT repair_orders_fact_0.state AS state, repair_orders_fact_0.company_name AS company_name, SUM(repair_orders_fact_0.repair_order_id_count_bd241964) AS num_repair_orders, - SUM(repair_orders_fact_0.price_sum_935e7117) / SUM(repair_orders_fact_0.price_count_935e7117) AS avg_repair_price, + SUM(repair_orders_fact_0.price_sum_935e7117) / NULLIF(SUM(repair_orders_fact_0.price_count_935e7117), 0) AS avg_repair_price, SUM(repair_orders_fact_0.total_repair_cost_sum_67874507) AS total_repair_cost FROM repair_orders_fact_0 GROUP BY repair_orders_fact_0.state, repair_orders_fact_0.company_name @@ -1296,7 +1296,7 @@ async def test_cube_filters_applied_in_v3_sql_via_cube_param( SELECT repair_orders_fact_0.state AS state, repair_orders_fact_0.company_name AS company_name, SUM(repair_orders_fact_0.repair_order_id_count_bd241964) AS num_repair_orders, - SUM(repair_orders_fact_0.price_sum_935e7117) / SUM(repair_orders_fact_0.price_count_935e7117) AS avg_repair_price, + SUM(repair_orders_fact_0.price_sum_935e7117) / NULLIF(SUM(repair_orders_fact_0.price_count_935e7117), 0) AS avg_repair_price, SUM(repair_orders_fact_0.total_repair_cost_sum_67874507) AS total_repair_cost FROM repair_orders_fact_0 WHERE repair_orders_fact_0.state = 'AZ' AND repair_orders_fact_0.company_name = 'Potts LLC' @@ -1398,9 +1398,9 @@ async def test_cube_only_no_metrics_no_dims(client_with_repairs_cube: AsyncClien COALESCE(repair_order_details_0.company_name, repair_orders_fact_0.company_name) AS company_name, COALESCE(repair_order_details_0.local_region, repair_orders_fact_0.local_region) AS local_region, COALESCE(repair_order_details_0.hire_date, repair_orders_fact_0.hire_date) AS hire_date, - CAST(SUM(repair_orders_fact_0.discount_sum_30b84e6c) AS DOUBLE) / SUM(repair_orders_fact_0.count_c8e42e74) AS discounted_orders_rate, + CAST(SUM(repair_orders_fact_0.discount_sum_30b84e6c) AS DOUBLE) / NULLIF(SUM(repair_orders_fact_0.count_c8e42e74), 0) AS discounted_orders_rate, SUM(repair_orders_fact_0.repair_order_id_count_bd241964) AS num_repair_orders, - SUM(repair_orders_fact_0.price_sum_935e7117) / SUM(repair_orders_fact_0.price_count_935e7117) AS avg_repair_price, + SUM(repair_orders_fact_0.price_sum_935e7117) / NULLIF(SUM(repair_orders_fact_0.price_count_935e7117), 0) AS avg_repair_price, SUM(repair_orders_fact_0.total_repair_cost_sum_67874507) AS total_repair_cost, SUM(repair_orders_fact_0.price_discount_sum_e4ba5456) AS total_repair_order_discounts, SUM(repair_order_details_0.price_sum_252381cf) + SUM(repair_order_details_0.price_sum_252381cf) AS double_total_repair_cost @@ -3049,9 +3049,9 @@ async def test_cube_materialization_metadata( ], }, { - "derived_expression": "SELECT SUM(price_sum_935e7117) / SUM(price_count_935e7117) FROM " + "derived_expression": "SELECT SUM(price_sum_935e7117) / NULLIF(SUM(price_count_935e7117), 0) FROM " "default.repair_orders_fact", - "metric_expression": "SUM(price_sum_935e7117) / SUM(price_count_935e7117)", + "metric_expression": "SUM(price_sum_935e7117) / NULLIF(SUM(price_count_935e7117), 0)", "metric": { "name": "default.avg_repair_price", "version": mock.ANY, diff --git a/datajunction-server/tests/api/deployments_test.py b/datajunction-server/tests/api/deployments_test.py index bcdf68668..5fec6a5ec 100644 --- a/datajunction-server/tests/api/deployments_test.py +++ b/datajunction-server/tests/api/deployments_test.py @@ -1064,8 +1064,8 @@ def default_regional_repair_efficiency(): "Total Repair Amount in Region" is the total amount spent on repairs in a given region. "Total Repair Amount Nationwide" is the total amount spent on all repairs nationwide.""", query="""SELECT - (SUM(rm.completed_repairs) * 1.0 / SUM(rm.total_repairs_dispatched)) * - (SUM(rm.total_amount_in_region) * 1.0 / SUM(na.total_amount_nationwide)) * 100 + (SUM(rm.completed_repairs) * 1.0 / NULLIF(SUM(rm.total_repairs_dispatched), 0)) * + (SUM(rm.total_amount_in_region) * 1.0 / NULLIF(SUM(na.total_amount_nationwide), 0)) * 100 FROM ${prefix}default.regional_level_agg rm CROSS JOIN @@ -1127,7 +1127,7 @@ def default_discounted_orders_rate(): description="""Proportion of Discounted Orders""", query=""" SELECT - cast(sum(if(discount > 0.0, 1, 0)) as double) / count(*) + cast(sum(if(discount > 0.0, 1, 0)) as double) / NULLIF(count(*), 0) AS default_DOT_discounted_orders_rate FROM ${prefix}default.repair_orders_fact """, diff --git a/datajunction-server/tests/api/djql_test.py b/datajunction-server/tests/api/djql_test.py index e04eaf686..69fdc5453 100644 --- a/datajunction-server/tests/api/djql_test.py +++ b/datajunction-server/tests/api/djql_test.py @@ -483,7 +483,7 @@ async def test_get_djsql_with_orderby_and_limit( ) SELECT repair_orders_fact_0.country AS country, - SUM(repair_orders_fact_0.price_sum_HASH) / SUM(repair_orders_fact_0.price_count_HASH) AS avg_repair_price + SUM(repair_orders_fact_0.price_sum_HASH) / NULLIF(SUM(repair_orders_fact_0.price_count_HASH), 0) AS avg_repair_price FROM repair_orders_fact_0 GROUP BY repair_orders_fact_0.country ORDER BY country DESC diff --git a/datajunction-server/tests/api/metrics_test.py b/datajunction-server/tests/api/metrics_test.py index d0961f189..e38ef243c 100644 --- a/datajunction-server/tests/api/metrics_test.py +++ b/datajunction-server/tests/api/metrics_test.py @@ -482,11 +482,11 @@ async def test_read_metrics(module__client_with_roads: AsyncClient) -> None: }, ] assert data["derived_query"] == ( - "SELECT CAST(SUM(discount_sum_30b84e6c) AS DOUBLE) / SUM(count_c8e42e74) AS " + "SELECT CAST(SUM(discount_sum_30b84e6c) AS DOUBLE) / NULLIF(SUM(count_c8e42e74), 0) AS " "default_DOT_discounted_orders_rate \n FROM default.repair_orders_fact" ) assert data["derived_expression"] == ( - "CAST(SUM(discount_sum_30b84e6c) AS DOUBLE) / SUM(count_c8e42e74) " + "CAST(SUM(discount_sum_30b84e6c) AS DOUBLE) / NULLIF(SUM(count_c8e42e74), 0) " "AS default_DOT_discounted_orders_rate" ) assert data["custom_metadata"] is None diff --git a/datajunction-server/tests/api/nodes_test.py b/datajunction-server/tests/api/nodes_test.py index 1b19fe29a..9ed22756c 100644 --- a/datajunction-server/tests/api/nodes_test.py +++ b/datajunction-server/tests/api/nodes_test.py @@ -4964,7 +4964,7 @@ async def test_node_column_lineage(self, client_with_roads: AsyncClient): "query": ( """ SELECT - cast(sum(if(discount > 0.0, 1, 0)) as double) / count(repair_order_id) + cast(sum(if(discount > 0.0, 1, 0)) as double) / NULLIF(count(repair_order_id), 0) FROM default.repair_order_details """ ), @@ -5750,7 +5750,10 @@ def test_decompose_expression(): res = decompose_expression( ast.Function(ast.Name("avg"), args=[ast.Column(ast.Name("orders"))]), ) - assert str(res[0]) == "sum(orders3845127662_sum) / count(orders3845127662_count)" + assert ( + str(res[0]) + == "sum(orders3845127662_sum) / NULLIF(count(orders3845127662_count), 0)" + ) assert [measure.alias_or_name.name for measure in res[1]] == [ "orders3845127662_sum", "orders3845127662_count", @@ -5765,7 +5768,8 @@ def test_decompose_expression(): ), ) assert ( - str(res[0]) == "sum(orders3845127662_sum) / count(orders3845127662_count) + 5.5" + str(res[0]) + == "sum(orders3845127662_sum) / NULLIF(count(orders3845127662_count), 0) + 5.5" ) assert [measure.alias_or_name.name for measure in res[1]] == [ "orders3845127662_sum", @@ -5793,8 +5797,8 @@ def test_decompose_expression(): ) assert ( str(res[0]) - == "max(sum(orders_a1170126662_sum) / count(orders_a1170126662_count) " - "+ sum(orders_b3703039740_sum) / count(orders_b3703039740_count))" + == "max(sum(orders_a1170126662_sum) / NULLIF(count(orders_a1170126662_count), 0) " + "+ sum(orders_b3703039740_sum) / NULLIF(count(orders_b3703039740_count), 0))" ) # Decompose `sum(max(orders))` @@ -5834,7 +5838,7 @@ def test_decompose_expression(): ) assert ( str(res[0]) - == "max(orders3845127662_max) + min(validations2970758927_min) / sum(total3257917790_sum)" + == "max(orders3845127662_max) + min(validations2970758927_min) / NULLIF(sum(total3257917790_sum), 0)" ) assert [measure.alias_or_name.name for measure in res[1]] == [ "orders3845127662_max", @@ -5864,7 +5868,10 @@ def test_decompose_expression(): op=ast.BinaryOpKind.Divide, ), ) - assert str(res[0]) == "sum(has_ordered2766370626_sum) / sum(total3257917790_sum)" + assert ( + str(res[0]) + == "sum(has_ordered2766370626_sum) / NULLIF(sum(total3257917790_sum), 0)" + ) assert [measure.alias_or_name.name for measure in res[1]] == [ "has_ordered2766370626_sum", "total3257917790_sum", diff --git a/datajunction-server/tests/api/sql_test.py b/datajunction-server/tests/api/sql_test.py index 2814892bb..41a1fd543 100644 --- a/datajunction-server/tests/api/sql_test.py +++ b/datajunction-server/tests/api/sql_test.py @@ -1782,7 +1782,7 @@ async def test_metric_with_second_order_dimensions( ) SELECT repair_orders_fact_0.city AS city, - SUM(repair_orders_fact_0.price_sum_935e7117) / SUM(repair_orders_fact_0.price_count_935e7117) AS avg_repair_price + SUM(repair_orders_fact_0.price_sum_935e7117) / NULLIF(SUM(repair_orders_fact_0.price_count_935e7117), 0) AS avg_repair_price FROM repair_orders_fact_0 GROUP BY repair_orders_fact_0.city @@ -1861,7 +1861,7 @@ async def test_metric_with_nth_order_dimensions( SELECT repair_orders_fact_0.city AS city, repair_orders_fact_0.company_name AS company_name, - SUM(repair_orders_fact_0.price_sum_935e7117) / SUM(repair_orders_fact_0.price_count_935e7117) AS avg_repair_price + SUM(repair_orders_fact_0.price_sum_935e7117) / NULLIF(SUM(repair_orders_fact_0.price_count_935e7117), 0) AS avg_repair_price FROM repair_orders_fact_0 GROUP BY repair_orders_fact_0.city, repair_orders_fact_0.company_name @@ -2698,7 +2698,7 @@ async def test_get_sql_for_metrics2(client_with_examples: AsyncClient): default_DOT_hard_hat.state default_DOT_hard_hat_DOT_state, default_DOT_dispatcher.company_name default_DOT_dispatcher_DOT_company_name, default_DOT_municipality_dim.local_region default_DOT_municipality_dim_DOT_local_region, - CAST(sum(if(default_DOT_repair_orders_fact.discount > 0.0, 1, 0)) AS DOUBLE) / count(*) AS default_DOT_discounted_orders_rate, + CAST(sum(if(default_DOT_repair_orders_fact.discount > 0.0, 1, 0)) AS DOUBLE) / NULLIF(count(*), 0) AS default_DOT_discounted_orders_rate, count(default_DOT_repair_orders_fact.repair_order_id) default_DOT_num_repair_orders FROM default_DOT_repair_orders_fact INNER JOIN default_DOT_hard_hat @@ -3346,7 +3346,7 @@ async def test_sql_structs(module__client_with_examples: AsyncClient): default_DOT_simple_agg.order_year default_DOT_simple_agg_DOT_order_year, default_DOT_simple_agg.order_month default_DOT_simple_agg_DOT_order_month, default_DOT_simple_agg.order_day default_DOT_simple_agg_DOT_order_day, - SUM(default_DOT_simple_agg.dispatch_delay_sum) / SUM(default_DOT_simple_agg.repair_orders_cnt) default_DOT_average_dispatch_delay + SUM(default_DOT_simple_agg.dispatch_delay_sum) / NULLIF(SUM(default_DOT_simple_agg.repair_orders_cnt), 0) default_DOT_average_dispatch_delay FROM default_DOT_simple_agg WHERE default_DOT_simple_agg.order_year = 2020 GROUP BY default_DOT_simple_agg.order_year, default_DOT_simple_agg.order_month, default_DOT_simple_agg.order_day @@ -4243,7 +4243,7 @@ async def test_role_path_dimensions_in_filters_single_hop( ) SELECT user_dim_0.name_user_birth_country AS name_user_birth_country, - SUM(user_dim_0.age_sum_4ebaaaaa) / SUM(user_dim_0.age_count_4ebaaaaa) AS avg_user_age + SUM(user_dim_0.age_sum_4ebaaaaa) / NULLIF(SUM(user_dim_0.age_count_4ebaaaaa), 0) AS avg_user_age FROM user_dim_0 WHERE user_dim_0.name_user_birth_country = 'United States' GROUP BY user_dim_0.name_user_birth_country @@ -4312,7 +4312,7 @@ async def test_role_path_dimensions_in_filters_multi_hop_geographic( ) SELECT user_dim_0.continent_name_region_continent AS continent_name_region_continent, - SUM(user_dim_0.age_sum_4ebaaaaa) / SUM(user_dim_0.age_count_4ebaaaaa) AS avg_user_age + SUM(user_dim_0.age_sum_4ebaaaaa) / NULLIF(SUM(user_dim_0.age_count_4ebaaaaa), 0) AS avg_user_age FROM user_dim_0 WHERE user_dim_0.continent_name_region_continent = 'North America' GROUP BY user_dim_0.continent_name_region_continent @@ -4388,7 +4388,7 @@ async def test_role_path_dimensions_in_filters_multi_hop_temporal( ) SELECT user_dim_0.year_number_month_year AS year_number_month_year, - SUM(user_dim_0.age_sum_4ebaaaaa) / SUM(user_dim_0.age_count_4ebaaaaa) AS avg_user_age + SUM(user_dim_0.age_sum_4ebaaaaa) / NULLIF(SUM(user_dim_0.age_count_4ebaaaaa), 0) AS avg_user_age FROM user_dim_0 WHERE user_dim_0.year_number_month_year = 2024 GROUP BY user_dim_0.year_number_month_year @@ -4450,7 +4450,7 @@ async def test_role_path_dimensions_mixed_paths( SELECT user_dim_0.name_user_birth_country AS name_user_birth_country, user_dim_0.name_user_residence_country AS name_user_residence_country, - SUM(user_dim_0.age_sum_4ebaaaaa) / SUM(user_dim_0.age_count_4ebaaaaa) AS avg_user_age + SUM(user_dim_0.age_sum_4ebaaaaa) / NULLIF(SUM(user_dim_0.age_count_4ebaaaaa), 0) AS avg_user_age FROM user_dim_0 WHERE user_dim_0.name_user_birth_country = 'Canada' AND user_dim_0.name_user_residence_country = 'United States' GROUP BY user_dim_0.name_user_birth_country, user_dim_0.name_user_residence_country @@ -4548,7 +4548,7 @@ async def test_role_path_dimensions_mixed_hierarchies( SELECT user_dim_0.continent_name_region_continent AS continent_name_region_continent, user_dim_0.month_name_week_month AS month_name_week_month, - SUM(user_dim_0.age_sum_4ebaaaaa) / SUM(user_dim_0.age_count_4ebaaaaa) AS avg_user_age + SUM(user_dim_0.age_sum_4ebaaaaa) / NULLIF(SUM(user_dim_0.age_count_4ebaaaaa), 0) AS avg_user_age FROM user_dim_0 GROUP BY user_dim_0.continent_name_region_continent, user_dim_0.month_name_week_month @@ -4744,7 +4744,7 @@ async def test_multiple_filters_same_role_path( GROUP BY t2.name ) SELECT user_dim_0.name_user_birth_country AS name_user_birth_country, - SUM(user_dim_0.age_sum_4ebaaaaa) / SUM(user_dim_0.age_count_4ebaaaaa) AS avg_user_age + SUM(user_dim_0.age_sum_4ebaaaaa) / NULLIF(SUM(user_dim_0.age_count_4ebaaaaa), 0) AS avg_user_age FROM user_dim_0 WHERE user_dim_0.name_user_birth_country IS NOT NULL GROUP BY user_dim_0.name_user_birth_country @@ -4849,7 +4849,7 @@ async def test_role_path_dimensions_performance_complex_query( user_dim_0.region_name_country_region AS region_name_country_region, user_dim_0.month_name_week_month AS month_name_week_month, user_dim_0.year_number_month_year AS year_number_month_year, - SUM(user_dim_0.age_sum_4ebaaaaa) / SUM(user_dim_0.age_count_4ebaaaaa) AS avg_user_age + SUM(user_dim_0.age_sum_4ebaaaaa) / NULLIF(SUM(user_dim_0.age_count_4ebaaaaa), 0) AS avg_user_age FROM user_dim_0 WHERE user_dim_0.name_user_birth_country IN ('Canada', 'United States', 'Mexico') AND user_dim_0.region_name_country_region = 'North America' AND user_dim_0.month_name_week_month IN ('January', 'February', 'March') AND user_dim_0.year_number_month_year >= 2020 GROUP BY user_dim_0.name_user_birth_country, user_dim_0.name_user_residence_country, user_dim_0.region_name_country_region, user_dim_0.month_name_week_month, user_dim_0.year_number_month_year diff --git a/datajunction-server/tests/construction/build_v3/cte_test.py b/datajunction-server/tests/construction/build_v3/cte_test.py index cc542ca47..71ebc11f8 100644 --- a/datajunction-server/tests/construction/build_v3/cte_test.py +++ b/datajunction-server/tests/construction/build_v3/cte_test.py @@ -264,7 +264,7 @@ def test_aggregate_window_not_modified(self): def test_weighted_cpm_pattern(self): """Test weighted CPM pattern with grand total weight.""" - # Weighted CPM = (revenue / impressions) * (impressions / SUM(impressions) OVER ()) + # Weighted CPM = (revenue / impressions) * (impressions / NULLIF(SUM(impressions), 0) OVER ()) query = parse( "SELECT " "(revenue / NULLIF(impressions / 1000.0, 0)) " diff --git a/datajunction-server/tests/construction/build_v3/metrics_sql_test.py b/datajunction-server/tests/construction/build_v3/metrics_sql_test.py index b2c46f7a0..dfae474a6 100644 --- a/datajunction-server/tests/construction/build_v3/metrics_sql_test.py +++ b/datajunction-server/tests/construction/build_v3/metrics_sql_test.py @@ -146,7 +146,7 @@ async def test_multi_component_metric(self, client_with_build_v3): Test metrics SQL for a multi-component metric (AVG). AVG decomposes into SUM and COUNT, and the combiner expression - should be applied: SUM(x) / COUNT(x). + should be applied: SUM(x) / NULLIF(COUNT(x), 0). """ response = await client_with_build_v3.get( "/sql/metrics/v3/", diff --git a/datajunction-server/tests/construction/build_v3/preagg_substitution_test.py b/datajunction-server/tests/construction/build_v3/preagg_substitution_test.py index 981e9d1e2..1849dcf8a 100644 --- a/datajunction-server/tests/construction/build_v3/preagg_substitution_test.py +++ b/datajunction-server/tests/construction/build_v3/preagg_substitution_test.py @@ -235,7 +235,7 @@ async def test_derived_metric_no_preagg(self, client_with_build_v3): assert metrics_response.status_code == 200 metrics_data = metrics_response.json() - # avg_order_value = SUM(total_revenue) / COUNT(DISTINCT order_id) + # avg_order_value = SUM(total_revenue) / NULLIF(COUNT(DISTINCT order_id), 0) assert_sql_equal( metrics_data["sql"], """ diff --git a/datajunction-server/tests/construction/build_v3/types_test.py b/datajunction-server/tests/construction/build_v3/types_test.py index 5a32c2fbd..ab4cb8e8c 100644 --- a/datajunction-server/tests/construction/build_v3/types_test.py +++ b/datajunction-server/tests/construction/build_v3/types_test.py @@ -61,7 +61,7 @@ def test_is_fully_decomposable_with_limited(self): metric_node=metric_node, components=[component1, component2], aggregability=Aggregability.LIMITED, - combiner="SUM(x) / COUNT(DISTINCT y)", + combiner="SUM(x) / NULLIF(COUNT(DISTINCT y), 0)", derived_ast=derived_ast, ) From c18fefc363dedde7070056f61353cf4d3aa68cd0 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Mon, 18 May 2026 18:47:13 -0700 Subject: [PATCH 5/6] Fix tests --- .../internal/materializations.py | 18 ++++++++++++-- .../datajunction_server/sql/decompose.py | 9 +++++++ datajunction-server/tests/api/cubes_test.py | 6 ++--- .../tests/api/deployments_test.py | 2 +- .../tests/api/graphql/find_nodes_test.py | 8 +++---- datajunction-server/tests/api/sql_test.py | 4 ++-- .../tests/sql/decompose_test.py | 24 +++++++++---------- 7 files changed, 47 insertions(+), 24 deletions(-) diff --git a/datajunction-server/datajunction_server/internal/materializations.py b/datajunction-server/datajunction_server/internal/materializations.py index 2650ae2e8..13b60c16c 100644 --- a/datajunction-server/datajunction_server/internal/materializations.py +++ b/datajunction-server/datajunction_server/internal/materializations.py @@ -428,8 +428,14 @@ def decompose_expression( args=[ast.Column(name=numerator_measure_name)], ), right=ast.Function( - ast.Name("count"), - args=[ast.Column(name=denominator_measure_name)], + ast.Name("NULLIF"), + args=[ + ast.Function( + ast.Name("count"), + args=[ast.Column(name=denominator_measure_name)], + ), + ast.Number(value=0), + ], ), op=ast.BinaryOpKind.Divide, ) @@ -455,6 +461,14 @@ def decompose_expression( if expr.op in acceptable_binary_ops: # pragma: no cover measures_combiner_left, measures_left = decompose_expression(expr.left) measures_combiner_right, measures_right = decompose_expression(expr.right) + if expr.op == ast.BinaryOpKind.Divide and not ( + isinstance(measures_combiner_right, ast.Function) + and measures_combiner_right.alias_or_name.name.lower() == "nullif" + ): + measures_combiner_right = ast.Function( + ast.Name("NULLIF"), + args=[measures_combiner_right, ast.Number(value=0)], + ) combiner = ast.BinaryOp( left=measures_combiner_left, right=measures_combiner_right, diff --git a/datajunction-server/datajunction_server/sql/decompose.py b/datajunction-server/datajunction_server/sql/decompose.py index 61297adaf..2a47e15e0 100644 --- a/datajunction-server/datajunction_server/sql/decompose.py +++ b/datajunction-server/datajunction_server/sql/decompose.py @@ -1038,6 +1038,15 @@ def _extract_base( components_tracker.add(comp.name) components.append(comp) + # Wrap user-authored divisions in the outer expression too. + # _decompose only wraps the small combiners it builds for AVG / + # variance / stddev / covariance; the user's top-level + # expression (e.g. CAST(SUM(...)) / COUNT(*)) is unchanged + # after sub-aggregation replacements and would otherwise emit + # bare divisions. + for proj in query_ast.select.projection: + wrap_divisions_in_nullif(cast(ast.Expression, proj)) + return components, query_ast def _substitute_metric_references( diff --git a/datajunction-server/tests/api/cubes_test.py b/datajunction-server/tests/api/cubes_test.py index 941e88925..a1fee6a22 100644 --- a/datajunction-server/tests/api/cubes_test.py +++ b/datajunction-server/tests/api/cubes_test.py @@ -757,7 +757,7 @@ async def test_create_cube( default_DOT_dispatcher.company_name default_DOT_dispatcher_DOT_company_name, default_DOT_municipality_dim.local_region default_DOT_municipality_dim_DOT_local_region, default_DOT_hard_hat_to_delete.hire_date default_DOT_hard_hat_to_delete_DOT_hire_date, - CAST(sum(if(default_DOT_repair_orders_fact.discount > 0.0, 1, 0)) AS DOUBLE) / NULLIF(count(*), 0) + CAST(sum(if(default_DOT_repair_orders_fact.discount > 0.0, 1, 0)) AS DOUBLE) / count(*) AS default_DOT_discounted_orders_rate, count(default_DOT_repair_orders_fact.repair_order_id) default_DOT_num_repair_orders, avg(default_DOT_repair_orders_fact.price) default_DOT_avg_repair_price, @@ -3000,10 +3000,10 @@ async def test_cube_materialization_metadata( assert results["metrics"] == [ { "derived_expression": "SELECT CAST(SUM(discount_sum_30b84e6c) AS DOUBLE) / " - "SUM(count_c8e42e74) AS default_DOT_discounted_orders_rate FROM " + "NULLIF(SUM(count_c8e42e74), 0) AS default_DOT_discounted_orders_rate FROM " "default.repair_orders_fact", "metric_expression": "CAST(SUM(discount_sum_30b84e6c) AS DOUBLE) / " - "SUM(count_c8e42e74)", + "NULLIF(SUM(count_c8e42e74), 0)", "metric": { "name": "default.discounted_orders_rate", "version": mock.ANY, diff --git a/datajunction-server/tests/api/deployments_test.py b/datajunction-server/tests/api/deployments_test.py index 5fec6a5ec..65e0dafa2 100644 --- a/datajunction-server/tests/api/deployments_test.py +++ b/datajunction-server/tests/api/deployments_test.py @@ -1127,7 +1127,7 @@ def default_discounted_orders_rate(): description="""Proportion of Discounted Orders""", query=""" SELECT - cast(sum(if(discount > 0.0, 1, 0)) as double) / NULLIF(count(*), 0) + cast(sum(if(discount > 0.0, 1, 0)) as double) / count(*) AS default_DOT_discounted_orders_rate FROM ${prefix}default.repair_orders_fact """, diff --git a/datajunction-server/tests/api/graphql/find_nodes_test.py b/datajunction-server/tests/api/graphql/find_nodes_test.py index 7ac4f67d6..733381522 100644 --- a/datajunction-server/tests/api/graphql/find_nodes_test.py +++ b/datajunction-server/tests/api/graphql/find_nodes_test.py @@ -716,16 +716,16 @@ async def test_find_metric( }, ], "derivedQuery": "SELECT (SUM(completed_repairs_sum_8b112bf1) * 1.0 / " - "SUM(total_repairs_dispatched_sum_601dc4f1)) * " + "NULLIF(SUM(total_repairs_dispatched_sum_601dc4f1), 0)) * " "(SUM(total_amount_in_region_sum_3426ede4) * 1.0 / " - "SUM(na_DOT_total_amount_nationwide_sum_4ecb2318)) * 100 \n" + "NULLIF(SUM(na_DOT_total_amount_nationwide_sum_4ecb2318), 0)) * 100 \n" " FROM default.regional_level_agg CROSS JOIN " "default.national_level_agg na\n" "\n", "derivedExpression": "(SUM(completed_repairs_sum_8b112bf1) * 1.0 / " - "SUM(total_repairs_dispatched_sum_601dc4f1)) * " + "NULLIF(SUM(total_repairs_dispatched_sum_601dc4f1), 0)) * " "(SUM(total_amount_in_region_sum_3426ede4) * 1.0 / " - "SUM(na_DOT_total_amount_nationwide_sum_4ecb2318)) * 100", + "NULLIF(SUM(na_DOT_total_amount_nationwide_sum_4ecb2318), 0)) * 100", }, }, "name": "default.regional_repair_efficiency", diff --git a/datajunction-server/tests/api/sql_test.py b/datajunction-server/tests/api/sql_test.py index 41a1fd543..6bc3d721b 100644 --- a/datajunction-server/tests/api/sql_test.py +++ b/datajunction-server/tests/api/sql_test.py @@ -2698,7 +2698,7 @@ async def test_get_sql_for_metrics2(client_with_examples: AsyncClient): default_DOT_hard_hat.state default_DOT_hard_hat_DOT_state, default_DOT_dispatcher.company_name default_DOT_dispatcher_DOT_company_name, default_DOT_municipality_dim.local_region default_DOT_municipality_dim_DOT_local_region, - CAST(sum(if(default_DOT_repair_orders_fact.discount > 0.0, 1, 0)) AS DOUBLE) / NULLIF(count(*), 0) AS default_DOT_discounted_orders_rate, + CAST(sum(if(default_DOT_repair_orders_fact.discount > 0.0, 1, 0)) AS DOUBLE) / count(*) AS default_DOT_discounted_orders_rate, count(default_DOT_repair_orders_fact.repair_order_id) default_DOT_num_repair_orders FROM default_DOT_repair_orders_fact INNER JOIN default_DOT_hard_hat @@ -3346,7 +3346,7 @@ async def test_sql_structs(module__client_with_examples: AsyncClient): default_DOT_simple_agg.order_year default_DOT_simple_agg_DOT_order_year, default_DOT_simple_agg.order_month default_DOT_simple_agg_DOT_order_month, default_DOT_simple_agg.order_day default_DOT_simple_agg_DOT_order_day, - SUM(default_DOT_simple_agg.dispatch_delay_sum) / NULLIF(SUM(default_DOT_simple_agg.repair_orders_cnt), 0) default_DOT_average_dispatch_delay + SUM(default_DOT_simple_agg.dispatch_delay_sum) / SUM(default_DOT_simple_agg.repair_orders_cnt) default_DOT_average_dispatch_delay FROM default_DOT_simple_agg WHERE default_DOT_simple_agg.order_year = 2020 GROUP BY default_DOT_simple_agg.order_year, default_DOT_simple_agg.order_month, default_DOT_simple_agg.order_day diff --git a/datajunction-server/tests/sql/decompose_test.py b/datajunction-server/tests/sql/decompose_test.py index a12dfcc82..e125d0bf5 100644 --- a/datajunction-server/tests/sql/decompose_test.py +++ b/datajunction-server/tests/sql/decompose_test.py @@ -383,7 +383,7 @@ async def test_rate(session: AsyncSession, create_metric): assert measures == expected_measures0 assert_sql_equal( str(derived_sql), - "SELECT SUM(clicks_sum_c45fd8cf) / SUM(impressions_sum_3be0a0e7) FROM parent_node", + "SELECT SUM(clicks_sum_c45fd8cf) / NULLIF(SUM(impressions_sum_3be0a0e7), 0) FROM parent_node", ) metric_rev2 = await create_metric( @@ -424,7 +424,7 @@ async def test_rate(session: AsyncSession, create_metric): assert_sql_equal( str(derived_sql), "SELECT CAST(CAST(SUM(clicks_sum_c45fd8cf) AS INT) AS DOUBLE) / " - "CAST(SUM(impressions_sum_3be0a0e7) AS DOUBLE) FROM parent_node", + "NULLIF(CAST(SUM(impressions_sum_3be0a0e7) AS DOUBLE), 0) FROM parent_node", ) metric_rev4 = await create_metric( @@ -452,7 +452,7 @@ async def test_rate(session: AsyncSession, create_metric): assert_sql_equal( str(derived_sql), "SELECT COALESCE(SUM(clicks_sum_c45fd8cf) / " - "SUM(impressions_sum_3be0a0e7), 0) FROM parent_node", + "NULLIF(SUM(impressions_sum_3be0a0e7), 0), 0) FROM parent_node", ) metric_rev5 = await create_metric( @@ -481,7 +481,7 @@ async def test_rate(session: AsyncSession, create_metric): assert_sql_equal( str(derived_sql), "SELECT IF(SUM(clicks_sum_c45fd8cf) > 0, CAST(SUM(impressions_sum_3be0a0e7) AS DOUBLE)" - " / CAST(SUM(clicks_sum_c45fd8cf) AS DOUBLE), NULL) FROM parent_node", + " / NULLIF(CAST(SUM(clicks_sum_c45fd8cf) AS DOUBLE), 0), NULL) FROM parent_node", ) metric_rev6 = await create_metric( @@ -508,7 +508,7 @@ async def test_rate(session: AsyncSession, create_metric): assert measures == expected_measures assert_sql_equal( str(derived_sql), - "SELECT ln(SUM(clicks_sum_c45fd8cf) + 1) / SUM(views_sum_d8e39817) FROM parent_node", + "SELECT ln(SUM(clicks_sum_c45fd8cf) + 1) / NULLIF(SUM(views_sum_d8e39817), 0) FROM parent_node", ) @@ -570,7 +570,7 @@ async def test_fraction_with_if(session: AsyncSession, create_metric): str(derived_sql), "SELECT IF(SUM(action_sum_c9802ccb) > 0, " "CAST(SUM(action_two_sum_05d921a8) AS DOUBLE) / " - "CAST(SUM(action_sum_c9802ccb) AS DOUBLE), NULL) FROM parent_node", + "NULLIF(CAST(SUM(action_sum_c9802ccb) AS DOUBLE), 0), NULL) FROM parent_node", ) @@ -638,7 +638,7 @@ async def test_count_distinct_rate(session: AsyncSession, create_metric): assert_sql_equal( str(derived_sql), "SELECT COUNT( DISTINCT user_id_distinct_7f092f23) / " - "SUM(action_count_50d753fd) FROM parent_node", + "NULLIF(SUM(action_count_50d753fd), 0) FROM parent_node", ) @@ -831,7 +831,7 @@ async def test_count_if(session: AsyncSession, create_metric): assert measures == expected_measures assert_sql_equal( str(derived_sql), - "SELECT CAST(SUM(field_a_count_if_3979ffbd) AS FLOAT) / SUM(count_58ac32c5) " + "SELECT CAST(SUM(field_a_count_if_3979ffbd) AS FLOAT) / NULLIF(SUM(count_58ac32c5), 0) " "FROM parent_node", ) @@ -1129,7 +1129,7 @@ async def test_approx_count_distinct_rate(session: AsyncSession, create_metric): derived_str = str(derived_sql) assert_sql_equal( derived_str, - "SELECT CAST(hll_sketch_estimate(hll_union_agg(clicked_user_id_hll_f3824813)) AS DOUBLE) / CAST(hll_sketch_estimate(hll_union_agg(user_id_hll_7f092f23)) AS DOUBLE) FROM parent_node", + "SELECT CAST(hll_sketch_estimate(hll_union_agg(clicked_user_id_hll_f3824813)) AS DOUBLE) / NULLIF(CAST(hll_sketch_estimate(hll_union_agg(user_id_hll_7f092f23)) AS DOUBLE), 0) FROM parent_node", ) @@ -1196,21 +1196,21 @@ async def test_approx_count_distinct_combined_metrics_dialect_translation( # Verify Spark SQL structure - contains both SUM and HLL assert_sql_equal( spark_sql, - "SELECT SUM(revenue_sum_60e4d31f) / hll_sketch_estimate(hll_union_agg(user_id_hll_7f092f23)) AS revenue_per_user FROM parent_node", + "SELECT SUM(revenue_sum_60e4d31f) / NULLIF(hll_sketch_estimate(hll_union_agg(user_id_hll_7f092f23)), 0) AS revenue_per_user FROM parent_node", ) # Translate to Druid - should preserve SUM but translate HLL druid_sql = to_sql(derived_sql, Dialect.DRUID) assert_sql_equal( druid_sql, - "SELECT SAFE_DIVIDE(SUM(revenue_sum_60e4d31f), hll_sketch_estimate(ds_hll(user_id_hll_7f092f23))) AS revenue_per_user FROM parent_node", + "SELECT SAFE_DIVIDE(SUM(revenue_sum_60e4d31f), NULLIF(hll_sketch_estimate(ds_hll(user_id_hll_7f092f23)), 0)) AS revenue_per_user FROM parent_node", ) # Translate to Trino trino_sql = to_sql(derived_sql, Dialect.TRINO) assert_sql_equal( trino_sql, - "SELECT SUM(revenue_sum_60e4d31f) / cardinality(merge(user_id_hll_7f092f23)) AS revenue_per_user FROM parent_node", + "SELECT SUM(revenue_sum_60e4d31f) / NULLIF(cardinality(merge(user_id_hll_7f092f23)), 0) AS revenue_per_user FROM parent_node", ) From 5f68c3f78931d5a3a0565d9a3679202f5aa327cd Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Mon, 18 May 2026 19:05:16 -0700 Subject: [PATCH 6/6] Fix client tests --- datajunction-clients/python/tests/examples.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/datajunction-clients/python/tests/examples.py b/datajunction-clients/python/tests/examples.py index 682e59221..87197543f 100644 --- a/datajunction-clients/python/tests/examples.py +++ b/datajunction-clients/python/tests/examples.py @@ -1256,7 +1256,7 @@ "JOINdefault_repair_ordert2ONt1.repair_order_id=t2.repair_order_idLEFTOUTERJOIN" "default_hard_hatt3ONt2.hard_hat_id=t3.hard_hat_idGROUPBYt3.state)SELECT" "repair_order_details_0.stateASstate,\tSUM(repair_order_details_0.price_sum_252381cf)" - "/SUM(repair_order_details_0.price_count_252381cf)ASavg_repair_priceFROM" + "/NULLIF(SUM(repair_order_details_0.price_count_252381cf),0)ASavg_repair_priceFROM" "repair_order_details_0GROUPBYrepair_order_details_0.state": QueryWithResults( **{ "id": "bd98d6be-e2d2-413e-94c7-96d9411ddee2", @@ -1399,7 +1399,7 @@ "LEFTOUTERJOINdefault_repair_ordert2ONt1.repair_order_id=t2.repair_order_id" "LEFTOUTERJOINdefault_hard_hatt3ONt2.hard_hat_id=t3.hard_hat_idGROUPBYt3.city)" "SELECTCOALESCE(repair_order_details_0.city)AScity,\tSUM(repair_order_details_0." - "price_sum_252381cf)/SUM(repair_order_details_0.price_count_252381cf)AS" + "price_sum_252381cf)/NULLIF(SUM(repair_order_details_0.price_count_252381cf),0)AS" "avg_repair_priceFROMrepair_order_details_0GROUPBYrepair_order_details_0.city": QueryWithResults( id="v3-avg-repair-price-city", submitted_query="...", @@ -1437,7 +1437,7 @@ "LEFTOUTERJOINdefault_repair_ordert2ONt1.repair_order_id=t2.repair_order_id" "LEFTOUTERJOINdefault_hard_hatt3ONt2.hard_hat_id=t3.hard_hat_idGROUPBYt3.city)" "SELECTrepair_order_details_0.cityAScity,\tSUM(repair_order_details_0." - "price_sum_252381cf)/SUM(repair_order_details_0.price_count_252381cf)AS" + "price_sum_252381cf)/NULLIF(SUM(repair_order_details_0.price_count_252381cf),0)AS" "avg_repair_priceFROMrepair_order_details_0GROUPBYrepair_order_details_0.city": QueryWithResults( id="v3-avg-repair-price-city", submitted_query="...", @@ -1465,7 +1465,7 @@ "LEFTOUTERJOINdefault_repair_ordert2ONt1.repair_order_id=t2.repair_order_id" "LEFTOUTERJOINdefault_hard_hatt3ONt2.hard_hat_id=t3.hard_hat_idGROUPBYt3.state)" "SELECTCOALESCE(repair_order_details_0.state)ASstate,\tSUM(repair_order_details_0." - "price_sum_252381cf)/SUM(repair_order_details_0.price_count_252381cf)AS" + "price_sum_252381cf)/NULLIF(SUM(repair_order_details_0.price_count_252381cf),0)AS" "avg_repair_priceFROMrepair_order_details_0GROUPBYrepair_order_details_0.state": QueryWithResults( id="v3-avg-repair-price-state-no-data", submitted_query="...", @@ -1481,7 +1481,7 @@ "LEFTOUTERJOINdefault_repair_ordert2ONt1.repair_order_id=t2.repair_order_id" "LEFTOUTERJOINdefault_hard_hatt3ONt2.hard_hat_id=t3.hard_hat_idGROUPBYt3.state)" "SELECTrepair_order_details_0.stateASstate,\tSUM(repair_order_details_0." - "price_sum_252381cf)/SUM(repair_order_details_0.price_count_252381cf)AS" + "price_sum_252381cf)/NULLIF(SUM(repair_order_details_0.price_count_252381cf),0)AS" "avg_repair_priceFROMrepair_order_details_0GROUPBYrepair_order_details_0.state": QueryWithResults( id="v3-avg-repair-price-state-no-data-no-coalesce", submitted_query="...", @@ -1498,7 +1498,7 @@ "LEFTOUTERJOINdefault_repair_ordert2ONt1.repair_order_id=t2.repair_order_id" "LEFTOUTERJOINdefault_hard_hatt3ONt2.hard_hat_id=t3.hard_hat_idGROUPBYt3.city)" "SELECTrepair_order_details_0.cityAScity,\tSUM(repair_order_details_0." - "price_sum_252381cf)/SUM(repair_order_details_0.price_count_252381cf)AS" + "price_sum_252381cf)/NULLIF(SUM(repair_order_details_0.price_count_252381cf),0)AS" "avg_repair_priceFROMrepair_order_details_0GROUPBYrepair_order_details_0.city": QueryWithResults( id="v3-avg-repair-price-city-with-join-key", submitted_query="...", @@ -1526,7 +1526,7 @@ "LEFTOUTERJOINdefault_repair_ordert2ONt1.repair_order_id=t2.repair_order_id" "LEFTOUTERJOINdefault_hard_hatt3ONt2.hard_hat_id=t3.hard_hat_idGROUPBYt3.state)" "SELECTrepair_order_details_0.stateASstate,\tSUM(repair_order_details_0." - "price_sum_252381cf)/SUM(repair_order_details_0.price_count_252381cf)AS" + "price_sum_252381cf)/NULLIF(SUM(repair_order_details_0.price_count_252381cf),0)AS" "avg_repair_priceFROMrepair_order_details_0GROUPBYrepair_order_details_0.state": QueryWithResults( id="v3-avg-repair-price-state-no-data-with-join-key", submitted_query="...",