Skip to content

Commit

Permalink
[SPARK-41222][CONNECT][PYTHON] Unify the typing definitions
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
1, remove `__init__.py`
2, rename `ColumnOrString ` as `ColumnOrName` to be the same as pyspark

### Why are the changes needed?
1, there are two typing files now: `_typing.py` and `__init__.py`, they are used in different files, which is very confusing;
2, the definitions of `LiteralType` are different, the old one in `_typing.py` was never used
3, both `ColumnOrString ` and `ColumnOrName` are used now;

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
existing UTs

Closes #38757 from zhengruifeng/connect_typing.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
zhengruifeng committed Nov 24, 2022
1 parent 957b0bc commit 381dd79
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 71 deletions.
41 changes: 37 additions & 4 deletions python/pyspark/sql/connect/_typing.py
Expand Up @@ -14,8 +14,41 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Union
from datetime import date, time, datetime

PrimitiveType = Union[str, int, bool, float]
LiteralType = Union[PrimitiveType, Union[date, time, datetime]]
import sys

if sys.version_info >= (3, 8):
from typing import Protocol
else:
from typing_extensions import Protocol

from typing import Union, Optional
import datetime
import decimal

from pyspark.sql.connect.column import ScalarFunctionExpression, Expression, Column
from pyspark.sql.connect.function_builder import UserDefinedFunction

ExpressionOrString = Union[Expression, str]

ColumnOrName = Union[Column, str]

PrimitiveType = Union[bool, float, int, str]

OptionalPrimitiveType = Optional[PrimitiveType]

LiteralType = PrimitiveType

DecimalLiteral = decimal.Decimal

DateTimeLiteral = Union[datetime.datetime, datetime.date]


class FunctionBuilderCallable(Protocol):
def __call__(self, *_: ExpressionOrString) -> ScalarFunctionExpression:
...


class UserDefinedFunctionCallable(Protocol):
def __call__(self, *_: ColumnOrName) -> UserDefinedFunction:
...
25 changes: 10 additions & 15 deletions python/pyspark/sql/connect/client.py
Expand Up @@ -18,7 +18,6 @@

import logging
import os
import typing
import urllib.parse
import uuid

Expand All @@ -35,9 +34,7 @@
from pyspark.sql.connect.plan import SQL, Range
from pyspark.sql.types import DataType, StructType, StructField, LongType, StringType

from typing import Optional, Any, Union

NumericType = typing.Union[int, float]
from typing import Iterable, Optional, Any, Union, List, Tuple, Dict

logging.basicConfig(level=logging.INFO)

Expand Down Expand Up @@ -74,7 +71,7 @@ def __init__(self, url: str) -> None:
# Python's built-in parser.
tmp_url = "http" + url[2:]
self.url = urllib.parse.urlparse(tmp_url)
self.params: typing.Dict[str, str] = {}
self.params: Dict[str, str] = {}
if len(self.url.path) > 0 and self.url.path != "/":
raise AttributeError(
f"Path component for connection URI must be empty: {self.url.path}"
Expand Down Expand Up @@ -102,7 +99,7 @@ def _extract_attributes(self) -> None:
f"Target destination {self.url.netloc} does not match '<host>:<port>' pattern"
)

def metadata(self) -> typing.Iterable[typing.Tuple[str, str]]:
def metadata(self) -> Iterable[Tuple[str, str]]:
"""
Builds the GRPC specific metadata list to be injected into the request. All
parameters will be converted to metadata except ones that are explicitly used
Expand Down Expand Up @@ -198,7 +195,7 @@ def toChannel(self) -> grpc.Channel:


class MetricValue:
def __init__(self, name: str, value: NumericType, type: str):
def __init__(self, name: str, value: Union[int, float], type: str):
self._name = name
self._type = type
self._value = value
Expand All @@ -211,7 +208,7 @@ def name(self) -> str:
return self._name

@property
def value(self) -> NumericType:
def value(self) -> Union[int, float]:
return self._value

@property
Expand All @@ -220,7 +217,7 @@ def metric_type(self) -> str:


class PlanMetrics:
def __init__(self, name: str, id: int, parent: int, metrics: typing.List[MetricValue]):
def __init__(self, name: str, id: int, parent: int, metrics: List[MetricValue]):
self._name = name
self._id = id
self._parent_id = parent
Expand All @@ -242,7 +239,7 @@ def parent_plan_id(self) -> int:
return self._parent_id

@property
def metrics(self) -> typing.List[MetricValue]:
def metrics(self) -> List[MetricValue]:
return self._metrics


Expand All @@ -252,7 +249,7 @@ def __init__(self, schema: pb2.DataType, explain: str):
self.explain_string = explain

@classmethod
def fromProto(cls, pb: typing.Any) -> "AnalyzeResult":
def fromProto(cls, pb: Any) -> "AnalyzeResult":
return AnalyzeResult(pb.schema, pb.explain_string)


Expand Down Expand Up @@ -306,9 +303,7 @@ def register_udf(
self._execute_and_fetch(req)
return name

def _build_metrics(
self, metrics: "pb2.ExecutePlanResponse.Metrics"
) -> typing.List[PlanMetrics]:
def _build_metrics(self, metrics: "pb2.ExecutePlanResponse.Metrics") -> List[PlanMetrics]:
return [
PlanMetrics(
x.name,
Expand Down Expand Up @@ -450,7 +445,7 @@ def _process_batch(self, b: pb2.ExecutePlanResponse) -> Optional[pandas.DataFram
return rd.read_pandas()
return None

def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> typing.Optional[pandas.DataFrame]:
def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> Optional[pandas.DataFrame]:
import pandas as pd

m: Optional[pb2.ExecutePlanResponse.Metrics] = None
Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/sql/connect/column.py
Expand Up @@ -22,7 +22,6 @@
import datetime

import pyspark.sql.connect.proto as proto
from pyspark.sql.connect._typing import PrimitiveType

if TYPE_CHECKING:
from pyspark.sql.connect.client import RemoteSparkSession
Expand All @@ -33,6 +32,8 @@ def _bin_op(
name: str, doc: str = "binary function", reverse: bool = False
) -> Callable[["Column", Any], "Expression"]:
def _(self: "Column", other: Any) -> "Expression":
from pyspark.sql.connect._typing import PrimitiveType

if isinstance(other, get_args(PrimitiveType)):
other = LiteralExpression(other)
if not reverse:
Expand Down Expand Up @@ -70,6 +71,8 @@ def __eq__(self, other: Any) -> "Expression": # type: ignore[override]
"""Returns a binary expression with the current column as the left
side and the other expression as the right side.
"""
from pyspark.sql.connect._typing import PrimitiveType

if isinstance(other, get_args(PrimitiveType)):
other = LiteralExpression(other)
return ScalarFunctionExpression("==", self, other)
Expand Down
12 changes: 5 additions & 7 deletions python/pyspark/sql/connect/dataframe.py
Expand Up @@ -44,11 +44,9 @@
)

if TYPE_CHECKING:
from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString, LiteralType
from pyspark.sql.connect._typing import ColumnOrName, ExpressionOrString, LiteralType
from pyspark.sql.connect.client import RemoteSparkSession

ColumnOrName = Union[Column, str]


class GroupingFrame(object):

Expand Down Expand Up @@ -308,7 +306,7 @@ def distinct(self) -> "DataFrame":
plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session
)

def drop(self, *cols: "ColumnOrString") -> "DataFrame":
def drop(self, *cols: "ColumnOrName") -> "DataFrame":
_cols = list(cols)
if any(not isinstance(c, (str, Column)) for c in _cols):
raise TypeError(
Expand Down Expand Up @@ -342,7 +340,7 @@ def first(self) -> Optional[Row]:
"""
return self.head()

