Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions connector/connect/client/jvm/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@
<artifactId>spark-sql-api_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-connect-shims_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sketch_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ import java.util.Properties
import scala.jdk.CollectionConverters._

import org.apache.spark.annotation.Stable
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.connect.proto.Parse.ParseFormat
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -140,6 +142,14 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends api.Data
def json(jsonDataset: Dataset[String]): DataFrame =
parse(jsonDataset, ParseFormat.PARSE_FORMAT_JSON)

/** @inheritdoc */
override def json(jsonRDD: JavaRDD[String]): Dataset[Row] =
throwRddNotSupportedException()

/** @inheritdoc */
override def json(jsonRDD: RDD[String]): Dataset[Row] =
throwRddNotSupportedException()

/** @inheritdoc */
override def csv(path: String): DataFrame = super.csv(path)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ import scala.util.control.NonFatal

import org.apache.spark.SparkException
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.function._
import org.apache.spark.connect.proto
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
Expand Down Expand Up @@ -1463,4 +1465,10 @@ class Dataset[T] private[sql] (
func: MapFunction[T, K],
encoder: Encoder[K]): KeyValueGroupedDataset[K, T] =
super.groupByKey(func, encoder).asInstanceOf[KeyValueGroupedDataset[K, T]]

/** @inheritdoc */
override def rdd: RDD[T] = throwRddNotSupportedException()

/** @inheritdoc */
override def toJavaRDD: JavaRDD[T] = throwRddNotSupportedException()
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,13 @@ import com.google.common.cache.{CacheBuilder, CacheLoader}
import io.grpc.ClientInterceptor
import org.apache.arrow.memory.RootAllocator

import org.apache.spark.SparkContext
import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.ExecutePlanResponse
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalog.Catalog
import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
Expand Down Expand Up @@ -84,10 +87,14 @@ class SparkSession private[sql] (

private[sql] val observationRegistry = new ConcurrentHashMap[Long, Observation]()

private[sql] def hijackServerSideSessionIdForTesting(suffix: String) = {
private[sql] def hijackServerSideSessionIdForTesting(suffix: String): Unit = {
client.hijackServerSideSessionIdForTesting(suffix)
}

/** @inheritdoc */
override def sparkContext: SparkContext =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this use the error framework.

throw new UnsupportedOperationException("sparkContext is not supported in Spark Connect.")

/** @inheritdoc */
val conf: RuntimeConfig = new ConnectRuntimeConfig(client)

Expand Down Expand Up @@ -144,6 +151,30 @@ class SparkSession private[sql] (
createDataset(data.asScala.toSeq)
}

/** @inheritdoc */
override def createDataFrame[A <: Product: TypeTag](rdd: RDD[A]): DataFrame =
throwRddNotSupportedException()

/** @inheritdoc */
override def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame =
throwRddNotSupportedException()

/** @inheritdoc */
override def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame =
throwRddNotSupportedException()

/** @inheritdoc */
override def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame =
throwRddNotSupportedException()

/** @inheritdoc */
override def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame =
throwRddNotSupportedException()

/** @inheritdoc */
override def createDataset[T: Encoder](data: RDD[T]): Dataset[T] =
throwRddNotSupportedException()

/** @inheritdoc */
@Experimental
def sql(sqlText: String, args: Array[_]): DataFrame = newDataFrame { builder =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,7 @@ package org.apache.spark

package object sql {
type DataFrame = Dataset[Row]

private[sql] def throwRddNotSupportedException(): Nothing =
throw new UnsupportedOperationException("RDDs are not supported in Spark Connect.")
Copy link
Contributor Author

@hvanhovell hvanhovell Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this use the error framework.

}
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
<module>common/utils</module>
<module>common/variant</module>
<module>common/tags</module>
<module>sql/connect/shims</module>
<module>core</module>
<module>graphx</module>
<module>mllib</module>
Expand Down
58 changes: 45 additions & 13 deletions project/SparkBuild.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,24 @@ object BuildCommons {

private val buildLocation = file(".").getAbsoluteFile.getParentFile

val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, tokenProviderKafka010, sqlKafka010, avro, protobuf) = Seq(
"catalyst", "sql", "hive", "hive-thriftserver", "token-provider-kafka-0-10", "sql-kafka-0-10", "avro", "protobuf"
).map(ProjectRef(buildLocation, _))
val sqlProjects@Seq(sqlApi, catalyst, sql, hive, hiveThriftServer, tokenProviderKafka010, sqlKafka010, avro, protobuf) =
Seq("sql-api", "catalyst", "sql", "hive", "hive-thriftserver", "token-provider-kafka-0-10",
"sql-kafka-0-10", "avro", "protobuf").map(ProjectRef(buildLocation, _))

val streamingProjects@Seq(streaming, streamingKafka010) =
Seq("streaming", "streaming-kafka-0-10").map(ProjectRef(buildLocation, _))

val connectCommon = ProjectRef(buildLocation, "connect-common")
val connect = ProjectRef(buildLocation, "connect")
val connectClient = ProjectRef(buildLocation, "connect-client-jvm")
val connectProjects@Seq(connectCommon, connect, connectClient, connectShims) =
Seq("connect-common", "connect", "connect-client-jvm", "connect-shims")
.map(ProjectRef(buildLocation, _))

val allProjects@Seq(
core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, tags, sketch, kvstore,
commonUtils, sqlApi, variant, _*
commonUtils, variant, _*
) = Seq(
"core", "graphx", "mllib", "mllib-local", "repl", "network-common", "network-shuffle", "launcher", "unsafe",
"tags", "sketch", "kvstore", "common-utils", "sql-api", "variant"
).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects ++ Seq(connectCommon, connect, connectClient)
"tags", "sketch", "kvstore", "common-utils", "variant"
).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects ++ connectProjects

val optionallyEnabledProjects@Seq(kubernetes, yarn,
sparkGangliaLgpl, streamingKinesisAsl,
Expand Down Expand Up @@ -360,7 +360,7 @@ object SparkBuild extends PomBuild {
/* Enable shared settings on all projects */
(allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ copyJarsProjects ++ Seq(spark, tools))
.foreach(enable(sharedSettings ++ DependencyOverrides.settings ++
ExcludedDependencies.settings ++ Checkstyle.settings))
ExcludedDependencies.settings ++ Checkstyle.settings ++ ExcludeShims.settings))

/* Enable tests settings for all projects except examples, assembly and tools */
(allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings))
Expand All @@ -369,7 +369,7 @@ object SparkBuild extends PomBuild {
Seq(
spark, hive, hiveThriftServer, repl, networkCommon, networkShuffle, networkYarn,
unsafe, tags, tokenProviderKafka010, sqlKafka010, connectCommon, connect, connectClient,
variant
variant, connectShims
).contains(x)
}

Expand Down Expand Up @@ -1087,6 +1087,36 @@ object ExcludedDependencies {
)
}

/**
* This excludes the spark-connect-shims module from a module when it is not part of the connect
* client dependencies.
*/
object ExcludeShims {
val shimmedProjects = Set("spark-sql-api", "spark-connect-common", "spark-connect-client-jvm")
val classPathFilter = TaskKey[Classpath => Classpath]("filter for classpath")
lazy val settings = Seq(
classPathFilter := {
if (!shimmedProjects(moduleName.value)) {
cp => cp.filterNot(_.data.name.contains("spark-connect-shims"))
} else {
identity _
}
},
Compile / internalDependencyClasspath :=
classPathFilter.value((Compile / internalDependencyClasspath).value),
Compile / internalDependencyAsJars :=
classPathFilter.value((Compile / internalDependencyAsJars).value),
Runtime / internalDependencyClasspath :=
classPathFilter.value((Runtime / internalDependencyClasspath).value),
Runtime / internalDependencyAsJars :=
classPathFilter.value((Runtime / internalDependencyAsJars).value),
Test / internalDependencyClasspath :=
classPathFilter.value((Test / internalDependencyClasspath).value),
Test / internalDependencyAsJars :=
classPathFilter.value((Test / internalDependencyAsJars).value),
)
}

/**
* Project to pull previous artifacts of Spark for generating Mima excludes.
*/
Expand Down Expand Up @@ -1456,10 +1486,12 @@ object SparkUnidoc extends SharedUnidocSettings {
lazy val settings = baseSettings ++ Seq(
(ScalaUnidoc / unidoc / unidocProjectFilter) :=
inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, kubernetes,
yarn, tags, streamingKafka010, sqlKafka010, connectCommon, connect, connectClient, protobuf),
yarn, tags, streamingKafka010, sqlKafka010, connectCommon, connect, connectClient,
connectShims, protobuf),
(JavaUnidoc / unidoc / unidocProjectFilter) :=
inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, kubernetes,
yarn, tags, streamingKafka010, sqlKafka010, connectCommon, connect, connectClient, protobuf),
yarn, tags, streamingKafka010, sqlKafka010, connectCommon, connect, connectClient,
connectShims, protobuf),
)
}

