Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions datajunction-clients/python/datajunction/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def plan(
metrics: List[str],
dimensions: Optional[List[str]] = None,
filters: Optional[List[str]] = None,
cube: Optional[str] = None,
dialect: Optional[str] = None,
use_materialized: bool = True,
include_temporal_filters: bool = False,
Expand All @@ -221,6 +222,8 @@ def plan(
metrics: List of metric names to include
dimensions: List of dimensions to group by
filters: List of filter expressions
cube: Optional cube node name. When provided, the cube's stored
filters are automatically prepended to the query filters.
dialect: SQL dialect (e.g., 'spark', 'trino'). Defaults to engine dialect.
use_materialized: Whether to use materialized tables when available
include_temporal_filters: Whether to include temporal partition filters.
Expand All @@ -236,6 +239,8 @@ def plan(
"use_materialized": use_materialized,
"include_temporal_filters": include_temporal_filters,
}
if cube is not None:
params["cube"] = cube
if lookback_window is not None:
params["lookback_window"] = lookback_window
effective_dialect = dialect or self.engine_name
Expand Down
5 changes: 5 additions & 0 deletions datajunction-clients/python/datajunction/mcp/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,7 @@ async def get_query_plan(
metrics: List[str],
dimensions: Optional[List[str]] = None,
filters: Optional[List[str]] = None,
cube: Optional[str] = None,
dialect: Optional[str] = None,
use_materialized: bool = True,
include_temporal_filters: bool = False,
Expand All @@ -694,6 +695,8 @@ async def get_query_plan(
metrics: List of metric node names to analyze
dimensions: Optional list of dimensions to group by
filters: Optional list of SQL filter conditions
cube: Optional cube node name. When provided, the cube's stored filters
are automatically prepended to the query filters.
dialect: Optional SQL dialect (e.g., 'spark', 'trino', 'postgres')
use_materialized: Whether to use materialized tables when available (default: True)
include_temporal_filters: Whether to include temporal partition filters (default: False)
Expand All @@ -713,6 +716,8 @@ async def get_query_plan(
"use_materialized": use_materialized,
"include_temporal_filters": include_temporal_filters,
}
if cube:
params["cube"] = cube
if dialect:
params["dialect"] = dialect
if lookback_window:
Expand Down
81 changes: 81 additions & 0 deletions datajunction-clients/python/tests/mcp/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2848,6 +2848,87 @@ async def test_get_query_plan_with_dialect_and_lookback():
assert params.get("include_temporal_filters") is True


@pytest.mark.asyncio
async def test_get_query_plan_with_cube():
"""Test get_query_plan passes cube param when provided"""
mock_response_json = {
"dialect": "spark",
"requested_dimensions": [],
"grain_groups": [],
"metric_formulas": [],
}

mock_http_response = MagicMock()
mock_http_response.status_code = 200
mock_http_response.json.return_value = mock_response_json
mock_http_response.raise_for_status = MagicMock()

with (
patch.object(tools, "get_client") as mock_get_client,
patch("httpx.AsyncClient") as mock_client_class,
):
mock_client = AsyncMock()
mock_client._ensure_token = AsyncMock()
mock_client.settings = MagicMock(
dj_api_url="http://localhost:8000",
request_timeout=30.0,
)
mock_client._get_headers = MagicMock(return_value={})
mock_get_client.return_value = mock_client

mock_http_client = AsyncMock()
mock_http_client.get.return_value = mock_http_response
mock_client_class.return_value.__aenter__.return_value = mock_http_client

await tools.get_query_plan(
metrics=["finance.revenue"],
cube="default.my_cube",
)

call_kwargs = mock_http_client.get.call_args
params = call_kwargs.kwargs.get("params") or call_kwargs[1].get("params", {})
assert params.get("cube") == "default.my_cube"


@pytest.mark.asyncio
async def test_get_query_plan_no_cube_not_in_params():
"""Test get_query_plan omits cube param when not provided"""
mock_response_json = {
"dialect": "spark",
"requested_dimensions": [],
"grain_groups": [],
"metric_formulas": [],
}

mock_http_response = MagicMock()
mock_http_response.status_code = 200
mock_http_response.json.return_value = mock_response_json
mock_http_response.raise_for_status = MagicMock()

with (
patch.object(tools, "get_client") as mock_get_client,
patch("httpx.AsyncClient") as mock_client_class,
):
mock_client = AsyncMock()
mock_client._ensure_token = AsyncMock()
mock_client.settings = MagicMock(
dj_api_url="http://localhost:8000",
request_timeout=30.0,
)
mock_client._get_headers = MagicMock(return_value={})
mock_get_client.return_value = mock_client

mock_http_client = AsyncMock()
mock_http_client.get.return_value = mock_http_response
mock_client_class.return_value.__aenter__.return_value = mock_http_client

await tools.get_query_plan(metrics=["finance.revenue"])

call_kwargs = mock_http_client.get.call_args
params = call_kwargs.kwargs.get("params") or call_kwargs[1].get("params", {})
assert "cube" not in params


@pytest.mark.asyncio
async def test_get_query_plan_no_dialect_or_lookback_not_in_params():
"""Test get_query_plan omits dialect and lookback_window when not provided"""
Expand Down
14 changes: 14 additions & 0 deletions datajunction-clients/python/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,20 @@ def test_plan(self, client):
)
assert "message" in result or "detail" in result

def test_plan_with_cube(self, client):
"""
Test query execution plan retrieval with cube parameter
"""
result = client.plan(
metrics=["default.num_repair_orders"],
dimensions=["default.municipality_dim.local_region"],
cube="default.cube_two",
)
assert isinstance(result, dict)
assert "grain_groups" in result
assert "metric_formulas" in result
assert "requested_dimensions" in result

def test_plan_with_temporal_filters(self, client):
"""
Test query execution plan retrieval with temporal filter parameters
Expand Down
Loading