diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 10f19aa00f23a..da7c5cf56981b 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -322,15 +322,11 @@ def _convert_measure( ) -> proto.Aggregate.AggregateFunction: exp, fun = m measure = proto.Aggregate.AggregateFunction() - measure.function.name = fun # type: ignore[attr-defined] + measure.name = fun if type(exp) is str: - measure.function.arguments.append( # type: ignore[attr-defined] - self.unresolved_attr(exp) - ) + measure.arguments.append(self.unresolved_attr(exp)) else: - measure.function.arguments.append( # type: ignore[attr-defined] - cast(Expression, exp).to_plan(session) - ) + measure.arguments.append(cast(Expression, exp).to_plan(session)) return measure def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: @@ -339,13 +335,11 @@ def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: agg = proto.Relation() agg.aggregate.input.CopyFrom(self._child.plan(session)) - agg.aggregate.measures.extend( # type: ignore[attr-defined] + agg.aggregate.result_expressions.extend( list(map(lambda x: self._convert_measure(x, session), self.measures)) ) - gs = proto.Aggregate.GroupingSet() # type: ignore[attr-defined] - gs.aggregate_expressions.extend(groupings) - agg.aggregate.grouping_sets.append(gs) # type: ignore[attr-defined] + agg.aggregate.grouping_expressions.extend(groupings) return agg def print(self, indent: int = 0) -> str: