Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions datajunction-server/datajunction_server/api/djsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,11 @@ async def get_sql_for_djsql(
metrics=metrics,
dimensions=dimensions,
filters=filters,
orderby=orderby if orderby else None,
limit=limit,
dialect=dialect_enum,
)

# TODO: Apply ORDER BY and LIMIT to the generated SQL if needed
# The v3 builder doesn't currently support these directly

return TranslatedDJSQL(
sql=result.sql,
columns=[
Expand Down
81 changes: 81 additions & 0 deletions datajunction-server/tests/api/djql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from httpx import AsyncClient

from tests.construction.build_v3 import assert_sql_equal
from tests.sql.utils import assert_query_strings_equal, compare_query_strings


Expand Down Expand Up @@ -393,6 +394,86 @@ async def test_get_djsql_illegal_limit(
)


@pytest.mark.asyncio
async def test_get_djsql_with_orderby_and_limit(
module__client_with_roads: AsyncClient,
) -> None:
"""
Test that /djsql/ correctly applies ORDER BY and LIMIT to generated SQL.
"""
query = """
SELECT
default.avg_repair_price,
default.hard_hat.country
FROM metrics
GROUP BY default.hard_hat.country
ORDER BY default.hard_hat.country DESC
LIMIT 5
"""

response = await module__client_with_roads.get(
"/djsql/",
params={"query": query},
)
assert response.status_code == 200

data = response.json()
generated_sql = data["sql"]

# Verify SQL structure using assert_sql_equal
assert_sql_equal(
generated_sql,
"""
WITH
default_hard_hat AS (
SELECT hard_hat_id, country
FROM default.roads.hard_hats
),
default_repair_orders_fact AS (
SELECT
repair_orders.hard_hat_id,
repair_order_details.price
FROM default.roads.repair_orders repair_orders
JOIN default.roads.repair_order_details repair_order_details
ON repair_orders.repair_order_id = repair_order_details.repair_order_id
),
repair_orders_fact_0 AS (
SELECT
t2.country,
COUNT(t1.price) price_count_HASH,
SUM(t1.price) price_sum_HASH
FROM default_repair_orders_fact t1
INNER JOIN default_hard_hat t2 ON t1.hard_hat_id = t2.hard_hat_id
GROUP BY t2.country
)
SELECT
repair_orders_fact_0.country AS country,
SUM(repair_orders_fact_0.price_sum_HASH) / SUM(repair_orders_fact_0.price_count_HASH) AS avg_repair_price
FROM repair_orders_fact_0
GROUP BY repair_orders_fact_0.country
ORDER BY country DESC
LIMIT 5
""",
normalize_aliases=True,
)

# Verify columns are returned
assert data["columns"] == [
{
"name": "country",
"type": "string",
"semantic_name": "default.hard_hat.country",
"semantic_type": "dimension",
},
{
"name": "avg_repair_price",
"type": "double",
"semantic_name": "default.avg_repair_price",
"semantic_type": "metric",
},
]


@pytest.mark.asyncio
async def test_get_djsql_no_nodes(
module__client_with_roads: AsyncClient,
Expand Down
Loading