diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index 81c6bd1f0886..7c2534fa449d 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -144,6 +144,10 @@ def get_default_instructions( - execute_sql: Execute SQL queries and get results (requires database_id) - save_sql_query: Save a SQL query to Saved Queries list - open_sql_lab_with_context: Generate SQL Lab URL with pre-filled sql +- list_saved_queries: List saved SQL queries with filtering and search (1-based pagination) +- get_saved_query_info: Get saved query details by ID or UUID +- list_queries: List SQL query history with filtering and search (most recent first) +- get_query_info: Get SQL query history details by ID Schema Discovery: - get_schema: Get schema metadata for chart/dataset/dashboard (columns, filters) @@ -636,6 +640,14 @@ def create_mcp_app( from superset.mcp_service.explore.tool import ( # noqa: F401, E402 generate_explore_link, ) +from superset.mcp_service.query.tool import ( # noqa: F401, E402 + get_query_info, + list_queries, +) +from superset.mcp_service.saved_query.tool import ( # noqa: F401, E402 + get_saved_query_info, + list_saved_queries, +) from superset.mcp_service.sql_lab.tool import ( # noqa: F401, E402 execute_sql, open_sql_lab_with_context, diff --git a/superset/mcp_service/query/__init__.py b/superset/mcp_service/query/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/superset/mcp_service/query/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/mcp_service/query/schemas.py b/superset/mcp_service/query/schemas.py new file mode 100644 index 000000000000..07c4bdcab38d --- /dev/null +++ b/superset/mcp_service/query/schemas.py @@ -0,0 +1,290 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Pydantic schemas for query history-related responses +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Annotated, Any, Dict, List, Literal + +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, + model_serializer, + model_validator, + PositiveInt, +) + +from superset.daos.base import ColumnOperator, ColumnOperatorEnum +from superset.mcp_service.constants import MAX_PAGE_SIZE +from superset.mcp_service.privacy import filter_user_directory_fields +from superset.mcp_service.system.schemas import PaginationInfo +from superset.mcp_service.utils.schema_utils import ( + parse_json_or_list, + parse_json_or_model_list, +) + +DEFAULT_QUERY_COLUMNS = ["id", "sql", "status", "start_time", "database_id", "schema"] +SORTABLE_QUERY_COLUMNS = [ + "id", + "start_time", + "end_time", + "status", + "database_id", +] +ALL_QUERY_COLUMNS = [ + "id", + "sql", + "status", + "start_time", + "end_time", + "rows", + "database_id", + "schema", + "tab_name", + "error_message", + "client_id", + "limit", + "progress", + "changed_on", +] + +DEFAULT_QUERY_PAGE_SIZE = 25 + + +class QueryFilter(ColumnOperator): + """ + Filter object for query history listing. + col: The column to filter on. Must be one of the allowed filter fields. + opr: The operator to use. Must be one of the supported operators. + value: The value to filter by (type depends on col and opr). + """ + + col: Literal["status", "database_id", "schema"] = Field( + ..., + description="Column to filter on.", + ) + opr: ColumnOperatorEnum = Field( + ..., + description="Operator to use.", + ) + value: str | int | float | bool | List[str | int | float | bool] = Field( + ..., description="Value to filter by (type depends on col and opr)" + ) + + +class QueryInfo(BaseModel): + id: int | None = Field(None, description="Query ID") + sql: str | None = Field(None, description="SQL query text") + status: str | None = Field(None, description="Query execution status") + start_time: float | None = Field( + None, description="Query start time (seconds since epoch)" + ) + end_time: float | None = Field( + None, description="Query end time (seconds since epoch)" + ) + rows: int | None = Field(None, description="Number of rows returned or affected") + database_id: int | None = Field(None, description="Database connection ID") + schema: str | None = Field(None, description="Database schema name") + tab_name: str | None = Field(None, description="SQL Lab tab name") + error_message: str | None = Field(None, description="Error message if query failed") + client_id: str | None = Field(None, description="Client-assigned query identifier") + limit: int | None = Field(None, description="Row limit applied to the query") + progress: int | None = Field(None, description="Query execution progress (0-100)") + changed_on: str | datetime | None = Field( + None, description="Last modification timestamp" + ) + model_config = ConfigDict( + from_attributes=True, + ser_json_timedelta="iso8601", + populate_by_name=True, + ) + + @model_serializer(mode="wrap") + def _filter_fields_by_context(self, serializer: Any, info: Any) -> Dict[str, Any]: + data = filter_user_directory_fields(serializer(self)) + + if info.context and isinstance(info.context, dict): + select_columns = info.context.get("select_columns") + if select_columns: + requested_fields = set(select_columns) + return {k: v for k, v in data.items() if k in requested_fields} + + return data + + +class QueryList(BaseModel): + queries: List[QueryInfo] + count: int + total_count: int + page: int + page_size: int + total_pages: int + has_previous: bool + has_next: bool + columns_requested: List[str] = Field( + default_factory=list, + description="Requested columns for the response", + ) + columns_loaded: List[str] = Field( + default_factory=list, + description="Columns that were actually loaded for each query", + ) + columns_available: List[str] = Field( + default_factory=list, + description="All columns available for selection via select_columns parameter", + ) + sortable_columns: List[str] = Field( + default_factory=list, + description="Columns that can be used with order_column parameter", + ) + filters_applied: List[QueryFilter] = Field( + default_factory=list, + description="List of advanced filter dicts applied to the query.", + ) + pagination: PaginationInfo | None = None + timestamp: datetime | None = None + model_config = ConfigDict(ser_json_timedelta="iso8601") + + +class ListQueriesRequest(BaseModel): + """Request schema for list_queries.""" + + filters: Annotated[ + List[QueryFilter], + Field( + default_factory=list, + description="List of filter objects (column, operator, value). Each " + "filter is an object with 'col', 'opr', and 'value' " + "properties. Cannot be used together with 'search'.", + ), + ] + select_columns: Annotated[ + List[str], + Field( + default_factory=list, + description="List of columns to select. Defaults to common columns if not " + "specified.", + ), + ] + search: Annotated[ + str | None, + Field( + default=None, + description="Text search string to match against query fields. " + "Cannot be used together with 'filters'.", + ), + ] + order_column: Annotated[ + str | None, + Field(default=None, description="Column to order results by"), + ] + order_direction: Annotated[ + Literal["asc", "desc"], + Field( + default="desc", + description="Direction to order results ('asc' or 'desc')", + ), + ] + page: Annotated[ + PositiveInt, + Field(default=1, description="Page number for pagination (1-based)"), + ] + page_size: Annotated[ + int, + Field( + default=DEFAULT_QUERY_PAGE_SIZE, + gt=0, + le=MAX_PAGE_SIZE, + description=f"Number of items per page (max {MAX_PAGE_SIZE})", + ), + ] + + @field_validator("filters", mode="before") + @classmethod + def parse_filters(cls, v: Any) -> List[QueryFilter]: + """Accept both JSON string and list of objects.""" + return parse_json_or_model_list(v, QueryFilter, "filters") + + @field_validator("select_columns", mode="before") + @classmethod + def parse_columns(cls, v: Any) -> List[str]: + """Accept JSON array, list, or comma-separated string.""" + return parse_json_or_list(v, "select_columns") + + @model_validator(mode="after") + def validate_search_and_filters(self) -> "ListQueriesRequest": + """Prevent using both search and filters simultaneously.""" + if self.search and self.filters: + raise ValueError( + "Cannot use both 'search' and 'filters' parameters simultaneously. " + "Use either 'search' for text-based searching across multiple fields, " + "or 'filters' for precise column-based filtering, but not both." + ) + return self + + +class QueryError(BaseModel): + error: str = Field(..., description="Error message") + error_type: str = Field(..., description="Type of error") + timestamp: str | datetime | None = Field(None, description="Error timestamp") + model_config = ConfigDict(ser_json_timedelta="iso8601") + + @classmethod + def create(cls, error: str, error_type: str) -> "QueryError": + """Create a standardized QueryError with timestamp.""" + from datetime import datetime, timezone + + return cls( + error=error, error_type=error_type, timestamp=datetime.now(timezone.utc) + ) + + +class GetQueryInfoRequest(BaseModel): + """Request schema for get_query_info with support for numeric ID only.""" + + identifier: Annotated[ + int, + Field(description="Query ID (numeric)"), + ] + + +def serialize_query_object(query: Any) -> QueryInfo | None: + if not query: + return None + + return QueryInfo( + id=getattr(query, "id", None), + sql=getattr(query, "sql", None), + status=getattr(query, "status", None), + start_time=getattr(query, "start_time", None), + end_time=getattr(query, "end_time", None), + rows=getattr(query, "rows", None), + database_id=getattr(query, "database_id", None), + schema=getattr(query, "schema", None), + tab_name=getattr(query, "tab_name", None), + error_message=getattr(query, "error_message", None), + client_id=getattr(query, "client_id", None), + limit=getattr(query, "limit", None), + progress=getattr(query, "progress", None), + changed_on=getattr(query, "changed_on", None), + ) diff --git a/superset/mcp_service/query/tool/__init__.py b/superset/mcp_service/query/tool/__init__.py new file mode 100644 index 000000000000..3e6edcbda474 --- /dev/null +++ b/superset/mcp_service/query/tool/__init__.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from .get_query_info import get_query_info +from .list_queries import list_queries + +__all__ = [ + "list_queries", + "get_query_info", +] diff --git a/superset/mcp_service/query/tool/get_query_info.py b/superset/mcp_service/query/tool/get_query_info.py new file mode 100644 index 000000000000..dc94a947d6cd --- /dev/null +++ b/superset/mcp_service/query/tool/get_query_info.py @@ -0,0 +1,122 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Get query info FastMCP tool + +This module contains the FastMCP tool for getting detailed information +about a specific SQL query from the query history. +""" + +import logging +from datetime import datetime, timezone + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.extensions import event_logger +from superset.mcp_service.mcp_core import ModelGetInfoCore +from superset.mcp_service.query.schemas import ( + GetQueryInfoRequest, + QueryError, + QueryInfo, + serialize_query_object, +) + +logger = logging.getLogger(__name__) + + +@tool( + tags=["discovery"], + class_permission_name="Query", + annotations=ToolAnnotations( + title="Get query info", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def get_query_info( + request: GetQueryInfoRequest, ctx: Context +) -> QueryInfo | QueryError: + """Get SQL query history details by ID. + + Returns query details including SQL text, execution status, timing, + row count, and any error messages. + + IMPORTANT FOR LLM CLIENTS: + - Use numeric ID (e.g., 123) + - To find a query ID, use the list_queries tool first + + Example usage: + ```json + { + "identifier": 123 + } + ``` + """ + await ctx.info( + "Retrieving query information: identifier=%s" % (request.identifier,) + ) + + try: + from superset.daos.query import QueryDAO + + with event_logger.log_context(action="mcp.get_query_info.lookup"): + get_tool = ModelGetInfoCore( + dao_class=QueryDAO, + output_schema=QueryInfo, + error_schema=QueryError, + serializer=serialize_query_object, + supports_slug=False, + logger=logger, + ) + + result = get_tool.run_tool(request.identifier) + + if isinstance(result, QueryInfo): + await ctx.info( + "Query information retrieved successfully: " + "query_id=%s, status=%s, database_id=%s" + % ( + result.id, + result.status, + result.database_id, + ) + ) + else: + await ctx.warning( + "Query retrieval failed: error_type=%s, error=%s" + % (result.error_type, result.error) + ) + + return result + + except Exception as e: + await ctx.error( + "Query information retrieval failed: identifier=%s, error=%s, " + "error_type=%s" + % ( + request.identifier, + str(e), + type(e).__name__, + ) + ) + return QueryError( + error=f"Failed to get query info: {str(e)}", + error_type="InternalError", + timestamp=datetime.now(timezone.utc), + ) diff --git a/superset/mcp_service/query/tool/list_queries.py b/superset/mcp_service/query/tool/list_queries.py new file mode 100644 index 000000000000..ae621de2cb64 --- /dev/null +++ b/superset/mcp_service/query/tool/list_queries.py @@ -0,0 +1,156 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +List queries FastMCP tool + +This module contains the FastMCP tool for listing SQL query history +with filtering, search, and pagination. +""" + +import logging + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.extensions import event_logger +from superset.mcp_service.mcp_core import ModelListCore +from superset.mcp_service.query.schemas import ( + ALL_QUERY_COLUMNS, + DEFAULT_QUERY_COLUMNS, + ListQueriesRequest, + QueryError, + QueryFilter, + QueryInfo, + QueryList, + serialize_query_object, + SORTABLE_QUERY_COLUMNS, +) + +logger = logging.getLogger(__name__) + +_DEFAULT_LIST_QUERIES_REQUEST = ListQueriesRequest() + + +@tool( + tags=["core"], + class_permission_name="Query", + annotations=ToolAnnotations( + title="List queries", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def list_queries( + request: ListQueriesRequest | None = None, + ctx: Context | None = None, +) -> QueryList | QueryError: + """List SQL query history with filtering and search. + + Returns recent queries executed by the current user (or all queries for + admins), including SQL text, status, timing, and database information. + Results are ordered by start_time descending (most recent first) by default. + + Sortable columns for order_column: id, start_time, end_time, status, + database_id + """ + if ctx is None: + raise RuntimeError("FastMCP context is required for list_queries") + + request = request or _DEFAULT_LIST_QUERIES_REQUEST.model_copy(deep=True) + + await ctx.info( + "Listing queries: page=%s, page_size=%s, search=%s" + % ( + request.page, + request.page_size, + request.search, + ) + ) + await ctx.debug( + "Query listing parameters: filters=%s, order_column=%s, " + "order_direction=%s, select_columns=%s" + % ( + request.filters, + request.order_column, + request.order_direction, + request.select_columns, + ) + ) + + try: + from superset.daos.query import QueryDAO + + def _serialize_query(obj: object, cols: list[str] | None) -> QueryInfo | None: + return serialize_query_object(obj) + + list_tool = ModelListCore( + dao_class=QueryDAO, + output_schema=QueryInfo, + item_serializer=_serialize_query, + filter_type=QueryFilter, + default_columns=DEFAULT_QUERY_COLUMNS, + search_columns=["tab_name", "sql"], + list_field_name="queries", + output_list_schema=QueryList, + all_columns=ALL_QUERY_COLUMNS, + sortable_columns=SORTABLE_QUERY_COLUMNS, + logger=logger, + ) + + with event_logger.log_context(action="mcp.list_queries.query"): + result = list_tool.run_tool( + filters=request.filters, + search=request.search, + select_columns=request.select_columns, + order_column=request.order_column or "start_time", + order_direction=request.order_direction, + page=max(request.page - 1, 0), + page_size=request.page_size, + ) + + await ctx.info( + "Queries listed successfully: count=%s, total_count=%s, total_pages=%s" + % ( + len(result.queries) if hasattr(result, "queries") else 0, + getattr(result, "total_count", None), + getattr(result, "total_pages", None), + ) + ) + + columns_to_filter = result.columns_requested + await ctx.debug( + "Applying field filtering via serialization context: columns=%s" + % (columns_to_filter,) + ) + with event_logger.log_context(action="mcp.list_queries.serialization"): + return result.model_dump( + mode="json", + context={"select_columns": columns_to_filter}, + ) + + except Exception as e: + await ctx.error( + "Query listing failed: page=%s, page_size=%s, error=%s, error_type=%s" + % ( + request.page, + request.page_size, + str(e), + type(e).__name__, + ) + ) + raise diff --git a/superset/mcp_service/saved_query/__init__.py b/superset/mcp_service/saved_query/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/superset/mcp_service/saved_query/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/mcp_service/saved_query/schemas.py b/superset/mcp_service/saved_query/schemas.py new file mode 100644 index 000000000000..c55298637e89 --- /dev/null +++ b/superset/mcp_service/saved_query/schemas.py @@ -0,0 +1,270 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Pydantic schemas for saved query-related responses +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Annotated, Any, Dict, List, Literal + +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, + model_serializer, + model_validator, + PositiveInt, +) + +from superset.daos.base import ColumnOperator, ColumnOperatorEnum +from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE +from superset.mcp_service.privacy import filter_user_directory_fields +from superset.mcp_service.system.schemas import PaginationInfo +from superset.mcp_service.utils.schema_utils import ( + parse_json_or_list, + parse_json_or_model_list, +) + +DEFAULT_SAVED_QUERY_COLUMNS = ["id", "label", "db_id", "schema", "uuid"] +SORTABLE_SAVED_QUERY_COLUMNS = [ + "id", + "label", + "db_id", + "schema", + "changed_on", + "created_on", +] +ALL_SAVED_QUERY_COLUMNS = [ + "id", + "label", + "db_id", + "schema", + "uuid", + "sql", + "description", + "changed_on", + "created_on", +] + + +class SavedQueryFilter(ColumnOperator): + """ + Filter object for saved query listing. + col: The column to filter on. Must be one of the allowed filter fields. + opr: The operator to use. Must be one of the supported operators. + value: The value to filter by (type depends on col and opr). + """ + + col: Literal["label", "db_id", "schema"] = Field( + ..., + description="Column to filter on.", + ) + opr: ColumnOperatorEnum = Field( + ..., + description="Operator to use.", + ) + value: str | int | float | bool | List[str | int | float | bool] = Field( + ..., description="Value to filter by (type depends on col and opr)" + ) + + +class SavedQueryInfo(BaseModel): + id: int | None = Field(None, description="Saved query ID") + uuid: str | None = Field(None, description="Saved query UUID") + label: str | None = Field(None, description="Saved query label/name") + sql: str | None = Field(None, description="SQL query text") + db_id: int | None = Field(None, description="Database connection ID") + schema: str | None = Field(None, description="Database schema name") + description: str | None = Field(None, description="User-provided description") + changed_on: str | datetime | None = Field( + None, description="Last modification timestamp" + ) + created_on: str | datetime | None = Field(None, description="Creation timestamp") + model_config = ConfigDict( + from_attributes=True, + ser_json_timedelta="iso8601", + populate_by_name=True, + ) + + @model_serializer(mode="wrap") + def _filter_fields_by_context(self, serializer: Any, info: Any) -> Dict[str, Any]: + data = filter_user_directory_fields(serializer(self)) + + if info.context and isinstance(info.context, dict): + select_columns = info.context.get("select_columns") + if select_columns: + requested_fields = set(select_columns) + return {k: v for k, v in data.items() if k in requested_fields} + + return data + + +class SavedQueryList(BaseModel): + saved_queries: List[SavedQueryInfo] + count: int + total_count: int + page: int + page_size: int + total_pages: int + has_previous: bool + has_next: bool + columns_requested: List[str] = Field( + default_factory=list, + description="Requested columns for the response", + ) + columns_loaded: List[str] = Field( + default_factory=list, + description="Columns that were actually loaded for each saved query", + ) + columns_available: List[str] = Field( + default_factory=list, + description="All columns available for selection via select_columns parameter", + ) + sortable_columns: List[str] = Field( + default_factory=list, + description="Columns that can be used with order_column parameter", + ) + filters_applied: List[SavedQueryFilter] = Field( + default_factory=list, + description="List of advanced filter dicts applied to the query.", + ) + pagination: PaginationInfo | None = None + timestamp: datetime | None = None + model_config = ConfigDict(ser_json_timedelta="iso8601") + + +class ListSavedQueriesRequest(BaseModel): + """Request schema for list_saved_queries.""" + + filters: Annotated[ + List[SavedQueryFilter], + Field( + default_factory=list, + description="List of filter objects (column, operator, value). Each " + "filter is an object with 'col', 'opr', and 'value' " + "properties. Cannot be used together with 'search'.", + ), + ] + select_columns: Annotated[ + List[str], + Field( + default_factory=list, + description="List of columns to select. Defaults to common columns if not " + "specified.", + ), + ] + search: Annotated[ + str | None, + Field( + default=None, + description="Text search string to match against saved query fields. " + "Cannot be used together with 'filters'.", + ), + ] + order_column: Annotated[ + str | None, Field(default=None, description="Column to order results by") + ] + order_direction: Annotated[ + Literal["asc", "desc"], + Field( + default="desc", description="Direction to order results ('asc' or 'desc')" + ), + ] + page: Annotated[ + PositiveInt, + Field(default=1, description="Page number for pagination (1-based)"), + ] + page_size: Annotated[ + int, + Field( + default=DEFAULT_PAGE_SIZE, + gt=0, + le=MAX_PAGE_SIZE, + description=f"Number of items per page (max {MAX_PAGE_SIZE})", + ), + ] + + @field_validator("filters", mode="before") + @classmethod + def parse_filters(cls, v: Any) -> List[SavedQueryFilter]: + """Accept both JSON string and list of objects.""" + return parse_json_or_model_list(v, SavedQueryFilter, "filters") + + @field_validator("select_columns", mode="before") + @classmethod + def parse_columns(cls, v: Any) -> List[str]: + """Accept JSON array, list, or comma-separated string.""" + return parse_json_or_list(v, "select_columns") + + @model_validator(mode="after") + def validate_search_and_filters(self) -> "ListSavedQueriesRequest": + """Prevent using both search and filters simultaneously.""" + if self.search and self.filters: + raise ValueError( + "Cannot use both 'search' and 'filters' parameters simultaneously. " + "Use either 'search' for text-based searching across multiple fields, " + "or 'filters' for precise column-based filtering, but not both." + ) + return self + + +class SavedQueryError(BaseModel): + error: str = Field(..., description="Error message") + error_type: str = Field(..., description="Type of error") + timestamp: str | datetime | None = Field(None, description="Error timestamp") + model_config = ConfigDict(ser_json_timedelta="iso8601") + + @classmethod + def create(cls, error: str, error_type: str) -> "SavedQueryError": + """Create a standardized SavedQueryError with timestamp.""" + from datetime import datetime, timezone + + return cls( + error=error, error_type=error_type, timestamp=datetime.now(timezone.utc) + ) + + +class GetSavedQueryInfoRequest(BaseModel): + """Request schema for get_saved_query_info with support for ID or UUID.""" + + identifier: Annotated[ + int | str, + Field(description="Saved query identifier - can be numeric ID or UUID string"), + ] + + +def serialize_saved_query_object(saved_query: Any) -> SavedQueryInfo | None: + if not saved_query: + return None + + return SavedQueryInfo( + id=getattr(saved_query, "id", None), + uuid=str(getattr(saved_query, "uuid", "")) + if getattr(saved_query, "uuid", None) + else None, + label=getattr(saved_query, "label", None), + sql=getattr(saved_query, "sql", None), + db_id=getattr(saved_query, "db_id", None), + schema=getattr(saved_query, "schema", None), + description=getattr(saved_query, "description", None), + changed_on=getattr(saved_query, "changed_on", None), + created_on=getattr(saved_query, "created_on", None), + ) diff --git a/superset/mcp_service/saved_query/tool/__init__.py b/superset/mcp_service/saved_query/tool/__init__.py new file mode 100644 index 000000000000..af366fd53122 --- /dev/null +++ b/superset/mcp_service/saved_query/tool/__init__.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from .get_saved_query_info import get_saved_query_info +from .list_saved_queries import list_saved_queries + +__all__ = [ + "list_saved_queries", + "get_saved_query_info", +] diff --git a/superset/mcp_service/saved_query/tool/get_saved_query_info.py b/superset/mcp_service/saved_query/tool/get_saved_query_info.py new file mode 100644 index 000000000000..9b3a1be74b22 --- /dev/null +++ b/superset/mcp_service/saved_query/tool/get_saved_query_info.py @@ -0,0 +1,129 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Get saved query info FastMCP tool + +This module contains the FastMCP tool for getting detailed information +about a specific saved SQL query. +""" + +import logging +from datetime import datetime, timezone + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.extensions import event_logger +from superset.mcp_service.mcp_core import ModelGetInfoCore +from superset.mcp_service.saved_query.schemas import ( + GetSavedQueryInfoRequest, + SavedQueryError, + SavedQueryInfo, + serialize_saved_query_object, +) + +logger = logging.getLogger(__name__) + + +@tool( + tags=["discovery"], + class_permission_name="SavedQuery", + annotations=ToolAnnotations( + title="Get saved query info", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def get_saved_query_info( + request: GetSavedQueryInfoRequest, ctx: Context +) -> SavedQueryInfo | SavedQueryError: + """Get saved query details by ID or UUID. + + Returns the full saved query including SQL text, label, database, + schema, and timestamps. + + IMPORTANT FOR LLM CLIENTS: + - Use numeric ID (e.g., 42) or UUID string (e.g., "a1b2c3d4-...") + - To find a saved query ID, use the list_saved_queries tool first + + Example usage: + ```json + { + "identifier": 42 + } + ``` + + Or with UUID: + ```json + { + "identifier": "a1b2c3d4-5678-90ab-cdef-1234567890ab" + } + ``` + """ + await ctx.info( + "Retrieving saved query information: identifier=%s" % (request.identifier,) + ) + + try: + from superset.daos.query import SavedQueryDAO + + with event_logger.log_context(action="mcp.get_saved_query_info.lookup"): + get_tool = ModelGetInfoCore( + dao_class=SavedQueryDAO, + output_schema=SavedQueryInfo, + error_schema=SavedQueryError, + serializer=serialize_saved_query_object, + supports_slug=False, + logger=logger, + ) + + result = get_tool.run_tool(request.identifier) + + if isinstance(result, SavedQueryInfo): + await ctx.info( + "Saved query information retrieved successfully: " + "saved_query_id=%s, label=%s, db_id=%s" + % ( + result.id, + result.label, + result.db_id, + ) + ) + else: + await ctx.warning( + "Saved query retrieval failed: error_type=%s, error=%s" + % (result.error_type, result.error) + ) + + return result + + except Exception as e: + await ctx.error( + "Saved query information retrieval failed: identifier=%s, error=%s, " + "error_type=%s" + % ( + request.identifier, + str(e), + type(e).__name__, + ) + ) + return SavedQueryError( + error=f"Failed to get saved query info: {str(e)}", + error_type="InternalError", + timestamp=datetime.now(timezone.utc), + ) diff --git a/superset/mcp_service/saved_query/tool/list_saved_queries.py b/superset/mcp_service/saved_query/tool/list_saved_queries.py new file mode 100644 index 000000000000..2e26bf2ce18f --- /dev/null +++ b/superset/mcp_service/saved_query/tool/list_saved_queries.py @@ -0,0 +1,159 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +List saved queries FastMCP tool + +This module contains the FastMCP tool for listing saved SQL queries +with filtering, search, and pagination. +""" + +import logging + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.extensions import event_logger +from superset.mcp_service.mcp_core import ModelListCore +from superset.mcp_service.saved_query.schemas import ( + ALL_SAVED_QUERY_COLUMNS, + DEFAULT_SAVED_QUERY_COLUMNS, + ListSavedQueriesRequest, + SavedQueryError, + SavedQueryFilter, + SavedQueryInfo, + SavedQueryList, + serialize_saved_query_object, + SORTABLE_SAVED_QUERY_COLUMNS, +) + +logger = logging.getLogger(__name__) + +_DEFAULT_LIST_SAVED_QUERIES_REQUEST = ListSavedQueriesRequest() + + +@tool( + tags=["core"], + class_permission_name="SavedQuery", + annotations=ToolAnnotations( + title="List saved queries", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def list_saved_queries( + request: ListSavedQueriesRequest | None = None, + ctx: Context | None = None, +) -> SavedQueryList | SavedQueryError: + """List saved SQL queries with filtering and search. + + Returns saved queries owned by the current user, including label, SQL, + database ID, and schema. + + Sortable columns for order_column: id, label, db_id, schema, + changed_on, created_on + """ + if ctx is None: + raise RuntimeError("FastMCP context is required for list_saved_queries") + + request = request or _DEFAULT_LIST_SAVED_QUERIES_REQUEST.model_copy(deep=True) + + await ctx.info( + "Listing saved queries: page=%s, page_size=%s, search=%s" + % ( + request.page, + request.page_size, + request.search, + ) + ) + await ctx.debug( + "Saved query listing parameters: filters=%s, order_column=%s, " + "order_direction=%s, select_columns=%s" + % ( + request.filters, + request.order_column, + request.order_direction, + request.select_columns, + ) + ) + + try: + from superset.daos.query import SavedQueryDAO + + def _serialize_saved_query( + obj: object, cols: list[str] | None + ) -> SavedQueryInfo | None: + return serialize_saved_query_object(obj) + + list_tool = ModelListCore( + dao_class=SavedQueryDAO, + output_schema=SavedQueryInfo, + item_serializer=_serialize_saved_query, + filter_type=SavedQueryFilter, + default_columns=DEFAULT_SAVED_QUERY_COLUMNS, + search_columns=["label", "description", "sql"], + list_field_name="saved_queries", + output_list_schema=SavedQueryList, + all_columns=ALL_SAVED_QUERY_COLUMNS, + sortable_columns=SORTABLE_SAVED_QUERY_COLUMNS, + logger=logger, + ) + + with event_logger.log_context(action="mcp.list_saved_queries.query"): + result = list_tool.run_tool( + filters=request.filters, + search=request.search, + select_columns=request.select_columns, + order_column=request.order_column, + order_direction=request.order_direction, + page=max(request.page - 1, 0), + page_size=request.page_size, + ) + + await ctx.info( + "Saved queries listed successfully: count=%s, total_count=%s, " + "total_pages=%s" + % ( + len(result.saved_queries) if hasattr(result, "saved_queries") else 0, + getattr(result, "total_count", None), + getattr(result, "total_pages", None), + ) + ) + + columns_to_filter = result.columns_requested + await ctx.debug( + "Applying field filtering via serialization context: columns=%s" + % (columns_to_filter,) + ) + with event_logger.log_context(action="mcp.list_saved_queries.serialization"): + return result.model_dump( + mode="json", + context={"select_columns": columns_to_filter}, + ) + + except Exception as e: + await ctx.error( + "Saved query listing failed: page=%s, page_size=%s, error=%s, " + "error_type=%s" + % ( + request.page, + request.page_size, + str(e), + type(e).__name__, + ) + ) + raise diff --git a/tests/unit_tests/mcp_service/query/__init__.py b/tests/unit_tests/mcp_service/query/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/unit_tests/mcp_service/query/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/query/tool/__init__.py b/tests/unit_tests/mcp_service/query/tool/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/unit_tests/mcp_service/query/tool/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/query/tool/test_query_tools.py b/tests/unit_tests/mcp_service/query/tool/test_query_tools.py new file mode 100644 index 000000000000..8e12d109a540 --- /dev/null +++ b/tests/unit_tests/mcp_service/query/tool/test_query_tools.py @@ -0,0 +1,271 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import logging +from unittest.mock import MagicMock, patch + +import pytest +from fastmcp import Client +from pydantic import ValidationError + +from superset.mcp_service.app import mcp +from superset.mcp_service.query.schemas import ( + ListQueriesRequest, + QueryFilter, +) +from superset.utils import json + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +class TestQueryFilterSchema: + """Tests for QueryFilter schema — filterable columns.""" + + def test_invalid_filter_column_rejected(self): + """Columns not in the Literal set must be rejected.""" + with pytest.raises(ValidationError): + QueryFilter(col="not_a_real_column", opr="eq", value="test") + + def test_user_id_is_rejected_as_filter_column(self): + """user_id is an internal field and should not be a filter column.""" + with pytest.raises(ValidationError): + QueryFilter(col="user_id", opr="eq", value=1) + + def test_valid_status_filter_accepted(self): + """status is a valid filter column.""" + f = QueryFilter(col="status", opr="eq", value="success") + assert f.col == "status" + + def test_valid_database_id_filter_accepted(self): + """database_id is a valid filter column.""" + f = QueryFilter(col="database_id", opr="eq", value=1) + assert f.col == "database_id" + + def test_valid_schema_filter_accepted(self): + """schema is a valid filter column.""" + f = QueryFilter(col="schema", opr="eq", value="public") + assert f.col == "schema" + + +def create_mock_query( + query_id: int = 1, + sql: str = "SELECT * FROM table", + status: str = "success", + start_time: float = 1700000000.0, + end_time: float = 1700000001.0, + rows: int = 100, + database_id: int = 1, + schema: str = "public", + tab_name: str = "SQL Lab 1", + error_message: str | None = None, + client_id: str = "abc123", +) -> MagicMock: + """Factory function to create mock query objects with sensible defaults.""" + query = MagicMock() + query.id = query_id + query.sql = sql + query.status = status + query.start_time = start_time + query.end_time = end_time + query.rows = rows + query.database_id = database_id + query.schema = schema + query.tab_name = tab_name + query.error_message = error_message + query.client_id = client_id + query.limit = 1000 + query.progress = 100 + query.changed_on = None + return query + + +@pytest.fixture +def mcp_server(): + return mcp + + +@pytest.fixture(autouse=True) +def mock_auth(): + """Mock authentication for all tests.""" + from unittest.mock import Mock, patch + + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user + + +@patch("superset.daos.query.QueryDAO.list") +@pytest.mark.asyncio +async def test_list_queries_basic(mock_list, mcp_server): + """Test basic query listing functionality.""" + query = create_mock_query() + query._mapping = { + "id": query.id, + "sql": query.sql, + "status": query.status, + "start_time": query.start_time, + "database_id": query.database_id, + "schema": query.schema, + } + mock_list.return_value = ([query], 1) + async with Client(mcp_server) as client: + request = ListQueriesRequest(page=1, page_size=10) + result = await client.call_tool( + "list_queries", {"request": request.model_dump()} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["queries"] is not None + assert len(data["queries"]) == 1 + assert data["queries"][0]["id"] == 1 + assert data["queries"][0]["status"] == "success" + + +@patch("superset.daos.query.QueryDAO.list") +@pytest.mark.asyncio +async def test_list_queries_with_status_filter(mock_list, mcp_server): + """Test query listing with status filter.""" + query = create_mock_query(status="failed", error_message="Syntax error") + query._mapping = { + "id": query.id, + "sql": query.sql, + "status": query.status, + "error_message": query.error_message, + } + mock_list.return_value = ([query], 1) + async with Client(mcp_server) as client: + request = ListQueriesRequest( + page=1, + page_size=10, + filters=[ + {"col": "status", "opr": "eq", "value": "failed"}, + ], + ) + result = await client.call_tool( + "list_queries", {"request": request.model_dump()} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["queries"] is not None + assert len(data["queries"]) == 1 + assert data["queries"][0]["status"] == "failed" + + +@patch("superset.daos.query.QueryDAO.list") +@pytest.mark.asyncio +async def test_list_queries_default_page_size(mock_list, mcp_server): + """Test that default page size is 25 for query history.""" + mock_list.return_value = ([], 0) + async with Client(mcp_server) as client: + result = await client.call_tool("list_queries", {}) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["page_size"] == 25 + + +def test_list_queries_request_rejects_both_search_and_filters(): + """Cannot use search and filters simultaneously.""" + with pytest.raises(ValidationError): + ListQueriesRequest( + search="test", + filters=[{"col": "status", "opr": "eq", "value": "success"}], + ) + + +@patch("superset.daos.query.QueryDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_query_info_basic(mock_find, mcp_server): + """Test basic get query info functionality.""" + query = create_mock_query() + mock_find.return_value = query + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_query_info", {"request": {"identifier": 1}} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["id"] == 1 + assert data["status"] == "success" + assert data["database_id"] == 1 + + +@patch("superset.daos.query.QueryDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_query_info_not_found(mock_find, mcp_server): + """Test get query info when query does not exist.""" + mock_find.return_value = None + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_query_info", {"request": {"identifier": 999}} + ) + assert result.data["error_type"] == "not_found" + + +@patch("superset.daos.query.QueryDAO.list") +@pytest.mark.asyncio +async def test_list_queries_empty(mock_list, mcp_server): + """Test query listing returns empty list when no results.""" + mock_list.return_value = ([], 0) + async with Client(mcp_server) as client: + request = ListQueriesRequest(page=1, page_size=10) + result = await client.call_tool( + "list_queries", {"request": request.model_dump()} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["queries"] == [] + assert data["count"] == 0 + assert data["total_count"] == 0 + + +@patch("superset.daos.query.QueryDAO.list") +@pytest.mark.asyncio +async def test_list_queries_pagination_info(mock_list, mcp_server): + """Test that pagination info is correctly returned.""" + queries = [create_mock_query(query_id=i) for i in range(1, 4)] + for q in queries: + q._mapping = {"id": q.id, "sql": q.sql, "status": q.status} + mock_list.return_value = (queries, 100) + async with Client(mcp_server) as client: + request = ListQueriesRequest(page=1, page_size=3) + result = await client.call_tool( + "list_queries", {"request": request.model_dump()} + ) + data = json.loads(result.content[0].text) + assert data["total_count"] == 100 + assert data["page_size"] == 3 + assert data["has_next"] is True + assert data["has_previous"] is False + + +@patch("superset.daos.query.QueryDAO.list") +@pytest.mark.asyncio +async def test_list_queries_default_order_is_start_time_desc(mock_list, mcp_server): + """Test that default ordering is start_time descending.""" + mock_list.return_value = ([], 0) + async with Client(mcp_server) as client: + result = await client.call_tool("list_queries", {}) + assert result.content is not None + mock_list.assert_called_once() + call_kwargs = mock_list.call_args + assert call_kwargs.kwargs.get("order_column") == "start_time" + assert call_kwargs.kwargs.get("order_direction") == "desc" diff --git a/tests/unit_tests/mcp_service/saved_query/__init__.py b/tests/unit_tests/mcp_service/saved_query/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/unit_tests/mcp_service/saved_query/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/saved_query/tool/__init__.py b/tests/unit_tests/mcp_service/saved_query/tool/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/unit_tests/mcp_service/saved_query/tool/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/saved_query/tool/test_saved_query_tools.py b/tests/unit_tests/mcp_service/saved_query/tool/test_saved_query_tools.py new file mode 100644 index 000000000000..da35d8b7607a --- /dev/null +++ b/tests/unit_tests/mcp_service/saved_query/tool/test_saved_query_tools.py @@ -0,0 +1,273 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import logging +from unittest.mock import MagicMock, patch + +import pytest +from fastmcp import Client +from pydantic import ValidationError + +from superset.mcp_service.app import mcp +from superset.mcp_service.saved_query.schemas import ( + ListSavedQueriesRequest, + SavedQueryFilter, +) +from superset.utils import json + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +class TestSavedQueryFilterSchema: + """Tests for SavedQueryFilter schema — filterable columns.""" + + def test_invalid_filter_column_rejected(self): + """Columns not in the Literal set must be rejected.""" + with pytest.raises(ValidationError): + SavedQueryFilter(col="not_a_real_column", opr="eq", value="test") + + def test_user_id_is_rejected_as_filter_column(self): + """user_id is an internal field and should not be a filter column.""" + with pytest.raises(ValidationError): + SavedQueryFilter(col="user_id", opr="eq", value=1) + + def test_valid_label_filter_accepted(self): + """label is a valid filter column.""" + f = SavedQueryFilter(col="label", opr="eq", value="my query") + assert f.col == "label" + + def test_valid_db_id_filter_accepted(self): + """db_id is a valid filter column.""" + f = SavedQueryFilter(col="db_id", opr="eq", value=1) + assert f.col == "db_id" + + def test_valid_schema_filter_accepted(self): + """schema is a valid filter column.""" + f = SavedQueryFilter(col="schema", opr="eq", value="public") + assert f.col == "schema" + + +def create_mock_saved_query( + saved_query_id: int = 1, + label: str = "My Query", + sql: str = "SELECT 1", + db_id: int = 1, + schema: str = "public", + description: str = "Test query", + uuid: str = "test-uuid-1234", +) -> MagicMock: + """Factory function to create mock saved query objects with sensible defaults.""" + saved_query = MagicMock() + saved_query.id = saved_query_id + saved_query.label = label + saved_query.sql = sql + saved_query.db_id = db_id + saved_query.schema = schema + saved_query.description = description + saved_query.uuid = uuid + saved_query.changed_on = None + saved_query.created_on = None + return saved_query + + +@pytest.fixture +def mcp_server(): + return mcp + + +@pytest.fixture(autouse=True) +def mock_auth(): + """Mock authentication for all tests.""" + from unittest.mock import Mock, patch + + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user + + +@patch("superset.daos.query.SavedQueryDAO.list") +@pytest.mark.asyncio +async def test_list_saved_queries_basic(mock_list, mcp_server): + """Test basic saved query listing functionality.""" + saved_query = create_mock_saved_query() + saved_query._mapping = { + "id": saved_query.id, + "label": saved_query.label, + "db_id": saved_query.db_id, + "schema": saved_query.schema, + "uuid": saved_query.uuid, + } + mock_list.return_value = ([saved_query], 1) + async with Client(mcp_server) as client: + request = ListSavedQueriesRequest(page=1, page_size=10) + result = await client.call_tool( + "list_saved_queries", {"request": request.model_dump()} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["saved_queries"] is not None + assert len(data["saved_queries"]) == 1 + assert data["saved_queries"][0]["id"] == 1 + assert data["saved_queries"][0]["label"] == "My Query" + + +@patch("superset.daos.query.SavedQueryDAO.list") +@pytest.mark.asyncio +async def test_list_saved_queries_with_search(mock_list, mcp_server): + """Test saved query listing with search functionality.""" + saved_query = create_mock_saved_query(label="Production Query") + saved_query._mapping = { + "id": saved_query.id, + "label": saved_query.label, + } + mock_list.return_value = ([saved_query], 1) + async with Client(mcp_server) as client: + request = ListSavedQueriesRequest(page=1, page_size=10, search="Production") + result = await client.call_tool( + "list_saved_queries", {"request": request.model_dump()} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["saved_queries"] is not None + assert len(data["saved_queries"]) == 1 + assert data["saved_queries"][0]["label"] == "Production Query" + + +@patch("superset.daos.query.SavedQueryDAO.list") +@pytest.mark.asyncio +async def test_list_saved_queries_with_filters(mock_list, mcp_server): + """Test saved query listing with filters.""" + saved_query = create_mock_saved_query(db_id=2) + saved_query._mapping = { + "id": saved_query.id, + "label": saved_query.label, + "db_id": saved_query.db_id, + } + mock_list.return_value = ([saved_query], 1) + async with Client(mcp_server) as client: + request = ListSavedQueriesRequest( + page=1, + page_size=10, + filters=[ + {"col": "db_id", "opr": "eq", "value": 2}, + ], + ) + result = await client.call_tool( + "list_saved_queries", {"request": request.model_dump()} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["saved_queries"] is not None + assert len(data["saved_queries"]) == 1 + + +def test_list_saved_queries_request_rejects_both_search_and_filters(): + """Cannot use search and filters simultaneously.""" + with pytest.raises(ValidationError): + ListSavedQueriesRequest( + search="test", + filters=[{"col": "label", "opr": "eq", "value": "test"}], + ) + + +@patch("superset.daos.query.SavedQueryDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_saved_query_info_basic(mock_find, mcp_server): + """Test basic get saved query info functionality.""" + saved_query = create_mock_saved_query() + mock_find.return_value = saved_query + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_saved_query_info", {"request": {"identifier": 1}} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["id"] == 1 + assert data["label"] == "My Query" + assert data["sql"] == "SELECT 1" + assert data["db_id"] == 1 + + +@patch("superset.daos.query.SavedQueryDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_saved_query_info_not_found(mock_find, mcp_server): + """Test get saved query info when saved query does not exist.""" + mock_find.return_value = None + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_saved_query_info", {"request": {"identifier": 999}} + ) + assert result.data["error_type"] == "not_found" + + +@patch("superset.daos.query.SavedQueryDAO.list") +@pytest.mark.asyncio +async def test_list_saved_queries_empty(mock_list, mcp_server): + """Test saved query listing returns empty list when no results.""" + mock_list.return_value = ([], 0) + async with Client(mcp_server) as client: + request = ListSavedQueriesRequest(page=1, page_size=10) + result = await client.call_tool( + "list_saved_queries", {"request": request.model_dump()} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["saved_queries"] == [] + assert data["count"] == 0 + assert data["total_count"] == 0 + + +@patch("superset.daos.query.SavedQueryDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_saved_query_info_by_uuid(mock_find, mcp_server): + """Test get saved query info by UUID string.""" + saved_query = create_mock_saved_query(uuid="a1b2c3d4-5678-90ab-cdef-1234567890ab") + mock_find.return_value = saved_query + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_saved_query_info", + {"request": {"identifier": "a1b2c3d4-5678-90ab-cdef-1234567890ab"}}, + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["id"] == 1 + assert data["uuid"] == "a1b2c3d4-5678-90ab-cdef-1234567890ab" + + +@patch("superset.daos.query.SavedQueryDAO.list") +@pytest.mark.asyncio +async def test_list_saved_queries_pagination_info(mock_list, mcp_server): + """Test that pagination info is correctly returned.""" + saved_queries = [create_mock_saved_query(saved_query_id=i) for i in range(1, 4)] + for sq in saved_queries: + sq._mapping = {"id": sq.id, "label": sq.label} + mock_list.return_value = (saved_queries, 25) + async with Client(mcp_server) as client: + request = ListSavedQueriesRequest(page=1, page_size=3) + result = await client.call_tool( + "list_saved_queries", {"request": request.model_dump()} + ) + data = json.loads(result.content[0].text) + assert data["total_count"] == 25 + assert data["page_size"] == 3 + assert data["has_next"] is True + assert data["has_previous"] is False