Skip to content

Commit

Permalink
[SPARK-37104][PYTHON] Make RDD and DStream covariant
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR changes changes `RDD[~T]` and `DStream[~T]` to `RDD[+T]` and `DStream[+T]` respectively.

### Why are the changes needed?

To improve usability of the current annotations and simplify further development of type hints.  Let's take simple `RDD` to `DataFrame` as an example. Currently, the following code will not type check

```python
from pyspark import SparkContext
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
sc = spark.sparkContext

rdd = sc.parallelize([(1, 2)])
reveal_type(rdd)

spark.createDataFrame(rdd)
```

with

```
main.py:8: note: Revealed type is "pyspark.rdd.RDD[Tuple[builtins.int, builtins.int]]"
main.py:10: error: Argument 1 to "createDataFrame" of "SparkSession" has incompatible type "RDD[Tuple[int, int]]"; expected "Union[RDD[Tuple[Any, ...]], Iterable[Tuple[Any, ...]]]"
Found 1 error in 1 file (checked 1 source file)
```

To type check, `rdd` would have to be annotated with specific type, matching the signature of the `createDataFrame` method:

```python
rdd: RDD[Tuple[Any, ...]] = sc.parallelize([(1, 2)])
```

Alternatively, one could inline definition:

```python
spark.createDataFrame(sc.parallelize([(1, 2)]))
```

Similarly, with `pyspark.mllib`:

```python
from pyspark import SparkContext
from pyspark.sql import SparkSession
from pyspark.mllib.clustering import KMeans
from pyspark.mllib.linalg import SparseVector, Vectors

spark = SparkSession.builder.getOrCreate()
sc = spark.sparkContext

rdd = sc.parallelize([
    Vectors.sparse(10, [1, 3, 5], [1, 1, 1]),
    Vectors.sparse(10, [2, 4, 6], [1, 1, 1]),
])

KMeans.train(rdd, 2)
```

we'd get

```
main.py:14: error: Argument 1 to "train" of "KMeans" has incompatible type "RDD[SparseVector]"; expected "RDD[Union[ndarray[Any, Any], Vector, List[float], Tuple[float, ...]]]"
Found 1 error in 1 file (checked 1 source file)
```

but this time, we'd need much more complex annotation (inlining would work as well):

```python
rdd: RDD[Union[ndarray[Any, Any], Vector, List[float], Tuple[float, ...]]] = sc.parallelize([
    Vectors.sparse(10, [1, 3, 5], [1, 1, 1]),
    Vectors.sparse(10, [2, 4, 6], [1, 1, 1]),
])
```

This happens because

- RDD is invariant in terms of stored type.
- mypy doesn't look forward to infer types of objects depending on the usage context (similarly to Scala console / spark-shell, but unlike standalone Scala compiler, which allows us to have [examples like this](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala))

It not only makes things verbose, but also fragile and dependent on details of implementation. In the first example, where we have top level `Union`, we can just use `RDD[...]` and ignore other members.

In the second case, where `Union` is a type parameter we have to match all its components (it could be simpler if we didn't use `RDD[VectorLike]` but defined something like `RDD[ndarray] | RDD[Vector] | RDD[List[float]] | RDD[Tuple[float, ...]]]`, which should make it closer to the first case, though not semantically equivalent to the current signature).

Theoretically, we could partially address this with different definitions of aliases, like using type bounds (see discussion under #34354), but it doesn't scale well and requires same steps to be taken by every library that depends on PySpark.

See also related discussion about Scala counterpart ‒ SPARK-1296

### Does this PR introduce _any_ user-facing change?

Type hints only.

Users will be able to use both subclasses of `RDD` / `DStream` in certain contexts, without explicit annotations or casts (both examples will pass type checker in their original form).

### How was this patch tested?

Existing tests and not released data tests (SPARK-36989).

Closes #34374 from zero323/SPARK-37104.

Authored-by: zero323 <mszymkiewicz@gmail.com>
Signed-off-by: zero323 <mszymkiewicz@gmail.com>
  • Loading branch information
zero323 committed Nov 21, 2021
1 parent b9e9167 commit ef4f254
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 82 deletions.
4 changes: 2 additions & 2 deletions python/pyspark/_typing.pyi
Expand Up @@ -20,7 +20,7 @@ from typing import Callable, Iterable, Sized, TypeVar, Union
from typing_extensions import Protocol

F = TypeVar("F", bound=Callable)
T = TypeVar("T", covariant=True)
T_co = TypeVar("T_co", covariant=True)

PrimitiveType = Union[bool, float, int, str]

Expand All @@ -30,4 +30,4 @@ class SupportsIAdd(Protocol):
class SupportsOrdering(Protocol):
def __le__(self, other: SupportsOrdering) -> bool: ...

class SizedIterable(Protocol, Sized, Iterable[T]): ...
class SizedIterable(Protocol, Sized, Iterable[T_co]): ...
107 changes: 55 additions & 52 deletions python/pyspark/rdd.pyi
Expand Up @@ -61,6 +61,7 @@ from pyspark.sql._typing import AtomicValue, RowLike
from py4j.java_gateway import JavaObject # type: ignore[import]

T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)
U = TypeVar("U")
K = TypeVar("K", bound=Hashable)
V = TypeVar("V")
Expand Down Expand Up @@ -96,7 +97,7 @@ class Partitioner:
def __eq__(self, other: Any) -> bool: ...
def __call__(self, k: Any) -> int: ...

class RDD(Generic[T]):
class RDD(Generic[T_co]):
is_cached: bool
is_checkpointed: bool
ctx: pyspark.context.SparkContext
Expand All @@ -111,44 +112,46 @@ class RDD(Generic[T]):
def __getnewargs__(self) -> Any: ...
@property
def context(self) -> pyspark.context.SparkContext: ...
def cache(self) -> RDD[T]: ...
def persist(self, storageLevel: StorageLevel = ...) -> RDD[T]: ...
def unpersist(self, blocking: bool = ...) -> RDD[T]: ...
def cache(self) -> RDD[T_co]: ...
def persist(self, storageLevel: StorageLevel = ...) -> RDD[T_co]: ...
def unpersist(self, blocking: bool = ...) -> RDD[T_co]: ...
def checkpoint(self) -> None: ...
def isCheckpointed(self) -> bool: ...
def localCheckpoint(self) -> None: ...
def isLocallyCheckpointed(self) -> bool: ...
def getCheckpointFile(self) -> Optional[str]: ...
def map(self, f: Callable[[T], U], preservesPartitioning: bool = ...) -> RDD[U]: ...
def map(self, f: Callable[[T_co], U], preservesPartitioning: bool = ...) -> RDD[U]: ...
def flatMap(
self, f: Callable[[T], Iterable[U]], preservesPartitioning: bool = ...
self, f: Callable[[T_co], Iterable[U]], preservesPartitioning: bool = ...
) -> RDD[U]: ...
def mapPartitions(
self, f: Callable[[Iterable[T]], Iterable[U]], preservesPartitioning: bool = ...
self, f: Callable[[Iterable[T_co]], Iterable[U]], preservesPartitioning: bool = ...
) -> RDD[U]: ...
def mapPartitionsWithIndex(
self,
f: Callable[[int, Iterable[T]], Iterable[U]],
f: Callable[[int, Iterable[T_co]], Iterable[U]],
preservesPartitioning: bool = ...,
) -> RDD[U]: ...
def mapPartitionsWithSplit(
self,
f: Callable[[int, Iterable[T]], Iterable[U]],
f: Callable[[int, Iterable[T_co]], Iterable[U]],
preservesPartitioning: bool = ...,
) -> RDD[U]: ...
def getNumPartitions(self) -> int: ...
def filter(self, f: Callable[[T], bool]) -> RDD[T]: ...
def distinct(self, numPartitions: Optional[int] = ...) -> RDD[T]: ...
def filter(self, f: Callable[[T_co], bool]) -> RDD[T_co]: ...
def distinct(self, numPartitions: Optional[int] = ...) -> RDD[T_co]: ...
def sample(
self, withReplacement: bool, fraction: float, seed: Optional[int] = ...
) -> RDD[T]: ...
) -> RDD[T_co]: ...
def randomSplit(
self, weights: List[Union[int, float]], seed: Optional[int] = ...
) -> List[RDD[T]]: ...
def takeSample(self, withReplacement: bool, num: int, seed: Optional[int] = ...) -> List[T]: ...
def union(self, other: RDD[U]) -> RDD[Union[T, U]]: ...
def intersection(self, other: RDD[T]) -> RDD[T]: ...
def __add__(self, other: RDD[T]) -> RDD[T]: ...
) -> List[RDD[T_co]]: ...
def takeSample(
self, withReplacement: bool, num: int, seed: Optional[int] = ...
) -> List[T_co]: ...
def union(self, other: RDD[U]) -> RDD[Union[T_co, U]]: ...
def intersection(self, other: RDD[T_co]) -> RDD[T_co]: ...
def __add__(self, other: RDD[T_co]) -> RDD[T_co]: ...
@overload
def repartitionAndSortWithinPartitions(
self: RDD[Tuple[O, V]],
Expand Down Expand Up @@ -195,55 +198,55 @@ class RDD(Generic[T]):
keyfunc: Callable[[K], O],
) -> RDD[Tuple[K, V]]: ...
def sortBy(
self: RDD[T],
keyfunc: Callable[[T], O],
self,
keyfunc: Callable[[T_co], O],
ascending: bool = ...,
numPartitions: Optional[int] = ...,
) -> RDD[T]: ...
def glom(self) -> RDD[List[T]]: ...
def cartesian(self, other: RDD[U]) -> RDD[Tuple[T, U]]: ...
) -> RDD[T_co]: ...
def glom(self) -> RDD[List[T_co]]: ...
def cartesian(self, other: RDD[U]) -> RDD[Tuple[T_co, U]]: ...
def groupBy(
self,
f: Callable[[T], K],
f: Callable[[T_co], K],
numPartitions: Optional[int] = ...,
partitionFunc: Callable[[K], int] = ...,
) -> RDD[Tuple[K, Iterable[T]]]: ...
) -> RDD[Tuple[K, Iterable[T_co]]]: ...
def pipe(
self, command: str, env: Optional[Dict[str, str]] = ..., checkCode: bool = ...
) -> RDD[str]: ...
def foreach(self, f: Callable[[T], None]) -> None: ...
def foreachPartition(self, f: Callable[[Iterable[T]], None]) -> None: ...
def collect(self) -> List[T]: ...
def foreach(self, f: Callable[[T_co], None]) -> None: ...
def foreachPartition(self, f: Callable[[Iterable[T_co]], None]) -> None: ...
def collect(self) -> List[T_co]: ...
def collectWithJobGroup(
self, groupId: str, description: str, interruptOnCancel: bool = ...
) -> List[T]: ...
def reduce(self, f: Callable[[T, T], T]) -> T: ...
def treeReduce(self, f: Callable[[T, T], T], depth: int = ...) -> T: ...
def fold(self, zeroValue: T, op: Callable[[T, T], T]) -> T: ...
) -> List[T_co]: ...
def reduce(self, f: Callable[[T_co, T_co], T_co]) -> T_co: ...
def treeReduce(self, f: Callable[[T_co, T_co], T_co], depth: int = ...) -> T_co: ...
def fold(self, zeroValue: T, op: Callable[[T_co, T_co], T_co]) -> T_co: ...
def aggregate(
self, zeroValue: U, seqOp: Callable[[U, T], U], combOp: Callable[[U, U], U]
self, zeroValue: U, seqOp: Callable[[U, T_co], U], combOp: Callable[[U, U], U]
) -> U: ...
def treeAggregate(
self,
zeroValue: U,
seqOp: Callable[[U, T], U],
seqOp: Callable[[U, T_co], U],
combOp: Callable[[U, U], U],
depth: int = ...,
) -> U: ...
@overload
def max(self: RDD[O]) -> O: ...
@overload
def max(self, key: Callable[[T], O]) -> T: ...
def max(self, key: Callable[[T_co], O]) -> T_co: ...
@overload
def min(self: RDD[O]) -> O: ...
@overload
def min(self, key: Callable[[T], O]) -> T: ...
def min(self, key: Callable[[T_co], O]) -> T_co: ...
def sum(self: RDD[NumberOrArray]) -> NumberOrArray: ...
def count(self) -> int: ...
def stats(self: RDD[NumberOrArray]) -> StatCounter: ...
def histogram(
self, buckets: Union[int, List[T], Tuple[T, ...]]
) -> Tuple[List[T], List[int]]: ...
self, buckets: Union[int, List[T_co], Tuple[T_co, ...]]
) -> Tuple[List[T_co], List[int]]: ...
def mean(self: RDD[NumberOrArray]) -> NumberOrArray: ...
def variance(self: RDD[NumberOrArray]) -> NumberOrArray: ...
def stdev(self: RDD[NumberOrArray]) -> NumberOrArray: ...
Expand All @@ -253,13 +256,13 @@ class RDD(Generic[T]):
@overload
def top(self: RDD[O], num: int) -> List[O]: ...
@overload
def top(self: RDD[T], num: int, key: Callable[[T], O]) -> List[T]: ...
def top(self, num: int, key: Callable[[T_co], O]) -> List[T_co]: ...
@overload
def takeOrdered(self: RDD[O], num: int) -> List[O]: ...
@overload
def takeOrdered(self: RDD[T], num: int, key: Callable[[T], O]) -> List[T]: ...
def take(self, num: int) -> List[T]: ...
def first(self) -> T: ...
def takeOrdered(self, num: int, key: Callable[[T_co], O]) -> List[T_co]: ...
def take(self, num: int) -> List[T_co]: ...
def first(self) -> T_co: ...
def isEmpty(self) -> bool: ...
def saveAsNewAPIHadoopDataset(
self: RDD[Tuple[K, V]],
Expand Down Expand Up @@ -408,15 +411,15 @@ class RDD(Generic[T]):
other: RDD[Tuple[K, U]],
numPartitions: Optional[int] = ...,
) -> RDD[Tuple[K, V]]: ...
def subtract(self: RDD[T], other: RDD[T], numPartitions: Optional[int] = ...) -> RDD[T]: ...
def keyBy(self: RDD[T], f: Callable[[T], K]) -> RDD[Tuple[K, T]]: ...
def repartition(self, numPartitions: int) -> RDD[T]: ...
def coalesce(self, numPartitions: int, shuffle: bool = ...) -> RDD[T]: ...
def zip(self, other: RDD[U]) -> RDD[Tuple[T, U]]: ...
def zipWithIndex(self) -> RDD[Tuple[T, int]]: ...
def zipWithUniqueId(self) -> RDD[Tuple[T, int]]: ...
def subtract(self, other: RDD[T_co], numPartitions: Optional[int] = ...) -> RDD[T_co]: ...
def keyBy(self, f: Callable[[T_co], K]) -> RDD[Tuple[K, T_co]]: ...
def repartition(self, numPartitions: int) -> RDD[T_co]: ...
def coalesce(self, numPartitions: int, shuffle: bool = ...) -> RDD[T_co]: ...
def zip(self, other: RDD[U]) -> RDD[Tuple[T_co, U]]: ...
def zipWithIndex(self) -> RDD[Tuple[T_co, int]]: ...
def zipWithUniqueId(self) -> RDD[Tuple[T_co, int]]: ...
def name(self) -> str: ...
def setName(self, name: str) -> RDD[T]: ...
def setName(self, name: str) -> RDD[T_co]: ...
def toDebugString(self) -> bytes: ...
def getStorageLevel(self) -> StorageLevel: ...
def lookup(self: RDD[Tuple[K, V]], key: K) -> List[V]: ...
Expand All @@ -428,9 +431,9 @@ class RDD(Generic[T]):
self: RDD[Union[float, int]], timeout: int, confidence: float = ...
) -> BoundedFloat: ...
def countApproxDistinct(self, relativeSD: float = ...) -> int: ...
def toLocalIterator(self, prefetchPartitions: bool = ...) -> Iterator[T]: ...
def barrier(self: RDD[T]) -> RDDBarrier[T]: ...
def withResources(self: RDD[T], profile: ResourceProfile) -> RDD[T]: ...
def toLocalIterator(self, prefetchPartitions: bool = ...) -> Iterator[T_co]: ...
def barrier(self) -> RDDBarrier[T_co]: ...
def withResources(self, profile: ResourceProfile) -> RDD[T_co]: ...
def getResourceProfile(self) -> Optional[ResourceProfile]: ...
@overload
def toDF(
Expand Down
59 changes: 31 additions & 28 deletions python/pyspark/streaming/dstream.pyi
Expand Up @@ -38,11 +38,12 @@ from py4j.java_gateway import JavaObject

S = TypeVar("S")
T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)
U = TypeVar("U")
K = TypeVar("K", bound=Hashable)
V = TypeVar("V")

