Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 65 additions & 16 deletions superset/mcp_service/chart/chart_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
generation that can be used by both generate_chart and generate_explore_link tools.
"""

import hashlib
import logging
from dataclasses import dataclass
from typing import Any, Dict
Expand Down Expand Up @@ -488,10 +489,28 @@ def create_metric_object(col: ColumnRef) -> Dict[str, Any] | str:

For saved metrics, returns the metric name as a plain string which
Superset's query engine resolves via its metrics_by_name lookup.
For ad-hoc metrics, returns a SIMPLE expression dict.
For custom SQL metrics, returns a SQL adhoc dict (expressionType="SQL").
For ad-hoc column metrics, returns a SIMPLE expression dict.
"""
if col.sql_expression:
return {
"aggregate": None,
"column": None,
"expressionType": "SQL",
"sqlExpression": col.sql_expression,
"label": col.label,
"optionName": (
"metric_sql_"
+ hashlib.md5(
col.sql_expression.encode("utf-8"), usedforsecurity=False
).hexdigest()[:8]
),
"hasCustomLabel": True,
"datasourceWarning": False,
}

if col.saved_metric:
return col.name
return col.name # type: ignore[return-value]

# Ensure aggregate is valid - default to SUM if not specified or invalid
valid_aggregates = {
Expand Down Expand Up @@ -684,17 +703,20 @@ def _add_xy_limits(form_data: Dict[str, Any], config: XYChartConfig) -> None:
form_data["series_limit"] = config.series_limit


def map_xy_config(
def map_xy_config( # noqa: C901
config: XYChartConfig, dataset_id: int | str | None = None
) -> Dict[str, Any]:
"""Map XY chart config to form_data with defensive validation."""
# Early validation to prevent empty charts
if not config.y:
raise ValueError("XY chart must have at least one Y-axis metric")

# Resolve x-axis default: use dataset's main_dttm_col when x is omitted
# Resolve x-axis default: use dataset's main_dttm_col when x is omitted.
config = _resolve_default_x_axis(config, dataset_id)
assert config.x is not None # _resolve_default_x_axis guarantees x is set

# ``_resolve_default_x_axis`` guarantees x is set.
if config.x is None or config.x.name is None:
raise ValueError("XY chart requires an x-axis with a resolvable column name")

# Check if x-axis column is truly temporal (based on actual SQL type)
x_is_temporal = is_column_truly_temporal(config.x.name, dataset_id)
Expand All @@ -719,7 +741,8 @@ def map_xy_config(
# Convert Y columns to metrics with validation
metrics = []
for col in config.y:
if not col.name.strip(): # Validate column name is not empty
# SQL metrics carry sql_expression instead of name.
if not col.sql_expression and not (col.name and col.name.strip()):
raise ValueError("Y-axis column name cannot be empty")
metrics.append(create_metric_object(col))

Expand Down Expand Up @@ -972,7 +995,9 @@ def map_mixed_timeseries_config(
if not config.y_secondary:
raise ValueError("Mixed timeseries must have at least one secondary metric")

# Check if x-axis column is truly temporal
# x rejects sql_expression at validation, so name is set.
if config.x.name is None:
raise ValueError("Mixed timeseries chart requires an x-axis column name")
x_is_temporal = is_column_truly_temporal(config.x.name, dataset_id)

form_data: Dict[str, Any] = {
Expand Down Expand Up @@ -1052,7 +1077,9 @@ def _humanize_column(col: ColumnRef) -> str:
"""Return a human-readable label for a column reference."""
if col.label:
return col.label
name = col.name.replace("_", " ").title()
if col.sql_expression:
return col.sql_expression
name = (col.name or "").replace("_", " ").title()
if col.saved_metric:
return name
if col.aggregate:
Expand Down Expand Up @@ -1144,21 +1171,32 @@ def _xy_chart_context(config: XYChartConfig) -> str | None:
def _pie_chart_what(config: PieChartConfig) -> str:
"""Build the 'what' portion for a pie chart name."""
dim = config.dimension.name
metric_label = config.metric.label or config.metric.name
metric_label = (
config.metric.label or config.metric.name or config.metric.sql_expression
)
return f"{dim} by {metric_label}"


def _pivot_table_what(config: PivotTableChartConfig) -> str:
"""Build the 'what' portion for a pivot table chart name."""
row_names = ", ".join(r.name for r in config.rows)
# Pivot rows reject sql_expression at validation, so name is set.
row_names = ", ".join(r.name or "" for r in config.rows)
return f"Pivot Table \u2013 {row_names}"


def _mixed_timeseries_what(config: MixedTimeseriesChartConfig) -> str:
"""Build the 'what' portion for a mixed timeseries chart name."""
primary = config.y[0].label or config.y[0].name if config.y else "primary"
primary = (
(config.y[0].label or config.y[0].name or config.y[0].sql_expression)
if config.y
else "primary"
)
secondary = (
config.y_secondary[0].label or config.y_secondary[0].name
(
config.y_secondary[0].label
or config.y_secondary[0].name
or config.y_secondary[0].sql_expression
)
if config.y_secondary
else "secondary"
)
Expand All @@ -1172,10 +1210,16 @@ def _handlebars_chart_what(config: HandlebarsChartConfig) -> str:
``generate_chart_name``'s ``\u2013`` context separator.
"""
if config.query_mode == "raw" and config.columns:
cols = ", ".join(col.name for col in config.columns[:3])
# Raw columns reject sql_expression at validation, so col.name is set.
cols = ", ".join(col.name or "" for col in config.columns[:3])
return f"Handlebars ({cols})"
elif config.metrics:
metrics = ", ".join(col.name for col in config.metrics[:3])
# Prefer raw column name for back-compat with existing chart names;
# SQL metrics fall back to label, then the expression itself.
metrics = ", ".join(
col.name or col.label or col.sql_expression or ""
for col in config.metrics[:3]
)
return f"Handlebars ({metrics})"
return "Handlebars Chart"

Expand All @@ -1188,10 +1232,12 @@ def _big_number_chart_what(config: BigNumberChartConfig) -> str:
"""
if config.metric.label:
metric_label = config.metric.label
elif config.metric.sql_expression:
metric_label = config.metric.sql_expression
elif config.metric.aggregate:
metric_label = f"{config.metric.aggregate}({config.metric.name})"
else:
metric_label = config.metric.name
metric_label = config.metric.name or ""
if config.show_trendline:
return f"Big Number ({metric_label}, trendline)"
return f"Big Number ({metric_label})"
Expand Down Expand Up @@ -1390,7 +1436,10 @@ def analyze_chart_semantics(chart: Any | None, config: Any) -> ChartSemantics:
if hasattr(config, "x") and config.x:
columns.append(config.x.name)
if hasattr(config, "y") and config.y:
columns.extend([col.name for col in config.y])
# SQL metrics have no name; fall back to label or the expression.
columns.extend(
[col.name or col.label or col.sql_expression for col in config.y]
)

if columns:
ellipsis = "..." if len(columns) > 3 else ""
Expand Down
Loading
Loading