Skip to content

Commit

Permalink
[SPARK-43971][CONNECT][PYTHON] Support Python's createDataFrame in st…
Browse files Browse the repository at this point in the history
…reaming manner

### What changes were proposed in this pull request?
In the PR, I propose to transfer a local relation from **the Python connect client** to the server in streaming way when it exceeds some size which is defined by the SQL config `spark.sql.session.localRelationCacheThreshold`. The implementation is similar to #40827.  In particular:
1. The client applies the `sha256` function over **the proto form** of the local relation;
2. It checks presents of the relation at the server side by sending the relation hash to the server;
3. If the server doesn't have the local relation, the client transfers the local relation as an artefact with the name `cache/<sha256>`;
4. As soon as the relation has presented at the server already, or transferred recently, the client transform the logical plan by replacing the `LocalRelation` node by `CachedLocalRelation` with the hash.
5. On another hand, the server converts `CachedLocalRelation` back to `LocalRelation` by retrieving the relation body from the local cache.

### Why are the changes needed?
To fix the issues of creating a large dataframe from a local collection:
```python
pyspark.errors.exceptions.connect.SparkConnectGrpcException: <_MultiThreadedRendezvous of RPC that terminated with:
	status = StatusCode.RESOURCE_EXHAUSTED
	details = "Sent message larger than max (134218508 vs. 134217728)"
	debug_error_string = "UNKNOWN:Error received from peer localhost:50982 {grpc_message:"Sent message larger than max (134218508 vs. 134217728)", grpc_status:8, created_time:"2023-06-09T15:34:08.362797+03:00"}
```

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
By running new test:
```
$ python/run-tests --parallelism=1 --testnames 'pyspark.sql.tests.connect.test_connect_basic SparkConnectBasicTests.test_streaming_local_relation'
```

Closes #41537 from MaxGekk/streaming-createDataFrame-python-4.

Authored-by: Max Gekk <max.gekk@gmail.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
MaxGekk committed Jun 9, 2023
1 parent c8fe194 commit 93e0acb
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 3 deletions.
3 changes: 3 additions & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,6 +1257,9 @@ def add_artifacts(self, *path: str, pyfile: bool, archive: bool, file: bool) ->
def copy_from_local_to_fs(self, local_path: str, dest_path: str) -> None:
self._artifact_manager._add_forward_to_fs_artifacts(local_path, dest_path)

def cache_artifact(self, blob: bytes) -> str:
return self._artifact_manager.cache_artifact(blob)


class RetryState:
"""
Expand Down
34 changes: 34 additions & 0 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,10 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan.local_relation.schema = self._schema
return plan

def serialize(self, session: "SparkConnectClient") -> bytes:
p = self.plan(session)
return bytes(p.local_relation.SerializeToString())

def print(self, indent: int = 0) -> str:
return f"{' ' * indent}<LocalRelation>\n"

Expand All @@ -374,6 +378,36 @@ def _repr_html_(self) -> str:
"""


class CachedLocalRelation(LogicalPlan):
"""Creates a CachedLocalRelation plan object based on a hash of a LocalRelation."""

def __init__(self, hash: str) -> None:
super().__init__(None)

self._hash = hash

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
clr = plan.cached_local_relation

if session._user_id:
clr.userId = session._user_id
clr.sessionId = session._session_id
clr.hash = self._hash

return plan

def print(self, indent: int = 0) -> str:
return f"{' ' * indent}<CachedLocalRelation>\n"

def _repr_html_(self) -> str:
return """
<ul>
<li><b>CachedLocalRelation</b></li>
</ul>
"""


class ShowString(LogicalPlan):
def __init__(
self, child: Optional["LogicalPlan"], num_rows: int, truncate: int, vertical: bool
Expand Down
26 changes: 23 additions & 3 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,14 @@
from pyspark.sql.connect.client import SparkConnectClient, ChannelBuilder
from pyspark.sql.connect.conf import RuntimeConf
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.plan import SQL, Range, LocalRelation, CachedRelation
from pyspark.sql.connect.plan import (
SQL,
Range,
LocalRelation,
LogicalPlan,
CachedLocalRelation,
CachedRelation,
)
from pyspark.sql.connect.readwriter import DataFrameReader
from pyspark.sql.connect.streaming import DataStreamReader, StreamingQueryManager
from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer
Expand Down Expand Up @@ -466,10 +473,16 @@ def createDataFrame(
)

if _schema is not None:
df = DataFrame.withPlan(LocalRelation(_table, schema=_schema.json()), self)
local_relation = LocalRelation(_table, schema=_schema.json())
else:
df = DataFrame.withPlan(LocalRelation(_table), self)
local_relation = LocalRelation(_table)

cache_threshold = self._client.get_configs("spark.sql.session.localRelationCacheThreshold")
plan: LogicalPlan = local_relation
if cache_threshold[0] is not None and int(cache_threshold[0]) <= _table.nbytes:
plan = CachedLocalRelation(self._cache_local_relation(local_relation))

df = DataFrame.withPlan(plan, self)
if _cols is not None and len(_cols) > 0:
df = df.toDF(*_cols)
return df
Expand Down Expand Up @@ -643,6 +656,13 @@ def addArtifacts(

addArtifact = addArtifacts

def _cache_local_relation(self, local_relation: LocalRelation) -> str:
"""
Cache the local relation at the server side if it has not been cached yet.
"""
serialized = local_relation.serialize(self._client)
return self._client.cache_artifact(serialized)

def copyFromLocalToFs(self, local_path: str, dest_path: str) -> None:
"""
Copy file from local to cloud storage file system.
Expand Down
19 changes: 19 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import datetime
import os
import unittest
import random
import shutil
import string
import tempfile
from collections import defaultdict

Expand Down Expand Up @@ -649,6 +651,23 @@ def test_with_local_rows(self):
self.assertEqual(sdf.schema, cdf.schema)
self.assert_eq(sdf.toPandas(), cdf.toPandas())

def test_streaming_local_relation(self):
threshold_conf = "spark.sql.session.localRelationCacheThreshold"
old_threshold = self.connect.conf.get(threshold_conf)
threshold = 1024 * 1024
self.connect.conf.set(threshold_conf, threshold)
try:
suffix = "abcdef"
letters = string.ascii_lowercase
str = "".join(random.choice(letters) for i in range(threshold)) + suffix
data = [[0, str], [1, str]]
for i in range(0, 2):
cdf = self.connect.createDataFrame(data, ["a", "b"])
self.assert_eq(cdf.count(), len(data))
self.assert_eq(cdf.filter(f"endsWith(b, '{suffix}')").isEmpty(), False)
finally:
self.connect.conf.set(threshold_conf, old_threshold)

def test_with_atom_type(self):
for data in [[(1), (2), (3)], [1, 2, 3]]:
for schema in ["long", "int", "short"]:
Expand Down

0 comments on commit 93e0acb

Please sign in to comment.