class DStream(Generic[T]):
class DStream(Generic[T_co]):
is_cached: bool
is_checkpointed: bool
def __init__(
Expand All @@ -53,24 +54,24 @@ class DStream(Generic[T]):
) -> None: ...
def context(self) -> pyspark.streaming.context.StreamingContext: ...
def count(self) -> DStream[int]: ...
def filter(self, f: Callable[[T], bool]) -> DStream[T]: ...
def filter(self, f: Callable[[T_co], bool]) -> DStream[T_co]: ...
def flatMap(
self: DStream[T],
f: Callable[[T], Iterable[U]],
self: DStream[T_co],
f: Callable[[T_co], Iterable[U]],
preservesPartitioning: bool = ...,
) -> DStream[U]: ...
def map(
self: DStream[T], f: Callable[[T], U], preservesPartitioning: bool = ...
self: DStream[T_co], f: Callable[[T_co], U], preservesPartitioning: bool = ...
) -> DStream[U]: ...
def mapPartitions(
self, f: Callable[[Iterable[T]], Iterable[U]], preservesPartitioning: bool = ...
self, f: Callable[[Iterable[T_co]], Iterable[U]], preservesPartitioning: bool = ...
) -> DStream[U]: ...
def mapPartitionsWithIndex(
self,
f: Callable[[int, Iterable[T]], Iterable[U]],
f: Callable[[int, Iterable[T_co]], Iterable[U]],
preservesPartitioning: bool = ...,
) -> DStream[U]: ...
def reduce(self, func: Callable[[T, T], T]) -> DStream[T]: ...
def reduce(self, func: Callable[[T_co, T_co], T_co]) -> DStream[T_co]: ...
def reduceByKey(
self: DStream[Tuple[K, V]],
func: Callable[[V, V], V],
Expand All @@ -89,45 +90,45 @@ class DStream(Generic[T]):
partitionFunc: Callable[[K], int] = ...,
) -> DStream[Tuple[K, V]]: ...
@overload
def foreachRDD(self, func: Callable[[RDD[T]], None]) -> None: ...
def foreachRDD(self, func: Callable[[RDD[T_co]], None]) -> None: ...
@overload
def foreachRDD(self, func: Callable[[datetime.datetime, RDD[T]], None]) -> None: ...
def foreachRDD(self, func: Callable[[datetime.datetime, RDD[T_co]], None]) -> None: ...
def pprint(self, num: int = ...) -> None: ...
def mapValues(self: DStream[Tuple[K, V]], f: Callable[[V], U]) -> DStream[Tuple[K, U]]: ...
def flatMapValues(
self: DStream[Tuple[K, V]], f: Callable[[V], Iterable[U]]
) -> DStream[Tuple[K, U]]: ...
def glom(self) -> DStream[List[T]]: ...
def cache(self) -> DStream[T]: ...
def persist(self, storageLevel: StorageLevel) -> DStream[T]: ...
def checkpoint(self, interval: int) -> DStream[T]: ...
def glom(self) -> DStream[List[T_co]]: ...
def cache(self) -> DStream[T_co]: ...
def persist(self, storageLevel: StorageLevel) -> DStream[T_co]: ...
def checkpoint(self, interval: int) -> DStream[T_co]: ...
def groupByKey(
self: DStream[Tuple[K, V]], numPartitions: Optional[int] = ...
) -> DStream[Tuple[K, Iterable[V]]]: ...
def countByValue(self) -> DStream[Tuple[T, int]]: ...
def countByValue(self) -> DStream[Tuple[T_co, int]]: ...
def saveAsTextFiles(self, prefix: str, suffix: Optional[str] = ...) -> None: ...
@overload
def transform(self, func: Callable[[RDD[T]], RDD[U]]) -> TransformedDStream[U]: ...
def transform(self, func: Callable[[RDD[T_co]], RDD[U]]) -> TransformedDStream[U]: ...
@overload
def transform(
self, func: Callable[[datetime.datetime, RDD[T]], RDD[U]]
self, func: Callable[[datetime.datetime, RDD[T_co]], RDD[U]]
) -> TransformedDStream[U]: ...
@overload
def transformWith(
self,
func: Callable[[RDD[T], RDD[U]], RDD[V]],
func: Callable[[RDD[T_co], RDD[U]], RDD[V]],
other: RDD[U],
keepSerializer: bool = ...,
) -> DStream[V]: ...
@overload
def transformWith(
self,
func: Callable[[datetime.datetime, RDD[T], RDD[U]], RDD[V]],
func: Callable[[datetime.datetime, RDD[T_co], RDD[U]], RDD[V]],
other: RDD[U],
keepSerializer: bool = ...,
) -> DStream[V]: ...
def repartition(self, numPartitions: int) -> DStream[T]: ...
def union(self, other: DStream[U]) -> DStream[Union[T, U]]: ...
def repartition(self, numPartitions: int) -> DStream[T_co]: ...
def union(self, other: DStream[U]) -> DStream[Union[T_co, U]]: ...
def cogroup(
self: DStream[Tuple[K, V]],
other: DStream[Tuple[K, U]],
Expand Down Expand Up @@ -155,22 +156,24 @@ class DStream(Generic[T]):
) -> DStream[Tuple[K, Tuple[Optional[V], Optional[U]]]]: ...
def slice(
self, begin: Union[datetime.datetime, int], end: Union[datetime.datetime, int]
) -> List[RDD[T]]: ...
def window(self, windowDuration: int, slideDuration: Optional[int] = ...) -> DStream[T]: ...
) -> List[RDD[T_co]]: ...
def window(self, windowDuration: int, slideDuration: Optional[int] = ...) -> DStream[T_co]: ...
def reduceByWindow(
self,
reduceFunc: Callable[[T, T], T],
invReduceFunc: Optional[Callable[[T, T], T]],
reduceFunc: Callable[[T_co, T_co], T_co],
invReduceFunc: Optional[Callable[[T_co, T_co], T_co]],
windowDuration: int,
slideDuration: int,
) -> DStream[T]: ...
def countByWindow(self, windowDuration: int, slideDuration: int) -> DStream[Tuple[T, int]]: ...
) -> DStream[T_co]: ...
def countByWindow(
self, windowDuration: int, slideDuration: int
) -> DStream[Tuple[T_co, int]]: ...
def countByValueAndWindow(
self,
windowDuration: int,
slideDuration: int,
numPartitions: Optional[int] = ...,
) -> DStream[Tuple[T, int]]: ...
) -> DStream[Tuple[T_co, int]]: ...
def groupByKeyAndWindow(
self: DStream[Tuple[K, V]],
windowDuration: int,
Expand Down

0 comments on commit ef4f254

Please sign in to comment.