Skip to content

Commit

Permalink
[SPARK-45927][PYTHON] Update path handling for Python data source
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR updates how to handle `path` values from the `load()` method.
It changes the DataSource class constructor and add `path` as a key-value pair in the options field.

Also, this PR blocks loading multiple paths.

### Why are the changes needed?

To make the behavior consistent with the existing data source APIs.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Existing unit tests.

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #43809 from allisonwang-db/spark-45927-path.

Authored-by: allisonwang-db <allison.wang@databricks.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
  • Loading branch information
allisonwang-db authored and dongjoon-hyun committed Nov 20, 2023
1 parent 5ded567 commit 25ee62e
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 70 deletions.
17 changes: 3 additions & 14 deletions python/pyspark/sql/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#
from abc import ABC, abstractmethod
from typing import final, Any, Dict, Iterator, List, Optional, Tuple, Type, Union, TYPE_CHECKING
from typing import final, Any, Dict, Iterator, List, Tuple, Type, Union, TYPE_CHECKING

from pyspark import since
from pyspark.sql import Row
Expand Down Expand Up @@ -45,30 +45,19 @@ class DataSource(ABC):
"""

@final
def __init__(
self,
paths: List[str],
userSpecifiedSchema: Optional[StructType],
options: Dict[str, "OptionalPrimitiveType"],
) -> None:
def __init__(self, options: Dict[str, "OptionalPrimitiveType"]) -> None:
"""
Initializes the data source with user-provided information.
Initializes the data source with user-provided options.
Parameters
----------
paths : list
A list of paths to the data source.
userSpecifiedSchema : StructType, optional
The user-specified schema of the data source.
options : dict
A dictionary representing the options for this data source.
Notes
-----
This method should not be overridden.
"""
self.paths = paths
self.userSpecifiedSchema = userSpecifiedSchema
self.options = options

@classmethod
Expand Down
36 changes: 12 additions & 24 deletions python/pyspark/sql/tests/test_python_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class MyDataSource(DataSource):
...

options = dict(a=1, b=2)
ds = MyDataSource(paths=[], userSpecifiedSchema=None, options=options)
ds = MyDataSource(options=options)
self.assertEqual(ds.options, options)
self.assertEqual(ds.name(), "MyDataSource")
with self.assertRaises(NotImplementedError):
Expand All @@ -53,8 +53,7 @@ def test_in_memory_data_source(self):
class InMemDataSourceReader(DataSourceReader):
DEFAULT_NUM_PARTITIONS: int = 3

def __init__(self, paths, options):
self.paths = paths
def __init__(self, options):
self.options = options

def partitions(self):
Expand All @@ -76,7 +75,7 @@ def schema(self):
return "x INT, y STRING"

def reader(self, schema) -> "DataSourceReader":
return InMemDataSourceReader(self.paths, self.options)
return InMemDataSourceReader(self.options)

self.spark.dataSource.register(InMemoryDataSource)
df = self.spark.read.format("memory").load()
Expand All @@ -91,14 +90,13 @@ def test_custom_json_data_source(self):
import json

class JsonDataSourceReader(DataSourceReader):
def __init__(self, paths, options):
self.paths = paths
def __init__(self, options):
self.options = options

def partitions(self):
return iter(self.paths)

def read(self, path):
def read(self, partition):
path = self.options.get("path")
if path is None:
raise Exception("path is not specified")
with open(path, "r") as file:
for line in file.readlines():
if line.strip():
Expand All @@ -114,28 +112,18 @@ def schema(self):
return "name STRING, age INT"

def reader(self, schema) -> "DataSourceReader":
return JsonDataSourceReader(self.paths, self.options)
return JsonDataSourceReader(self.options)

self.spark.dataSource.register(JsonDataSource)
path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json")
path2 = os.path.join(SPARK_HOME, "python/test_support/sql/people1.json")
df1 = self.spark.read.format("my-json").load(path1)
self.assertEqual(df1.rdd.getNumPartitions(), 1)
assertDataFrameEqual(
df1,
self.spark.read.format("my-json").load(path1),
[Row(name="Michael", age=None), Row(name="Andy", age=30), Row(name="Justin", age=19)],
)

df2 = self.spark.read.format("my-json").load([path1, path2])
self.assertEqual(df2.rdd.getNumPartitions(), 2)
assertDataFrameEqual(
df2,
[
Row(name="Michael", age=None),
Row(name="Andy", age=30),
Row(name="Justin", age=19),
Row(name="Jonathan", age=None),
],
self.spark.read.format("my-json").load(path2),
[Row(name="Jonathan", age=None)],
)


Expand Down
15 changes: 2 additions & 13 deletions python/pyspark/sql/worker/create_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import inspect
import os
import sys
from typing import IO, List
from typing import IO

from pyspark.accumulators import _accumulatorRegistry
from pyspark.errors import PySparkAssertionError, PySparkRuntimeError, PySparkTypeError
Expand Down Expand Up @@ -55,7 +55,6 @@ def main(infile: IO, outfile: IO) -> None:
The JVM sends the following information to this process:
- a `DataSource` class representing the data source to be created.
- a provider name in string.
- a list of paths in string.
- an optional user-specified schema in json string.
- a dictionary of options in string.
Expand Down Expand Up @@ -107,12 +106,6 @@ def main(infile: IO, outfile: IO) -> None:
},
)

# Receive the paths.
num_paths = read_int(infile)
paths: List[str] = []
for _ in range(num_paths):
paths.append(utf8_deserializer.loads(infile))

# Receive the user-specified schema
user_specified_schema = None
if read_bool(infile):
Expand All @@ -136,11 +129,7 @@ def main(infile: IO, outfile: IO) -> None:

# Instantiate a data source.
try:
data_source = data_source_cls(
paths=paths,
userSpecifiedSchema=user_specified_schema, # type: ignore
options=options,
)
data_source = data_source_cls(options=options)
except Exception as e:
raise PySparkRuntimeError(
error_class="PYTHON_DATA_SOURCE_CREATE_ERROR",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {

private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = {
val builder = sparkSession.sharedState.dataSourceManager.lookupDataSource(source)
// Unless the legacy path option behavior is enabled, the extraOptions here
// should not include "path" or "paths" as keys.
val plan = builder(sparkSession, source, paths, userSpecifiedSchema, extraOptions)
// Add `path` and `paths` options to the extra options if specified.
val optionsWithPath = DataSourceV2Utils.getOptionsWithPaths(extraOptions, paths: _*)
val plan = builder(sparkSession, source, userSpecifiedSchema, optionsWithPath)
Dataset.ofRows(sparkSession, plan)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class DataSourceManager {
private type DataSourceBuilder = (
SparkSession, // Spark session
String, // provider name
Seq[String], // paths
Option[StructType], // user specified schema
CaseInsensitiveMap[String] // options
) => LogicalPlan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ private[sql] object DataSourceV2Utils extends Logging {
}

private lazy val objectMapper = new ObjectMapper()
private def getOptionsWithPaths(
def getOptionsWithPaths(
extraOptions: CaseInsensitiveMap[String],
paths: String*): CaseInsensitiveMap[String] = {
if (paths.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,11 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
def builder(
sparkSession: SparkSession,
provider: String,
paths: Seq[String],
userSpecifiedSchema: Option[StructType],
options: CaseInsensitiveMap[String]): LogicalPlan = {

val runner = new UserDefinedPythonDataSourceRunner(
dataSourceCls, provider, paths, userSpecifiedSchema, options)
dataSourceCls, provider, userSpecifiedSchema, options)

val result = runner.runInPython()
val pickledDataSourceInstance = result.dataSource
Expand All @@ -68,10 +67,9 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
def apply(
sparkSession: SparkSession,
provider: String,
paths: Seq[String] = Seq.empty,
userSpecifiedSchema: Option[StructType] = None,
options: CaseInsensitiveMap[String] = CaseInsensitiveMap(Map.empty)): DataFrame = {
val plan = builder(sparkSession, provider, paths, userSpecifiedSchema, options)
val plan = builder(sparkSession, provider, userSpecifiedSchema, options)
Dataset.ofRows(sparkSession, plan)
}
}
Expand All @@ -89,7 +87,6 @@ case class PythonDataSourceCreationResult(
class UserDefinedPythonDataSourceRunner(
dataSourceCls: PythonFunction,
provider: String,
paths: Seq[String],
userSpecifiedSchema: Option[StructType],
options: CaseInsensitiveMap[String])
extends PythonPlannerRunner[PythonDataSourceCreationResult](dataSourceCls) {
Expand All @@ -103,10 +100,6 @@ class UserDefinedPythonDataSourceRunner(
// Send the provider name
PythonWorkerUtils.writeUTF(provider, dataOut)

// Send the paths
dataOut.writeInt(paths.length)
paths.foreach(PythonWorkerUtils.writeUTF(_, dataOut))

// Send the user-specified schema, if provided
dataOut.writeBoolean(userSpecifiedSchema.isDefined)
userSpecifiedSchema.map(_.json).foreach(PythonWorkerUtils.writeUTF(_, dataOut))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,20 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession {
val dataSourceScript =
s"""
|from pyspark.sql.datasource import DataSource, DataSourceReader
|import json
|
|class SimpleDataSourceReader(DataSourceReader):
| def __init__(self, paths, options):
| self.paths = paths
| def __init__(self, options):
| self.options = options
|
| def partitions(self):
| return iter(self.paths)
| if "paths" in self.options:
| paths = json.loads(self.options["paths"])
| elif "path" in self.options:
| paths = [self.options["path"]]
| else:
| paths = []
| return paths
|
| def read(self, path):
| yield (path, 1)
Expand All @@ -180,11 +187,10 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession {
| return "id STRING, value INT"
|
| def reader(self, schema):
| return SimpleDataSourceReader(self.paths, self.options)
| return SimpleDataSourceReader(self.options)
|""".stripMargin
val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript)
spark.dataSource.registerPython("test", dataSource)

checkAnswer(spark.read.format("test").load(), Seq(Row(null, 1)))
checkAnswer(spark.read.format("test").load("1"), Seq(Row("1", 1)))
checkAnswer(spark.read.format("test").load("1", "2"), Seq(Row("1", 1), Row("2", 1)))
Expand Down

0 comments on commit 25ee62e

Please sign in to comment.