Expand Down
6 changes: 6 additions & 0 deletions sql/api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@
<artifactId>spark-sketch_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-connect-shims_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.json4s</groupId>
<artifactId>json4s-jackson_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import scala.jdk.CollectionConverters._
import _root_.java.util

import org.apache.spark.annotation.Stable
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, SparkCharVarcharUtils}
Expand Down Expand Up @@ -309,6 +311,38 @@ abstract class DataFrameReader {
*/
def json(jsonDataset: DS[String]): Dataset[Row]

/**
* Loads a `JavaRDD[String]` storing JSON objects (<a href="http://jsonlines.org/">JSON Lines
* text format or newline-delimited JSON</a>) and returns the result as a `DataFrame`.
*
* Unless the schema is specified using `schema` function, this function goes through the input
* once to determine the input schema.
*
* @note
* this method is not supported in Spark Connect.
* @param jsonRDD
* input RDD with one JSON object per record
* @since 1.4.0
*/
@deprecated("Use json(Dataset[String]) instead.", "2.2.0")
def json(jsonRDD: JavaRDD[String]): DS[Row]

/**
* Loads an `RDD[String]` storing JSON objects (<a href="http://jsonlines.org/">JSON Lines text
* format or newline-delimited JSON</a>) and returns the result as a `DataFrame`.
*
* Unless the schema is specified using `schema` function, this function goes through the input
* once to determine the input schema.
*
* @note
* this method is not supported in Spark Connect.
* @param jsonRDD
* input RDD with one JSON object per record
* @since 1.4.0
*/
@deprecated("Use json(Dataset[String]) instead.", "2.2.0")
def json(jsonRDD: RDD[String]): DS[Row]

/**
* Loads a CSV file and returns the result as a `DataFrame`. See the documentation on the other
* overloaded `csv()` method for more details.
Expand Down
32 changes: 32 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ import scala.reflect.runtime.universe.TypeTag
import _root_.java.util

import org.apache.spark.annotation.{DeveloperApi, Stable}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.function.{FilterFunction, FlatMapFunction, ForeachFunction, ForeachPartitionFunction, MapFunction, MapPartitionsFunction, ReduceFunction}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{functions, AnalysisException, Column, DataFrameWriter, DataFrameWriterV2, Encoder, MergeIntoWriter, Observation, Row, TypedColumn}
import org.apache.spark.sql.internal.{ToScalaUDF, UDFAdaptors}
import org.apache.spark.sql.types.{Metadata, StructType}
Expand Down Expand Up @@ -3098,4 +3100,34 @@ abstract class Dataset[T] extends Serializable {
* @since 1.6.0
*/
def write: DataFrameWriter[T]

/**
* Represents the content of the Dataset as an `RDD` of `T`.
*
* @note
* this method is not supported in Spark Connect.
* @group basic
* @since 1.6.0
*/
def rdd: RDD[T]

/**
* Returns the content of the Dataset as a `JavaRDD` of `T`s.
*
* @note
* this method is not supported in Spark Connect.
* @group basic
* @since 1.6.0
*/
def toJavaRDD: JavaRDD[T]

/**
* Returns the content of the Dataset as a `JavaRDD` of `T`s.
*
* @note
* this method is not supported in Spark Connect.
* @group basic
* @since 1.6.0
*/
def javaRDD: JavaRDD[T] = toJavaRDD
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.TypeTag

import _root_.java

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{ColumnName, DatasetHolder, Encoder, Encoders}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
Expand Down Expand Up @@ -278,6 +279,14 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits with Serializable {
new DatasetHolder(session.createDataset(s).asInstanceOf[DS[T]])
}

/**
* Creates a [[Dataset]] from an RDD.
*
* @since 1.6.0
*/
implicit def rddToDatasetHolder[T: Encoder](rdd: RDD[T]): DatasetHolder[T, DS] =
new DatasetHolder(session.createDataset(rdd).asInstanceOf[DS[T]])

/**
* An implicit conversion that turns a Scala `Symbol` into a [[org.apache.spark.sql.Column]].
* @since 1.3.0
Expand Down
Loading