def groupBy(self, *cols: "ColumnOrString") -> GroupingFrame:
def groupBy(self, *cols: "ColumnOrName") -> GroupingFrame:
return GroupingFrame(self, *cols)

@overload
Expand Down Expand Up @@ -414,13 +412,13 @@ def limit(self, n: int) -> "DataFrame":
def offset(self, n: int) -> "DataFrame":
return DataFrame.withPlan(plan.Offset(child=self._plan, offset=n), session=self._session)

def sort(self, *cols: "ColumnOrString") -> "DataFrame":
def sort(self, *cols: "ColumnOrName") -> "DataFrame":
"""Sort by a specific column"""
return DataFrame.withPlan(
plan.Sort(self._plan, columns=list(cols), is_global=True), session=self._session
)

def sortWithinPartitions(self, *cols: "ColumnOrString") -> "DataFrame":
def sortWithinPartitions(self, *cols: "ColumnOrName") -> "DataFrame":
"""Sort within each partition by a specific column"""
return DataFrame.withPlan(
plan.Sort(self._plan, columns=list(cols), is_global=False), session=self._session
Expand Down
10 changes: 7 additions & 3 deletions python/pyspark/sql/connect/function_builder.py
Expand Up @@ -28,9 +28,13 @@


if TYPE_CHECKING:
from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString
from pyspark.sql.connect._typing import (
ColumnOrName,
ExpressionOrString,
FunctionBuilderCallable,
UserDefinedFunctionCallable,
)
from pyspark.sql.connect.client import RemoteSparkSession
from pyspark.sql.connect.typing import FunctionBuilderCallable, UserDefinedFunctionCallable


def _build(name: str, *args: "ExpressionOrString") -> ScalarFunctionExpression:
Expand Down Expand Up @@ -103,7 +107,7 @@ def __str__(self) -> str:
def _create_udf(
function: Any, return_type: Union[str, pyspark.sql.types.DataType]
) -> "UserDefinedFunctionCallable":
def wrapper(*cols: "ColumnOrString") -> UserDefinedFunction:
def wrapper(*cols: "ColumnOrName") -> UserDefinedFunction:
return UserDefinedFunction(func=function, return_type=return_type, args=cols)

return wrapper
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/connect/plan.py
Expand Up @@ -36,7 +36,7 @@


if TYPE_CHECKING:
from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString
from pyspark.sql.connect._typing import ColumnOrName, ExpressionOrString
from pyspark.sql.connect.client import RemoteSparkSession


Expand All @@ -58,7 +58,7 @@ def unresolved_attr(self, colName: str) -> proto.Expression:
return exp

def to_attr_or_expression(
self, col: "ColumnOrString", session: "RemoteSparkSession"
self, col: "ColumnOrName", session: "RemoteSparkSession"
) -> proto.Expression:
"""Returns either an instance of an unresolved attribute or the serialized
expression value of the column."""
Expand Down
5 changes: 1 addition & 4 deletions python/pyspark/sql/connect/readwriter.py
Expand Up @@ -18,17 +18,14 @@

from typing import Dict, Optional

from pyspark.sql.connect.column import PrimitiveType
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.plan import Read, DataSource
from pyspark.sql.utils import to_str


OptionalPrimitiveType = Optional[PrimitiveType]

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from pyspark.sql.connect._typing import OptionalPrimitiveType
from pyspark.sql.connect.client import RemoteSparkSession


Expand Down
35 changes: 0 additions & 35 deletions python/pyspark/sql/connect/typing/__init__.pyi

This file was deleted.

0 comments on commit 381dd79

Please sign in to comment.