From 74d436cc2c670e5aab6b34e83bcc6a14978147c0 Mon Sep 17 00:00:00 2001 From: Jess Hester Date: Wed, 8 Apr 2026 13:54:14 -0500 Subject: [PATCH] Add a cube parameter to the python client plan API. --- .../python/datajunction/client.py | 5 ++ .../python/datajunction/mcp/tools.py | 5 ++ .../python/tests/mcp/test_tools.py | 81 +++++++++++++++++++ .../python/tests/test_client.py | 14 ++++ 4 files changed, 105 insertions(+) diff --git a/datajunction-clients/python/datajunction/client.py b/datajunction-clients/python/datajunction/client.py index a550cffcc..4897f589f 100644 --- a/datajunction-clients/python/datajunction/client.py +++ b/datajunction-clients/python/datajunction/client.py @@ -201,6 +201,7 @@ def plan( metrics: List[str], dimensions: Optional[List[str]] = None, filters: Optional[List[str]] = None, + cube: Optional[str] = None, dialect: Optional[str] = None, use_materialized: bool = True, include_temporal_filters: bool = False, @@ -221,6 +222,8 @@ def plan( metrics: List of metric names to include dimensions: List of dimensions to group by filters: List of filter expressions + cube: Optional cube node name. When provided, the cube's stored + filters are automatically prepended to the query filters. dialect: SQL dialect (e.g., 'spark', 'trino'). Defaults to engine dialect. use_materialized: Whether to use materialized tables when available include_temporal_filters: Whether to include temporal partition filters. @@ -236,6 +239,8 @@ def plan( "use_materialized": use_materialized, "include_temporal_filters": include_temporal_filters, } + if cube is not None: + params["cube"] = cube if lookback_window is not None: params["lookback_window"] = lookback_window effective_dialect = dialect or self.engine_name diff --git a/datajunction-clients/python/datajunction/mcp/tools.py b/datajunction-clients/python/datajunction/mcp/tools.py index dafe5f8af..0b78e9e23 100644 --- a/datajunction-clients/python/datajunction/mcp/tools.py +++ b/datajunction-clients/python/datajunction/mcp/tools.py @@ -678,6 +678,7 @@ async def get_query_plan( metrics: List[str], dimensions: Optional[List[str]] = None, filters: Optional[List[str]] = None, + cube: Optional[str] = None, dialect: Optional[str] = None, use_materialized: bool = True, include_temporal_filters: bool = False, @@ -694,6 +695,8 @@ async def get_query_plan( metrics: List of metric node names to analyze dimensions: Optional list of dimensions to group by filters: Optional list of SQL filter conditions + cube: Optional cube node name. When provided, the cube's stored filters + are automatically prepended to the query filters. dialect: Optional SQL dialect (e.g., 'spark', 'trino', 'postgres') use_materialized: Whether to use materialized tables when available (default: True) include_temporal_filters: Whether to include temporal partition filters (default: False) @@ -713,6 +716,8 @@ async def get_query_plan( "use_materialized": use_materialized, "include_temporal_filters": include_temporal_filters, } + if cube: + params["cube"] = cube if dialect: params["dialect"] = dialect if lookback_window: diff --git a/datajunction-clients/python/tests/mcp/test_tools.py b/datajunction-clients/python/tests/mcp/test_tools.py index cd601d975..7895e6b28 100644 --- a/datajunction-clients/python/tests/mcp/test_tools.py +++ b/datajunction-clients/python/tests/mcp/test_tools.py @@ -2848,6 +2848,87 @@ async def test_get_query_plan_with_dialect_and_lookback(): assert params.get("include_temporal_filters") is True +@pytest.mark.asyncio +async def test_get_query_plan_with_cube(): + """Test get_query_plan passes cube param when provided""" + mock_response_json = { + "dialect": "spark", + "requested_dimensions": [], + "grain_groups": [], + "metric_formulas": [], + } + + mock_http_response = MagicMock() + mock_http_response.status_code = 200 + mock_http_response.json.return_value = mock_response_json + mock_http_response.raise_for_status = MagicMock() + + with ( + patch.object(tools, "get_client") as mock_get_client, + patch("httpx.AsyncClient") as mock_client_class, + ): + mock_client = AsyncMock() + mock_client._ensure_token = AsyncMock() + mock_client.settings = MagicMock( + dj_api_url="http://localhost:8000", + request_timeout=30.0, + ) + mock_client._get_headers = MagicMock(return_value={}) + mock_get_client.return_value = mock_client + + mock_http_client = AsyncMock() + mock_http_client.get.return_value = mock_http_response + mock_client_class.return_value.__aenter__.return_value = mock_http_client + + await tools.get_query_plan( + metrics=["finance.revenue"], + cube="default.my_cube", + ) + + call_kwargs = mock_http_client.get.call_args + params = call_kwargs.kwargs.get("params") or call_kwargs[1].get("params", {}) + assert params.get("cube") == "default.my_cube" + + +@pytest.mark.asyncio +async def test_get_query_plan_no_cube_not_in_params(): + """Test get_query_plan omits cube param when not provided""" + mock_response_json = { + "dialect": "spark", + "requested_dimensions": [], + "grain_groups": [], + "metric_formulas": [], + } + + mock_http_response = MagicMock() + mock_http_response.status_code = 200 + mock_http_response.json.return_value = mock_response_json + mock_http_response.raise_for_status = MagicMock() + + with ( + patch.object(tools, "get_client") as mock_get_client, + patch("httpx.AsyncClient") as mock_client_class, + ): + mock_client = AsyncMock() + mock_client._ensure_token = AsyncMock() + mock_client.settings = MagicMock( + dj_api_url="http://localhost:8000", + request_timeout=30.0, + ) + mock_client._get_headers = MagicMock(return_value={}) + mock_get_client.return_value = mock_client + + mock_http_client = AsyncMock() + mock_http_client.get.return_value = mock_http_response + mock_client_class.return_value.__aenter__.return_value = mock_http_client + + await tools.get_query_plan(metrics=["finance.revenue"]) + + call_kwargs = mock_http_client.get.call_args + params = call_kwargs.kwargs.get("params") or call_kwargs[1].get("params", {}) + assert "cube" not in params + + @pytest.mark.asyncio async def test_get_query_plan_no_dialect_or_lookback_not_in_params(): """Test get_query_plan omits dialect and lookback_window when not provided""" diff --git a/datajunction-clients/python/tests/test_client.py b/datajunction-clients/python/tests/test_client.py index 97f22cfb0..1e5849f64 100644 --- a/datajunction-clients/python/tests/test_client.py +++ b/datajunction-clients/python/tests/test_client.py @@ -552,6 +552,20 @@ def test_plan(self, client): ) assert "message" in result or "detail" in result + def test_plan_with_cube(self, client): + """ + Test query execution plan retrieval with cube parameter + """ + result = client.plan( + metrics=["default.num_repair_orders"], + dimensions=["default.municipality_dim.local_region"], + cube="default.cube_two", + ) + assert isinstance(result, dict) + assert "grain_groups" in result + assert "metric_formulas" in result + assert "requested_dimensions" in result + def test_plan_with_temporal_filters(self, client): """ Test query execution plan retrieval with temporal filter parameters