diff --git a/datajunction-server/datajunction_server/api/djsql.py b/datajunction-server/datajunction_server/api/djsql.py index 8271a8282..fa6d8d161 100644 --- a/datajunction-server/datajunction_server/api/djsql.py +++ b/datajunction-server/datajunction_server/api/djsql.py @@ -177,12 +177,11 @@ async def get_sql_for_djsql( metrics=metrics, dimensions=dimensions, filters=filters, + orderby=orderby if orderby else None, + limit=limit, dialect=dialect_enum, ) - # TODO: Apply ORDER BY and LIMIT to the generated SQL if needed - # The v3 builder doesn't currently support these directly - return TranslatedDJSQL( sql=result.sql, columns=[ diff --git a/datajunction-server/tests/api/djql_test.py b/datajunction-server/tests/api/djql_test.py index 1eb90977e..07b1c89e7 100644 --- a/datajunction-server/tests/api/djql_test.py +++ b/datajunction-server/tests/api/djql_test.py @@ -5,6 +5,7 @@ import pytest from httpx import AsyncClient +from tests.construction.build_v3 import assert_sql_equal from tests.sql.utils import assert_query_strings_equal, compare_query_strings @@ -393,6 +394,86 @@ async def test_get_djsql_illegal_limit( ) +@pytest.mark.asyncio +async def test_get_djsql_with_orderby_and_limit( + module__client_with_roads: AsyncClient, +) -> None: + """ + Test that /djsql/ correctly applies ORDER BY and LIMIT to generated SQL. + """ + query = """ + SELECT + default.avg_repair_price, + default.hard_hat.country + FROM metrics + GROUP BY default.hard_hat.country + ORDER BY default.hard_hat.country DESC + LIMIT 5 + """ + + response = await module__client_with_roads.get( + "/djsql/", + params={"query": query}, + ) + assert response.status_code == 200 + + data = response.json() + generated_sql = data["sql"] + + # Verify SQL structure using assert_sql_equal + assert_sql_equal( + generated_sql, + """ + WITH + default_hard_hat AS ( + SELECT hard_hat_id, country + FROM default.roads.hard_hats + ), + default_repair_orders_fact AS ( + SELECT + repair_orders.hard_hat_id, + repair_order_details.price + FROM default.roads.repair_orders repair_orders + JOIN default.roads.repair_order_details repair_order_details + ON repair_orders.repair_order_id = repair_order_details.repair_order_id + ), + repair_orders_fact_0 AS ( + SELECT + t2.country, + COUNT(t1.price) price_count_HASH, + SUM(t1.price) price_sum_HASH + FROM default_repair_orders_fact t1 + INNER JOIN default_hard_hat t2 ON t1.hard_hat_id = t2.hard_hat_id + GROUP BY t2.country + ) + SELECT + repair_orders_fact_0.country AS country, + SUM(repair_orders_fact_0.price_sum_HASH) / SUM(repair_orders_fact_0.price_count_HASH) AS avg_repair_price + FROM repair_orders_fact_0 + GROUP BY repair_orders_fact_0.country + ORDER BY country DESC + LIMIT 5 + """, + normalize_aliases=True, + ) + + # Verify columns are returned + assert data["columns"] == [ + { + "name": "country", + "type": "string", + "semantic_name": "default.hard_hat.country", + "semantic_type": "dimension", + }, + { + "name": "avg_repair_price", + "type": "double", + "semantic_name": "default.avg_repair_price", + "semantic_type": "metric", + }, + ] + + @pytest.mark.asyncio async def test_get_djsql_no_nodes( module__client_with_roads: AsyncClient,