, and Spark does not support statistics collection on this column type."
diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
index fe738f4149..825d9ce779 100644
--- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
+++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
@@ -25,7 +25,7 @@ import scala.concurrent.Future
import org.apache.spark.executor.ExecutorMetrics
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.internal.config.Network
-import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcCallContext, RpcEnv}
+import org.apache.spark.rpc.{IsolatedThreadSafeRpcEndpoint, RpcCallContext, RpcEnv}
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
@@ -65,7 +65,7 @@ private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean)
* Lives in the driver to receive heartbeats from executors..
*/
private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock)
- extends SparkListener with IsolatedRpcEndpoint with Logging {
+ extends SparkListener with IsolatedThreadSafeRpcEndpoint with Logging {
def this(sc: SparkContext) = {
this(sc, new SystemClock)
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 783cf47df1..73acfedd8b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -20,6 +20,7 @@ package org.apache.spark.deploy
import java.io._
import java.lang.reflect.{InvocationTargetException, UndeclaredThrowableException}
import java.net.{URI, URL}
+import java.nio.file.Files
import java.security.PrivilegedExceptionAction
import java.text.ParseException
import java.util.{ServiceLoader, UUID}
@@ -383,43 +384,55 @@ private[spark] class SparkSubmit extends Logging {
}.orNull
if (isKubernetesClusterModeDriver) {
- // Replace with the downloaded local jar path to avoid propagating hadoop compatible uris.
- // Executors will get the jars from the Spark file server.
- // Explicitly download the related files here
- args.jars = localJars
- val filesLocalFiles = Option(args.files).map {
- downloadFileList(_, targetDir, sparkConf, hadoopConf)
- }.orNull
- val archiveLocalFiles = Option(args.archives).map { uris =>
+ // SPARK-33748: this mimics the behaviour of Yarn cluster mode. If the driver is running
+ // in cluster mode, the archives should be available in the driver's current working
+ // directory too.
+ // SPARK-33782 : This downloads all the files , jars , archiveFiles and pyfiles to current
+ // working directory
+ def downloadResourcesToCurrentDirectory(uris: String, isArchive: Boolean = false):
+ String = {
val resolvedUris = Utils.stringToSeq(uris).map(Utils.resolveURI)
- val localArchives = downloadFileList(
+ val localResources = downloadFileList(
resolvedUris.map(
UriBuilder.fromUri(_).fragment(null).build().toString).mkString(","),
targetDir, sparkConf, hadoopConf)
-
- // SPARK-33748: this mimics the behaviour of Yarn cluster mode. If the driver is running
- // in cluster mode, the archives should be available in the driver's current working
- // directory too.
- Utils.stringToSeq(localArchives).map(Utils.resolveURI).zip(resolvedUris).map {
- case (localArchive, resolvedUri) =>
- val source = new File(localArchive.getPath)
+ Utils.stringToSeq(localResources).map(Utils.resolveURI).zip(resolvedUris).map {
+ case (localResources, resolvedUri) =>
+ val source = new File(localResources.getPath)
val dest = new File(
".",
if (resolvedUri.getFragment != null) resolvedUri.getFragment else source.getName)
logInfo(
- s"Unpacking an archive $resolvedUri " +
+ s"Files $resolvedUri " +
s"from ${source.getAbsolutePath} to ${dest.getAbsolutePath}")
Utils.deleteRecursively(dest)
- Utils.unpack(source, dest)
-
+ if (isArchive) {
+ Utils.unpack(source, dest)
+ } else {
+ Files.copy(source.toPath, dest.toPath)
+ }
// Keep the URIs of local files with the given fragments.
UriBuilder.fromUri(
- localArchive).fragment(resolvedUri.getFragment).build().toString
+ localResources).fragment(resolvedUri.getFragment).build().toString
}.mkString(",")
+ }
+
+ val filesLocalFiles = Option(args.files).map {
+ downloadResourcesToCurrentDirectory(_)
+ }.orNull
+ val jarsLocalJars = Option(args.jars).map {
+ downloadResourcesToCurrentDirectory(_)
+ }.orNull
+ val archiveLocalFiles = Option(args.archives).map {
+ downloadResourcesToCurrentDirectory(_, true)
+ }.orNull
+ val pyLocalFiles = Option(args.pyFiles).map {
+ downloadResourcesToCurrentDirectory(_)
}.orNull
args.files = filesLocalFiles
args.archives = archiveLocalFiles
- args.pyFiles = localPyFiles
+ args.pyFiles = pyLocalFiles
+ args.jars = jarsLocalJars
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index a94e63656e..d8f33a0612 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -35,6 +35,8 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.deploy.worker.WorkerWatcher
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
+import org.apache.spark.network.netty.SparkTransportConf
+import org.apache.spark.network.util.NettyUtils
import org.apache.spark.resource.ResourceInformation
import org.apache.spark.resource.ResourceProfile
import org.apache.spark.resource.ResourceProfile._
@@ -54,7 +56,7 @@ private[spark] class CoarseGrainedExecutorBackend(
env: SparkEnv,
resourcesFileOpt: Option[String],
resourceProfile: ResourceProfile)
- extends IsolatedRpcEndpoint with ExecutorBackend with Logging {
+ extends IsolatedThreadSafeRpcEndpoint with ExecutorBackend with Logging {
import CoarseGrainedExecutorBackend._
@@ -85,7 +87,8 @@ private[spark] class CoarseGrainedExecutorBackend(
logInfo("Connecting to driver: " + driverUrl)
try {
- if (PlatformDependent.directBufferPreferred() &&
+ val shuffleClientTransportConf = SparkTransportConf.fromSparkConf(env.conf, "shuffle")
+ if (NettyUtils.preferDirectBufs(shuffleClientTransportConf) &&
PlatformDependent.maxDirectMemory() < env.conf.get(MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM)) {
throw new SparkException(s"Netty direct memory should at least be bigger than " +
s"'${MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM.key}', but got " +
diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriterUtils.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriterUtils.scala
index 657842c620..6ba6713b69 100644
--- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriterUtils.scala
+++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriterUtils.scala
@@ -47,11 +47,22 @@ object SparkHadoopWriterUtils {
* @return a job ID
*/
def createJobID(time: Date, id: Int): JobID = {
+ val jobTrackerID = createJobTrackerID(time)
+ createJobID(jobTrackerID, id)
+ }
+
+ /**
+ * Create a job ID.
+ *
+ * @param jobTrackerID unique job track id
+ * @param id job number
+ * @return a job ID
+ */
+ def createJobID(jobTrackerID: String, id: Int): JobID = {
if (id < 0) {
throw new IllegalArgumentException("Job number is negative")
}
- val jobtrackerID = createJobTrackerID(time)
- new JobID(jobtrackerID, id)
+ new JobID(jobTrackerID, id)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/internal/plugin/PluginEndpoint.scala b/core/src/main/scala/org/apache/spark/internal/plugin/PluginEndpoint.scala
index 9a59b6bf67..989ef8f2ed 100644
--- a/core/src/main/scala/org/apache/spark/internal/plugin/PluginEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/internal/plugin/PluginEndpoint.scala
@@ -19,14 +19,14 @@ package org.apache.spark.internal.plugin
import org.apache.spark.api.plugin.DriverPlugin
import org.apache.spark.internal.Logging
-import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcCallContext, RpcEnv}
+import org.apache.spark.rpc.{IsolatedThreadSafeRpcEndpoint, RpcCallContext, RpcEnv}
case class PluginMessage(pluginName: String, message: AnyRef)
private class PluginEndpoint(
plugins: Map[String, DriverPlugin],
override val rpcEnv: RpcEnv)
- extends IsolatedRpcEndpoint with Logging {
+ extends IsolatedThreadSafeRpcEndpoint with Logging {
override def receive: PartialFunction[Any, Unit] = {
case PluginMessage(pluginName, message) =>
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
index 4728759e7f..627f17f886 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
@@ -153,12 +153,25 @@ private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint
private[spark] trait IsolatedRpcEndpoint extends RpcEndpoint {
/**
- * How many threads to use for delivering messages. By default, use a single thread.
+ * How many threads to use for delivering messages.
*
* Note that requesting more than one thread means that the endpoint should be able to handle
* messages arriving from many threads at once, and all the things that entails (including
* messages being delivered to the endpoint out of order).
*/
- def threadCount(): Int = 1
+ def threadCount(): Int
+
+}
+
+/**
+ * An endpoint that uses a dedicated thread pool for delivering messages and
+ * ensured to be thread-safe.
+ */
+private[spark] trait IsolatedThreadSafeRpcEndpoint extends IsolatedRpcEndpoint {
+
+ /**
+ * Limit the threadCount to 1 so that messages are ensured to be handled in a thread-safe way.
+ */
+ final def threadCount(): Int = 1
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 225dd1d75b..2d3cf2ebc4 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -127,7 +127,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
ThreadUtils.newDaemonSingleThreadScheduledExecutor("cleanup-decommission-execs")
}
- class DriverEndpoint extends IsolatedRpcEndpoint with Logging {
+ class DriverEndpoint extends IsolatedThreadSafeRpcEndpoint with Logging {
override val rpcEnv: RpcEnv = CoarseGrainedSchedulerBackend.this.rpcEnv
diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala
index ea028dfd11..287bf2165c 100644
--- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala
+++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala
@@ -74,7 +74,7 @@ private[spark] class AppStatusListener(
private val liveStages = new ConcurrentHashMap[(Int, Int), LiveStage]()
private val liveJobs = new HashMap[Int, LiveJob]()
private[spark] val liveExecutors = new HashMap[String, LiveExecutor]()
- private val deadExecutors = new HashMap[String, LiveExecutor]()
+ private[spark] val deadExecutors = new HashMap[String, LiveExecutor]()
private val liveTasks = new HashMap[Long, LiveTask]()
private val liveRDDs = new HashMap[Int, LiveRDD]()
private val pools = new HashMap[String, SchedulerPool]()
@@ -674,22 +674,30 @@ private[spark] class AppStatusListener(
delta
}.orNull
- val (completedDelta, failedDelta, killedDelta) = event.reason match {
+ // SPARK-41187: For `SparkListenerTaskEnd` with `Resubmitted` reason, which is raised by
+ // executor lost, it can lead to negative `LiveStage.activeTasks` since there's no
+ // corresponding `SparkListenerTaskStart` event for each of them. The negative activeTasks
+ // will make the stage always remains in the live stage list as it can never meet the
+ // condition activeTasks == 0. This in turn causes the dead executor to never be retained
+ // if that live stage's submissionTime is less than the dead executor's removeTime.
+ val (completedDelta, failedDelta, killedDelta, activeDelta) = event.reason match {
case Success =>
- (1, 0, 0)
+ (1, 0, 0, 1)
case _: TaskKilled =>
- (0, 0, 1)
+ (0, 0, 1, 1)
case _: TaskCommitDenied =>
- (0, 0, 1)
+ (0, 0, 1, 1)
+ case _ @ Resubmitted =>
+ (0, 1, 0, 0)
case _ =>
- (0, 1, 0)
+ (0, 1, 0, 1)
}
Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage =>
if (metricsDelta != null) {
stage.metrics = LiveEntityHelpers.addMetrics(stage.metrics, metricsDelta)
}
- stage.activeTasks -= 1
+ stage.activeTasks -= activeDelta
stage.completedTasks += completedDelta
if (completedDelta > 0) {
stage.completedIndices.add(event.taskInfo.index)
@@ -699,7 +707,7 @@ private[spark] class AppStatusListener(
if (killedDelta > 0) {
stage.killedSummary = killedTasksSummary(event.reason, stage.killedSummary)
}
- stage.activeTasksPerExecutor(event.taskInfo.executorId) -= 1
+ stage.activeTasksPerExecutor(event.taskInfo.executorId) -= activeDelta
stage.peakExecutorMetrics.compareAndUpdatePeakValues(event.taskExecutorMetrics)
stage.executorSummary(event.taskInfo.executorId).peakExecutorMetrics
@@ -718,7 +726,7 @@ private[spark] class AppStatusListener(
// Store both stage ID and task index in a single long variable for tracking at job level.
val taskIndex = (event.stageId.toLong << Integer.SIZE) | event.taskInfo.index
stage.jobs.foreach { job =>
- job.activeTasks -= 1
+ job.activeTasks -= activeDelta
job.completedTasks += completedDelta
if (completedDelta > 0) {
job.completedIndices.add(taskIndex)
@@ -774,7 +782,7 @@ private[spark] class AppStatusListener(
}
liveExecutors.get(event.taskInfo.executorId).foreach { exec =>
- exec.activeTasks -= 1
+ exec.activeTasks -= activeDelta
exec.completedTasks += completedDelta
exec.failedTasks += failedDelta
exec.totalDuration += event.taskInfo.duration
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index d5fde96b14..1067ee1556 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -637,9 +637,14 @@ private[spark] class BlockManager(
def reregister(): Unit = {
// TODO: We might need to rate limit re-registering.
logInfo(s"BlockManager $blockManagerId re-registering with master")
- master.registerBlockManager(blockManagerId, diskBlockManager.localDirsString, maxOnHeapMemory,
- maxOffHeapMemory, storageEndpoint)
- reportAllBlocks()
+ val id = master.registerBlockManager(blockManagerId, diskBlockManager.localDirsString,
+ maxOnHeapMemory, maxOffHeapMemory, storageEndpoint, isReRegister = true)
+ if (id.executorId != BlockManagerId.INVALID_EXECUTOR_ID) {
+ reportAllBlocks()
+ } else {
+ logError("Exiting executor due to block manager re-registration failure")
+ System.exit(-1)
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
index c6a4457d8f..12e416bbb3 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
@@ -147,4 +147,6 @@ private[spark] object BlockManagerId {
}
private[spark] val SHUFFLE_MERGER_IDENTIFIER = "shuffle-push-merger"
+
+ private[spark] val INVALID_EXECUTOR_ID = "invalid"
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index 40008e6afb..0ee3dc249d 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -74,11 +74,25 @@ class BlockManagerMaster(
localDirs: Array[String],
maxOnHeapMemSize: Long,
maxOffHeapMemSize: Long,
- storageEndpoint: RpcEndpointRef): BlockManagerId = {
+ storageEndpoint: RpcEndpointRef,
+ isReRegister: Boolean = false): BlockManagerId = {
logInfo(s"Registering BlockManager $id")
val updatedId = driverEndpoint.askSync[BlockManagerId](
- RegisterBlockManager(id, localDirs, maxOnHeapMemSize, maxOffHeapMemSize, storageEndpoint))
- logInfo(s"Registered BlockManager $updatedId")
+ RegisterBlockManager(
+ id,
+ localDirs,
+ maxOnHeapMemSize,
+ maxOffHeapMemSize,
+ storageEndpoint,
+ isReRegister
+ )
+ )
+ if (updatedId.executorId == BlockManagerId.INVALID_EXECUTOR_ID) {
+ assert(isReRegister, "Got invalid executor id from non re-register case")
+ logInfo(s"Re-register BlockManager $id failed")
+ } else {
+ logInfo(s"Registered BlockManager $updatedId")
+ }
updatedId
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
index adeb507941..d30272c51b 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
@@ -33,7 +33,7 @@ import org.apache.spark.{MapOutputTrackerMaster, SparkConf}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.network.shuffle.ExternalBlockStoreClient
-import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcCallContext, RpcEndpointRef, RpcEnv}
+import org.apache.spark.rpc.{IsolatedThreadSafeRpcEndpoint, RpcCallContext, RpcEndpointRef, RpcEnv}
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.{CoarseGrainedClusterMessages, CoarseGrainedSchedulerBackend}
import org.apache.spark.shuffle.ShuffleManager
@@ -41,8 +41,8 @@ import org.apache.spark.storage.BlockManagerMessages._
import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils}
/**
- * BlockManagerMasterEndpoint is an [[IsolatedRpcEndpoint]] on the master node to track statuses
- * of all the storage endpoints' block managers.
+ * BlockManagerMasterEndpoint is an [[IsolatedThreadSafeRpcEndpoint]] on the master node to
+ * track statuses of all the storage endpoints' block managers.
*/
private[spark]
class BlockManagerMasterEndpoint(
@@ -55,7 +55,7 @@ class BlockManagerMasterEndpoint(
mapOutputTracker: MapOutputTrackerMaster,
shuffleManager: ShuffleManager,
isDriver: Boolean)
- extends IsolatedRpcEndpoint with Logging {
+ extends IsolatedThreadSafeRpcEndpoint with Logging {
// Mapping from executor id to the block manager's local disk directories.
private val executorIdToLocalDirs =
@@ -117,8 +117,10 @@ class BlockManagerMasterEndpoint(
RpcUtils.makeDriverRef(CoarseGrainedSchedulerBackend.ENDPOINT_NAME, conf, rpcEnv)
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
- case RegisterBlockManager(id, localDirs, maxOnHeapMemSize, maxOffHeapMemSize, endpoint) =>
- context.reply(register(id, localDirs, maxOnHeapMemSize, maxOffHeapMemSize, endpoint))
+ case RegisterBlockManager(
+ id, localDirs, maxOnHeapMemSize, maxOffHeapMemSize, endpoint, isReRegister) =>
+ context.reply(
+ register(id, localDirs, maxOnHeapMemSize, maxOffHeapMemSize, endpoint, isReRegister))
case _updateBlockInfo @
UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) =>
@@ -572,7 +574,8 @@ class BlockManagerMasterEndpoint(
localDirs: Array[String],
maxOnHeapMemSize: Long,
maxOffHeapMemSize: Long,
- storageEndpoint: RpcEndpointRef): BlockManagerId = {
+ storageEndpoint: RpcEndpointRef,
+ isReRegister: Boolean): BlockManagerId = {
// the dummy id is not expected to contain the topology information.
// we get that info here and respond back with a more fleshed out block manager id
val id = BlockManagerId(
@@ -583,7 +586,12 @@ class BlockManagerMasterEndpoint(
val time = System.currentTimeMillis()
executorIdToLocalDirs.put(id.executorId, localDirs)
- if (!blockManagerInfo.contains(id)) {
+ // SPARK-41360: For the block manager re-registration, we should only allow it when
+ // the executor is recognized as active by the scheduler backend. Otherwise, this kind
+ // of re-registration from the terminating/stopped executor is meaningless and harmful.
+ lazy val isExecutorAlive =
+ driverEndpoint.askSync[Boolean](CoarseGrainedClusterMessages.IsExecutorAlive(id.executorId))
+ if (!blockManagerInfo.contains(id) && (!isReRegister || isExecutorAlive)) {
blockManagerIdByExecutor.get(id.executorId) match {
case Some(oldId) =>
// A block manager of the same executor already exists, so remove it (assumed dead)
@@ -616,10 +624,29 @@ class BlockManagerMasterEndpoint(
if (pushBasedShuffleEnabled) {
addMergerLocation(id)
}
+ listenerBus.post(SparkListenerBlockManagerAdded(time, id,
+ maxOnHeapMemSize + maxOffHeapMemSize, Some(maxOnHeapMemSize), Some(maxOffHeapMemSize)))
}
- listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxOnHeapMemSize + maxOffHeapMemSize,
- Some(maxOnHeapMemSize), Some(maxOffHeapMemSize)))
- id
+ val updatedId = if (isReRegister && !isExecutorAlive) {
+ assert(!blockManagerInfo.contains(id),
+ "BlockManager re-registration shouldn't succeed when the executor is lost")
+
+ logInfo(s"BlockManager ($id) re-registration is rejected since " +
+ s"the executor (${id.executorId}) has been lost")
+
+ // Use "invalid" as the return executor id to indicate the block manager that
+ // re-registration failed. It's a bit hacky but fine since the returned block
+ // manager id won't be accessed in the case of re-registration. And we'll use
+ // this "invalid" executor id to print better logs and avoid blocks reporting.
+ BlockManagerId(
+ BlockManagerId.INVALID_EXECUTOR_ID,
+ id.host,
+ id.port,
+ id.topologyInfo)
+ } else {
+ id
+ }
+ updatedId
}
private def updateShuffleBlockInfo(blockId: BlockId, blockManagerId: BlockManagerId)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index afe416a55e..e047b61fcb 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -63,7 +63,8 @@ private[spark] object BlockManagerMessages {
localDirs: Array[String],
maxOnHeapMemSize: Long,
maxOffHeapMemSize: Long,
- sender: RpcEndpointRef)
+ sender: RpcEndpointRef,
+ isReRegister: Boolean)
extends ToBlockManagerMaster
case class UpdateBlockInfo(
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala
index 54a72568b1..71c7a4de4c 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala
@@ -21,7 +21,7 @@ import scala.concurrent.{ExecutionContext, Future}
import org.apache.spark.{MapOutputTracker, SparkEnv}
import org.apache.spark.internal.Logging
-import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcCallContext, RpcEnv}
+import org.apache.spark.rpc.{IsolatedThreadSafeRpcEndpoint, RpcCallContext, RpcEnv}
import org.apache.spark.storage.BlockManagerMessages._
import org.apache.spark.util.{ThreadUtils, Utils}
@@ -34,7 +34,7 @@ class BlockManagerStorageEndpoint(
override val rpcEnv: RpcEnv,
blockManager: BlockManager,
mapOutputTracker: MapOutputTracker)
- extends IsolatedRpcEndpoint with Logging {
+ extends IsolatedThreadSafeRpcEndpoint with Logging {
private val asyncThreadPool =
ThreadUtils.newDaemonCachedThreadPool("block-manager-storage-async-thread-pool", 100)
diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
index 7a08de9c18..27198039fd 100644
--- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
@@ -32,7 +32,7 @@ import org.apache.logging.log4j.core.{LogEvent, Logger, LoggerContext}
import org.apache.logging.log4j.core.appender.AbstractAppender
import org.apache.logging.log4j.core.config.Property
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, BeforeAndAfterEach, Failed, Outcome}
-import org.scalatest.funsuite.AnyFunSuite
+import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite
import org.apache.spark.deploy.LocalSparkCluster
import org.apache.spark.internal.Logging
@@ -64,7 +64,7 @@ import org.apache.spark.util.{AccumulatorContext, Utils}
* }
*/
abstract class SparkFunSuite
- extends AnyFunSuite
+ extends AnyFunSuite // scalastyle:ignore funsuite
with BeforeAndAfterAll
with BeforeAndAfterEach
with ThreadAudit
diff --git a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala
index c15ae9504c..64703b0b04 100644
--- a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala
@@ -147,6 +147,18 @@ class SparkThrowableSuite extends SparkFunSuite {
assert(rereadErrorClassToInfoMap == errorReader.errorInfoMap)
}
+ test("Error class names should contain only capital letters, numbers and underscores") {
+ val allowedChars = "[A-Z0-9_]*"
+ errorReader.errorInfoMap.foreach { e =>
+ assert(e._1.matches(allowedChars), s"Error class: ${e._1} is invalid")
+ e._2.subClass.map { s =>
+ s.keys.foreach { k =>
+ assert(k.matches(allowedChars), s"Error sub-class: $k is invalid")
+ }
+ }
+ }
+ }
+
test("Check if error class is missing") {
val ex1 = intercept[SparkException] {
getMessage("", Map.empty[String, String])
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index 6bd3a49576..76311d0ab1 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -486,6 +486,41 @@ class SparkSubmitSuite
conf.get("spark.kubernetes.driver.container.image") should be ("bar")
}
+ test("SPARK-33782: handles k8s files download to current directory") {
+ val clArgs = Seq(
+ "--deploy-mode", "client",
+ "--proxy-user", "test.user",
+ "--master", "k8s://host:port",
+ "--executor-memory", "5g",
+ "--class", "org.SomeClass",
+ "--driver-memory", "4g",
+ "--conf", "spark.kubernetes.namespace=spark",
+ "--conf", "spark.kubernetes.driver.container.image=bar",
+ "--conf", "spark.kubernetes.submitInDriver=true",
+ "--files", "src/test/resources/test_metrics_config.properties",
+ "--py-files", "src/test/resources/test_metrics_system.properties",
+ "--archives", "src/test/resources/log4j2.properties",
+ "--jars", "src/test/resources/TestUDTF.jar",
+ "/home/thejar.jar",
+ "arg1")
+ val appArgs = new SparkSubmitArguments(clArgs)
+ val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs)
+ conf.get("spark.master") should be ("k8s://https://host:port")
+ conf.get("spark.executor.memory") should be ("5g")
+ conf.get("spark.driver.memory") should be ("4g")
+ conf.get("spark.kubernetes.namespace") should be ("spark")
+ conf.get("spark.kubernetes.driver.container.image") should be ("bar")
+
+ Files.exists(Paths.get("test_metrics_config.properties")) should be (true)
+ Files.exists(Paths.get("test_metrics_system.properties")) should be (true)
+ Files.exists(Paths.get("log4j2.properties")) should be (true)
+ Files.exists(Paths.get("TestUDTF.jar")) should be (true)
+ Files.delete(Paths.get("test_metrics_config.properties"))
+ Files.delete(Paths.get("test_metrics_system.properties"))
+ Files.delete(Paths.get("log4j2.properties"))
+ Files.delete(Paths.get("TestUDTF.jar"))
+ }
+
/**
* Helper function for testing main class resolution on remote JAR files.
*
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index c70dde79b3..6e5eb77322 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -962,7 +962,8 @@ abstract class RpcEnvSuite extends SparkFunSuite {
val singleThreadedEnv = createRpcEnv(
new SparkConf().set(Network.RPC_NETTY_DISPATCHER_NUM_THREADS, 1), "singleThread", 0)
try {
- val blockingEndpoint = singleThreadedEnv.setupEndpoint("blocking", new IsolatedRpcEndpoint {
+ val blockingEndpoint = singleThreadedEnv
+ .setupEndpoint("blocking", new IsolatedThreadSafeRpcEndpoint {
override val rpcEnv: RpcEnv = singleThreadedEnv
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala
index 24a8a6844f..5d0c25aa86 100644
--- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala
@@ -1849,6 +1849,68 @@ abstract class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter
checkInfoPopulated(listener, logUrlMap, processId)
}
+ test("SPARK-41187: Stage should be removed from liveStages to avoid deadExecutors accumulated") {
+
+ val listener = new AppStatusListener(store, conf, true)
+
+ listener.onExecutorAdded(createExecutorAddedEvent(1))
+ listener.onExecutorAdded(createExecutorAddedEvent(2))
+ val stage = new StageInfo(1, 0, "stage", 4, Nil, Nil, "details",
+ resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID)
+ listener.onJobStart(SparkListenerJobStart(1, time, Seq(stage), null))
+
+ time += 1
+ stage.submissionTime = Some(time)
+ listener.onStageSubmitted(SparkListenerStageSubmitted(stage, new Properties()))
+
+ val tasks = createTasks(2, Array("1", "2"))
+ tasks.foreach { task =>
+ listener.onTaskStart(SparkListenerTaskStart(stage.stageId, stage.attemptNumber, task))
+ }
+
+ time += 1
+ tasks(0).markFinished(TaskState.FINISHED, time)
+ listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType",
+ Success, tasks(0), new ExecutorMetrics, null))
+
+ // executor lost, success task will be resubmitted
+ time += 1
+ listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType",
+ Resubmitted, tasks(0), new ExecutorMetrics, null))
+
+ // executor lost, running task will be failed and rerun
+ time += 1
+ tasks(1).markFinished(TaskState.FAILED, time)
+ listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType",
+ ExecutorLostFailure("1", true, Some("Lost executor")), tasks(1), new ExecutorMetrics,
+ null))
+
+ tasks.foreach { task =>
+ listener.onTaskStart(SparkListenerTaskStart(stage.stageId, stage.attemptNumber, task))
+ }
+
+ time += 1
+ tasks(0).markFinished(TaskState.FINISHED, time)
+ listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType",
+ Success, tasks(0), new ExecutorMetrics, null))
+
+ time += 1
+ tasks(1).markFinished(TaskState.FINISHED, time)
+ listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType",
+ Success, tasks(1), new ExecutorMetrics, null))
+
+ listener.onStageCompleted(SparkListenerStageCompleted(stage))
+ time += 1
+ listener.onJobEnd(SparkListenerJobEnd(1, time, JobSucceeded ))
+
+ time += 1
+ listener.onExecutorRemoved(SparkListenerExecutorRemoved(time, "1", "Test"))
+ time += 1
+ listener.onExecutorRemoved(SparkListenerExecutorRemoved(time, "2", "Test"))
+
+ assert(listener.deadExecutors.size === 0)
+ }
+
private def key(stage: StageInfo): Array[Int] = Array(stage.stageId, stage.attemptNumber)
private def check[T: ClassTag](key: Any)(fn: T => Unit): Unit = {
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index c8914761b9..842b66193f 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -295,7 +295,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe
eventually(timeout(5.seconds)) {
// make sure both bm1 and bm2 are registered at driver side BlockManagerMaster
verify(master, times(2))
- .registerBlockManager(mc.any(), mc.any(), mc.any(), mc.any(), mc.any())
+ .registerBlockManager(mc.any(), mc.any(), mc.any(), mc.any(), mc.any(), mc.any())
assert(driverEndpoint.askSync[Boolean](
CoarseGrainedClusterMessages.IsExecutorAlive(bm1Id.executorId)))
assert(driverEndpoint.askSync[Boolean](
@@ -361,6 +361,44 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe
master.removeShuffle(0, true)
}
+ test("SPARK-41360: Avoid block manager re-registration if the executor has been lost") {
+ // Set up a DriverEndpoint which always returns isExecutorAlive=false
+ rpcEnv.setupEndpoint(CoarseGrainedSchedulerBackend.ENDPOINT_NAME,
+ new RpcEndpoint {
+ override val rpcEnv: RpcEnv = BlockManagerSuite.this.rpcEnv
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case CoarseGrainedClusterMessages.RegisterExecutor(executorId, _, _, _, _, _, _, _) =>
+ context.reply(true)
+ case CoarseGrainedClusterMessages.IsExecutorAlive(executorId) =>
+ // always return false
+ context.reply(false)
+ }
+ }
+ )
+
+ // Set up a block manager endpoint and endpoint reference
+ val bmRef = rpcEnv.setupEndpoint(s"bm-0", new RpcEndpoint {
+ override val rpcEnv: RpcEnv = BlockManagerSuite.this.rpcEnv
+
+ private def reply[T](context: RpcCallContext, response: T): Unit = {
+ context.reply(response)
+ }
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case RemoveRdd(_) => reply(context, 1)
+ case RemoveBroadcast(_, _) => reply(context, 1)
+ case RemoveShuffle(_) => reply(context, true)
+ }
+ })
+ val bmId = BlockManagerId(s"exec-0", "localhost", 1234, None)
+ // Register the block manager with isReRegister = true
+ val updatedId = master.registerBlockManager(
+ bmId, Array.empty, 2000, 0, bmRef, isReRegister = true)
+ // The re-registration should fail since the executor is considered as dead by DriverEndpoint
+ assert(updatedId.executorId === BlockManagerId.INVALID_EXECUTOR_ID)
+ }
+
test("StorageLevel object caching") {
val level1 = StorageLevel(false, false, false, 3)
// this should return the same object as level1
@@ -669,6 +707,22 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
+ // Set up a DriverEndpoint which simulates the executor is alive (required by SPARK-41360)
+ rpcEnv.setupEndpoint(CoarseGrainedSchedulerBackend.ENDPOINT_NAME,
+ new RpcEndpoint {
+ override val rpcEnv: RpcEnv = BlockManagerSuite.this.rpcEnv
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case CoarseGrainedClusterMessages.IsExecutorAlive(executorId) =>
+ if (executorId == store.blockManagerId.executorId) {
+ context.reply(true)
+ } else {
+ context.reply(false)
+ }
+ }
+ }
+ )
+
store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
assert(master.getLocations("a1").size > 0, "master was not told about a1")
@@ -2207,7 +2261,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe
}.getMessage
assert(e.contains("TimeoutException"))
verify(master, times(0))
- .registerBlockManager(mc.any(), mc.any(), mc.any(), mc.any(), mc.any())
+ .registerBlockManager(mc.any(), mc.any(), mc.any(), mc.any(), mc.any(), mc.any())
server.close()
}
}
diff --git a/dev/deps/spark-deps-hadoop-2-hive-2.3 b/dev/deps/spark-deps-hadoop-2-hive-2.3
index ad7a8a1a4c..ae7cc9d592 100644
--- a/dev/deps/spark-deps-hadoop-2-hive-2.3
+++ b/dev/deps/spark-deps-hadoop-2-hive-2.3
@@ -101,12 +101,11 @@ hive-shims-common/2.3.9//hive-shims-common-2.3.9.jar
hive-shims-scheduler/2.3.9//hive-shims-scheduler-2.3.9.jar
hive-shims/2.3.9//hive-shims-2.3.9.jar
hive-storage-api/2.7.3//hive-storage-api-2.7.3.jar
-hive-vector-code-gen/2.3.9//hive-vector-code-gen-2.3.9.jar
hk2-api/2.6.1//hk2-api-2.6.1.jar
hk2-locator/2.6.1//hk2-locator-2.6.1.jar
hk2-utils/2.6.1//hk2-utils-2.6.1.jar
htrace-core/3.1.0-incubating//htrace-core-3.1.0-incubating.jar
-httpclient/4.5.13//httpclient-4.5.13.jar
+httpclient/4.5.14//httpclient-4.5.14.jar
httpcore/4.4.14//httpcore-4.4.14.jar
istack-commons-runtime/3.0.8//istack-commons-runtime-3.0.8.jar
ivy/2.5.1//ivy-2.5.1.jar
@@ -261,7 +260,6 @@ threeten-extra/1.7.1//threeten-extra-1.7.1.jar
tink/1.7.0//tink-1.7.0.jar
transaction-api/1.1//transaction-api-1.1.jar
univocity-parsers/2.9.1//univocity-parsers-2.9.1.jar
-velocity/1.5//velocity-1.5.jar
xbean-asm9-shaded/4.22//xbean-asm9-shaded-4.22.jar
xercesImpl/2.12.2//xercesImpl-2.12.2.jar
xml-apis/1.4.01//xml-apis-1.4.01.jar
diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3
index cac2e9f305..f70abedd34 100644
--- a/dev/deps/spark-deps-hadoop-3-hive-2.3
+++ b/dev/deps/spark-deps-hadoop-3-hive-2.3
@@ -89,11 +89,10 @@ hive-shims-common/2.3.9//hive-shims-common-2.3.9.jar
hive-shims-scheduler/2.3.9//hive-shims-scheduler-2.3.9.jar
hive-shims/2.3.9//hive-shims-2.3.9.jar
hive-storage-api/2.7.3//hive-storage-api-2.7.3.jar
-hive-vector-code-gen/2.3.9//hive-vector-code-gen-2.3.9.jar
hk2-api/2.6.1//hk2-api-2.6.1.jar
hk2-locator/2.6.1//hk2-locator-2.6.1.jar
hk2-utils/2.6.1//hk2-utils-2.6.1.jar
-httpclient/4.5.13//httpclient-4.5.13.jar
+httpclient/4.5.14//httpclient-4.5.14.jar
httpcore/4.4.14//httpcore-4.4.14.jar
ini4j/0.5.4//ini4j-0.5.4.jar
istack-commons-runtime/3.0.8//istack-commons-runtime-3.0.8.jar
@@ -248,7 +247,6 @@ threeten-extra/1.7.1//threeten-extra-1.7.1.jar
tink/1.7.0//tink-1.7.0.jar
transaction-api/1.1//transaction-api-1.1.jar
univocity-parsers/2.9.1//univocity-parsers-2.9.1.jar
-velocity/1.5//velocity-1.5.jar
wildfly-openssl/1.0.7.Final//wildfly-openssl-1.0.7.Final.jar
xbean-asm9-shaded/4.22//xbean-asm9-shaded-4.22.jar
xz/1.9//xz-1.9.jar
diff --git a/dev/lint-scala b/dev/lint-scala
index ea3b98464b..48ecf57ef4 100755
--- a/dev/lint-scala
+++ b/dev/lint-scala
@@ -29,14 +29,14 @@ ERRORS=$(./build/mvn \
-Dscalafmt.skip=false \
-Dscalafmt.validateOnly=true \
-Dscalafmt.changedOnly=false \
- -pl connector/connect \
+ -pl connector/connect/server \
2>&1 | grep -e "^Requires formatting" \
)
if test ! -z "$ERRORS"; then
echo -e "The scalafmt check failed on connector/connect at following occurrences:\n\n$ERRORS\n"
echo "Before submitting your change, please make sure to format your code using the following command:"
- echo "./build/mvn -Pscala-2.12 scalafmt:format -Dscalafmt.skip=fase -Dscalafmt.validateOnly=false -Dscalafmt.changedOnly=false -pl connector/connect"
+ echo "./build/mvn -Pscala-2.12 scalafmt:format -Dscalafmt.skip=false -Dscalafmt.validateOnly=false -Dscalafmt.changedOnly=false -pl connector/connect/server"
exit 1
else
echo -e "Scalafmt checks passed."
diff --git a/dev/sparktestsupport/utils.py b/dev/sparktestsupport/utils.py
index 37e023aaa6..a3df038b55 100755
--- a/dev/sparktestsupport/utils.py
+++ b/dev/sparktestsupport/utils.py
@@ -34,19 +34,22 @@ def determine_modules_for_files(filenames):
Given a list of filenames, return the set of modules that contain those files.
If a file is not associated with a more specific submodule, then this method will consider that
file to belong to the 'root' module. `.github` directory is counted only in GitHub Actions,
- and `appveyor.yml` is always ignored because this file is dedicated only to AppVeyor builds.
+ and `appveyor.yml` is always ignored because this file is dedicated only to AppVeyor builds,
+ and `README.md` is always ignored too.
>>> sorted(x.name for x in determine_modules_for_files(["python/pyspark/a.py", "sql/core/foo"]))
['pyspark-core', 'sql']
>>> [x.name for x in determine_modules_for_files(["file_not_matched_by_any_subproject"])]
['root']
- >>> [x.name for x in determine_modules_for_files(["appveyor.yml"])]
+ >>> [x.name for x in determine_modules_for_files(["appveyor.yml", "sql/README.md"])]
[]
"""
changed_modules = set()
for filename in filenames:
if filename in ("appveyor.yml",):
continue
+ if filename.endswith("README.md"):
+ continue
if ("GITHUB_ACTIONS" not in os.environ) and filename.startswith(".github"):
continue
matched_at_least_one_module = False
diff --git a/dev/tox.ini b/dev/tox.ini
index f44cbe54dd..15c93832c2 100644
--- a/dev/tox.ini
+++ b/dev/tox.ini
@@ -36,7 +36,8 @@ per-file-ignores =
python/pyspark/resource/tests/*.py: F403,
python/pyspark/sql/tests/*.py: F403,
python/pyspark/streaming/tests/*.py: F403,
- python/pyspark/tests/*.py: F403
+ python/pyspark/tests/*.py: F403,
+ python/pyspark/testing/*: F401
exclude =
*/target/*,
docs/.local_ruby_bundle/,
diff --git a/docs/building-spark.md b/docs/building-spark.md
index 3e1ec771da..9b115f1ad9 100644
--- a/docs/building-spark.md
+++ b/docs/building-spark.md
@@ -317,3 +317,21 @@ To build and run tests on IPv6-only environment, the following configurations ar
export MAVEN_OPTS="-Djava.net.preferIPv6Addresses=true"
export SBT_OPTS="-Djava.net.preferIPv6Addresses=true"
export SERIAL_SBT_TESTS=1
+
+### Building with user-defined `protoc`
+
+When the user cannot use the official `protoc` binary files to build the `core` module in the compilation environment, for example, compiling `core` module on CentOS 6 or CentOS 7 which the default `glibc` version is less than 2.14, we can try to compile and test by specifying the user-defined `protoc` binary files as follows:
+
+```bash
+export SPARK_PROTOC_EXEC_PATH=/path-to-protoc-exe
+./build/mvn -Puser-defined-protoc -DskipDefaultProtoc clean package
+```
+
+or
+
+```bash
+export SPARK_PROTOC_EXEC_PATH=/path-to-protoc-exe
+./build/sbt -Puser-defined-protoc clean package
+```
+
+The user-defined `protoc` binary files can be produced in the user's compilation environment by source code compilation, for compilation steps, please refer to [protobuf](https://github.com/protocolbuffers/protobuf).
diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md
index 95be32a819..711e828bd8 100644
--- a/docs/mllib-isotonic-regression.md
+++ b/docs/mllib-isotonic-regression.md
@@ -43,7 +43,14 @@ best fitting the original data points.
which uses an approach to
[parallelizing isotonic regression](https://doi.org/10.1007/978-3-642-99789-1_10).
The training input is an RDD of tuples of three double values that represent
-label, feature and weight in this order. Additionally, IsotonicRegression algorithm has one
+label, feature and weight in this order. In case there are multiple tuples with
+the same feature then these tuples are aggregated into a single tuple as follows:
+
+* Aggregated label is the weighted average of all labels.
+* Aggregated feature is the unique feature value.
+* Aggregated weight is the sum of all weights.
+
+Additionally, IsotonicRegression algorithm has one
optional parameter called $isotonic$ defaulting to true.
This argument specifies if the isotonic regression is
isotonic (monotonically increasing) or antitonic (monotonically decreasing).
@@ -53,17 +60,12 @@ labels for both known and unknown features. The result of isotonic regression
is treated as piecewise linear function. The rules for prediction therefore are:
* If the prediction input exactly matches a training feature
- then associated prediction is returned. In case there are multiple predictions with the same
- feature then one of them is returned. Which one is undefined
- (same as java.util.Arrays.binarySearch).
+ then associated prediction is returned.
* If the prediction input is lower or higher than all training features
then prediction with lowest or highest feature is returned respectively.
- In case there are multiple predictions with the same feature
- then the lowest or highest is returned respectively.
* If the prediction input falls between two training features then prediction is treated
as piecewise linear function and interpolated value is calculated from the
- predictions of the two closest features. In case there are multiple values
- with the same feature then the same rules as in previous point are used.
+ predictions of the two closest features.
### Examples
diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md
index 08580a77eb..21c81c508e 100644
--- a/docs/running-on-kubernetes.md
+++ b/docs/running-on-kubernetes.md
@@ -204,6 +204,26 @@ When this property is set, it's highly recommended to make it unique across all
Use the exact prefix `spark.kubernetes.authenticate` for Kubernetes authentication parameters in client mode.
+## IPv4 and IPv6
+
+Starting with 3.4.0, Spark supports additionally IPv6-only environment via
+[IPv4/IPv6 dual-stack network](https://kubernetes.io/docs/concepts/services-networking/dual-stack/)
+feature which enables the allocation of both IPv4 and IPv6 addresses to Pods and Services.
+According to the K8s cluster capability, `spark.kubernetes.driver.service.ipFamilyPolicy` and
+`spark.kubernetes.driver.service.ipFamilies` can be one of `SingleStack`, `PreferDualStack`,
+and `RequireDualStack` and one of `IPv4`, `IPv6`, `IPv4,IPv6`, and `IPv6,IPv4` respectively.
+By default, Spark uses `spark.kubernetes.driver.service.ipFamilyPolicy=SingleStack` and
+`spark.kubernetes.driver.service.ipFamilies=IPv4`.
+
+To use only `IPv6`, you can submit your jobs with the following.
+```bash
+...
+ --conf spark.kubernetes.driver.service.ipFamilies=IPv6 \
+```
+
+In `DualStack` environment, you may need `java.net.preferIPv6Addresses=true` for JVM
+and `SPARK_PREFER_IPV6=true` for Python additionally to use `IPv6`.
+
## Dependency Management
If your application's dependencies are all hosted in remote locations like HDFS or HTTP servers, they may be referred to
@@ -1418,7 +1438,8 @@ See the [configuration page](configuration.html) for information on Spark config
spark.kubernetes.driver.service.ipFamilyPolicy |
SingleStack |
- K8s IP Family Policy for Driver Service.
+ K8s IP Family Policy for Driver Service. Valid values are
+ SingleStack , PreferDualStack , and RequireDualStack .
|
3.4.0 |
@@ -1426,7 +1447,8 @@ See the [configuration page](configuration.html) for information on Spark config
spark.kubernetes.driver.service.ipFamilies |
IPv4 |
- A list of IP families for K8s Driver Service.
+ A list of IP families for K8s Driver Service. Valid values are
+ IPv4 and IPv6 .
|
3.4.0 |
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
index 649f9816e6..fbf0dc9c35 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
@@ -14,7 +14,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.spark.mllib.regression
import java.io.Serializable
@@ -272,8 +271,8 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali
* @param input RDD of tuples (label, feature, weight) where label is dependent variable
* for which we calculate isotonic regression, feature is independent variable
* and weight represents number of measures with default 1.
- * If multiple labels share the same feature value then they are ordered before
- * the algorithm is executed.
+ * If multiple labels share the same feature value then they are aggregated using
+ * the weighted average before the algorithm is executed.
* @return Isotonic regression model.
*/
@Since("1.3.0")
@@ -298,8 +297,8 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali
* @param input JavaRDD of tuples (label, feature, weight) where label is dependent variable
* for which we calculate isotonic regression, feature is independent variable
* and weight represents number of measures with default 1.
- * If multiple labels share the same feature value then they are ordered before
- * the algorithm is executed.
+ * If multiple labels share the same feature value then they are aggregated using
+ * the weighted average before the algorithm is executed.
* @return Isotonic regression model.
*/
@Since("1.3.0")
@@ -307,6 +306,58 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali
run(input.rdd.retag.asInstanceOf[RDD[(Double, Double, Double)]])
}
+ /**
+ * Aggregates points of duplicate feature values into a single point using as label the weighted
+ * average of the labels of the points with duplicate feature values. All points for a unique
+ * feature value are aggregated as:
+ *
+ * - Aggregated label is the weighted average of all labels.
+ * - Aggregated feature is the unique feature value.
+ * - Aggregated weight is the sum of all weights.
+ *
+ * @param input Input data of tuples (label, feature, weight). Weights must be non-negative.
+ * @return Points with unique feature values.
+ */
+ private[regression] def makeUnique(
+ input: Array[(Double, Double, Double)]): Array[(Double, Double, Double)] = {
+
+ val cleanInput = input.filter { case (y, x, weight) =>
+ require(
+ weight >= 0.0,
+ s"Negative weight at point ($y, $x, $weight). Weights must be non-negative")
+ weight > 0
+ }
+
+ if (cleanInput.length <= 1) {
+ cleanInput
+ } else {
+ val pointsAccumulator = new IsotonicRegression.PointsAccumulator
+
+ // Go through input points, merging all points with equal feature values into a single point.
+ // Equality of features is defined by shouldAccumulate method. The label of the accumulated
+ // points is the weighted average of the labels of all points of equal feature value.
+
+ // Initialize with first point
+ pointsAccumulator := cleanInput.head
+ // Accumulate the rest
+ cleanInput.tail.foreach { case point @ (_, feature, _) =>
+ if (pointsAccumulator.shouldAccumulate(feature)) {
+ // Still on a duplicate feature, accumulate
+ pointsAccumulator += point
+ } else {
+ // A new unique feature encountered:
+ // - append the last accumulated point to unique features output
+ pointsAccumulator.appendToOutput()
+ // - and reset
+ pointsAccumulator := point
+ }
+ }
+ // Append the last accumulated point to unique features output
+ pointsAccumulator.appendToOutput()
+ pointsAccumulator.getOutput
+ }
+ }
+
/**
* Performs a pool adjacent violators algorithm (PAV). Implements the algorithm originally
* described in [1], using the formulation from [2, 3]. Uses an array to keep track of start
@@ -322,35 +373,27 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali
* functions subject to simple chain constraints." SIAM Journal on Optimization 10.3 (2000):
* 658-672.
*
- * @param input Input data of tuples (label, feature, weight). Weights must
- be non-negative.
+ * @param cleanUniqueInput Input data of tuples(label, feature, weight).Features must be unique
+ * and weights must be non-negative.
* @return Result tuples (label, feature, weight) where labels were updated
* to form a monotone sequence as per isotonic regression definition.
*/
private def poolAdjacentViolators(
- input: Array[(Double, Double, Double)]): Array[(Double, Double, Double)] = {
+ cleanUniqueInput: Array[(Double, Double, Double)]): Array[(Double, Double, Double)] = {
- val cleanInput = input.filter{ case (y, x, weight) =>
- require(
- weight >= 0.0,
- s"Negative weight at point ($y, $x, $weight). Weights must be non-negative"
- )
- weight > 0
- }
-
- if (cleanInput.isEmpty) {
+ if (cleanUniqueInput.isEmpty) {
return Array.empty
}
// Keeps track of the start and end indices of the blocks. if [i, j] is a valid block from
// cleanInput(i) to cleanInput(j) (inclusive), then blockBounds(i) = j and blockBounds(j) = i
// Initially, each data point is its own block.
- val blockBounds = Array.range(0, cleanInput.length)
+ val blockBounds = Array.range(0, cleanUniqueInput.length)
// Keep track of the sum of weights and sum of weight * y for each block. weights(start)
// gives the values for the block. Entries that are not at the start of a block
// are meaningless.
- val weights: Array[(Double, Double)] = cleanInput.map { case (y, _, weight) =>
+ val weights: Array[(Double, Double)] = cleanUniqueInput.map { case (y, _, weight) =>
(weight, weight * y)
}
@@ -392,10 +435,10 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali
// Merge on >= instead of > because it eliminates adjacent blocks with the same average, and we
// want to compress our output as much as possible. Both give correct results.
var i = 0
- while (nextBlock(i) < cleanInput.length) {
+ while (nextBlock(i) < cleanUniqueInput.length) {
if (average(i) >= average(nextBlock(i))) {
merge(i, nextBlock(i))
- while((i > 0) && (average(prevBlock(i)) >= average(i))) {
+ while ((i > 0) && (average(prevBlock(i)) >= average(i))) {
i = merge(prevBlock(i), i)
}
} else {
@@ -406,15 +449,15 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali
// construct the output by walking through the blocks in order
val output = ArrayBuffer.empty[(Double, Double, Double)]
i = 0
- while (i < cleanInput.length) {
+ while (i < cleanUniqueInput.length) {
// If block size is > 1, a point at the start and end of the block,
// each receiving half the weight. Otherwise, a single point with
// all the weight.
- if (cleanInput(blockEnd(i))._2 > cleanInput(i)._2) {
- output += ((average(i), cleanInput(i)._2, weights(i)._1 / 2))
- output += ((average(i), cleanInput(blockEnd(i))._2, weights(i)._1 / 2))
+ if (cleanUniqueInput(blockEnd(i))._2 > cleanUniqueInput(i)._2) {
+ output += ((average(i), cleanUniqueInput(i)._2, weights(i)._1 / 2))
+ output += ((average(i), cleanUniqueInput(blockEnd(i))._2, weights(i)._1 / 2))
} else {
- output += ((average(i), cleanInput(i)._2, weights(i)._1))
+ output += ((average(i), cleanUniqueInput(i)._2, weights(i)._1))
}
i = nextBlock(i)
}
@@ -434,12 +477,58 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali
input: RDD[(Double, Double, Double)]): Array[(Double, Double, Double)] = {
val keyedInput = input.keyBy(_._2)
val parallelStepResult = keyedInput
+ // Points with same or adjacent features must collocate within the same partition.
.partitionBy(new RangePartitioner(keyedInput.getNumPartitions, keyedInput))
.values
- .mapPartitions(p => Iterator(p.toArray.sortBy(x => (x._2, x._1))))
+ // Lexicographically sort points by features.
+ .mapPartitions(p => Iterator(p.toArray.sortBy(_._2)))
+ // Aggregate points with equal features into a single point.
+ .map(makeUnique)
.flatMap(poolAdjacentViolators)
.collect()
- .sortBy(x => (x._2, x._1)) // Sort again because collect() doesn't promise ordering.
+ // Sort again because collect() doesn't promise ordering.
+ .sortBy(_._2)
poolAdjacentViolators(parallelStepResult)
}
}
+
+object IsotonicRegression {
+ /**
+ * Utility class, holds a buffer of all points with unique features so far, and performs
+ * weighted sum accumulation of points. Hides these details for better readability of the
+ * main algorithm.
+ */
+ class PointsAccumulator {
+ private val output = ArrayBuffer[(Double, Double, Double)]()
+ private var (currentLabel: Double, currentFeature: Double, currentWeight: Double) =
+ (0d, 0d, 0d)
+
+ /** Whether or not this feature exactly equals the current accumulated feature. */
+ @inline def shouldAccumulate(feature: Double): Boolean = currentFeature == feature
+
+ /** Resets the current value of the point accumulator using the provided point. */
+ @inline def :=(point: (Double, Double, Double)): Unit = {
+ val (label, feature, weight) = point
+ currentLabel = label * weight
+ currentFeature = feature
+ currentWeight = weight
+ }
+
+ /** Accumulates the provided point into the current value of the point accumulator. */
+ @inline def +=(point: (Double, Double, Double)): Unit = {
+ val (label, _, weight) = point
+ currentLabel += label * weight
+ currentWeight += weight
+ }
+
+ /** Appends the current value of the point accumulator to the output. */
+ @inline def appendToOutput(): Unit =
+ output += ((
+ currentLabel / currentWeight,
+ currentFeature,
+ currentWeight))
+
+ /** Returns all accumulated points so far. */
+ @inline def getOutput: Array[(Double, Double, Double)] = output.toArray
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala
index 8066900dfa..a206e922e5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala
@@ -24,6 +24,24 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
+/**
+ * Tests can be verified through the following python snippet:
+ *
+ * {{{
+ * from sklearn.isotonic import IsotonicRegression
+ *
+ * def test(x, y, x_test, isotonic=True):
+ * ir = IsotonicRegression(out_of_bounds='clip', increasing=isotonic).fit(x, y)
+ * y_test = ir.predict(x_test)
+ *
+ * def print_array(label, a):
+ * print(f"{label}: [{', '.join([str(i) for i in a])}]")
+ *
+ * print_array("boundaries", ir.X_thresholds_)
+ * print_array("predictions", ir.y_thresholds_)
+ * print_array("y_test", y_test)
+ * }}}
+ */
class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers {
private def round(d: Double) = {
@@ -44,8 +62,7 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w
labels: Seq[Double],
weights: Seq[Double],
isotonic: Boolean): IsotonicRegressionModel = {
- val trainRDD = sc.parallelize(generateIsotonicInput(labels, weights)).cache()
- new IsotonicRegression().setIsotonic(isotonic).run(trainRDD)
+ runIsotonicRegressionOnInput(generateIsotonicInput(labels, weights), isotonic)
}
private def runIsotonicRegression(
@@ -54,17 +71,37 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w
runIsotonicRegression(labels, Array.fill(labels.size)(1d), isotonic)
}
+ private def runIsotonicRegression(
+ labels: Seq[Double],
+ features: Seq[Double],
+ weights: Seq[Double],
+ isotonic: Boolean): IsotonicRegressionModel = {
+ runIsotonicRegressionOnInput(
+ labels.indices.map(i => (labels(i), features(i), weights(i))),
+ isotonic)
+ }
+
+ private def runIsotonicRegressionOnInput(
+ input: Seq[(Double, Double, Double)],
+ isotonic: Boolean,
+ slices: Int = sc.defaultParallelism): IsotonicRegressionModel = {
+ val trainRDD = sc.parallelize(input, slices).cache()
+ new IsotonicRegression().setIsotonic(isotonic).run(trainRDD)
+ }
+
test("increasing isotonic regression") {
/*
The following result could be re-produced with sklearn.
- > from sklearn.isotonic import IsotonicRegression
- > x = range(9)
- > y = [1, 2, 3, 1, 6, 17, 16, 17, 18]
- > ir = IsotonicRegression(x, y)
- > print ir.predict(x)
+ > test(
+ > x = range(9),
+ > y = [1, 2, 3, 1, 6, 17, 16, 17, 18],
+ > x_test = range(9)
+ > )
- array([ 1. , 2. , 2. , 2. , 6. , 16.5, 16.5, 17. , 18. ])
+ boundaries: [0.0, 1.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
+ predictions: [1.0, 2.0, 2.0, 6.0, 16.5, 16.5, 17.0, 18.0]
+ y_test: [1.0, 2.0, 2.0, 2.0, 6.0, 16.5, 16.5, 17.0, 18.0]
*/
val model = runIsotonicRegression(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18), true)
@@ -142,9 +179,9 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w
}
test("isotonic regression with unordered input") {
- val trainRDD = sc.parallelize(generateIsotonicInput(Seq(1, 2, 3, 4, 5)).reverse, 2).cache()
+ val model =
+ runIsotonicRegressionOnInput(generateIsotonicInput(Seq(1, 2, 3, 4, 5)).reverse, true, 2)
- val model = new IsotonicRegression().run(trainRDD)
assert(model.predictions === Array(1, 2, 3, 4, 5))
}
@@ -159,7 +196,7 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w
val model = runIsotonicRegression(Seq(1, 2, 3, 2, 1), Seq(1, 1, 1, 0.1, 0.1), true)
assert(model.boundaries === Array(0, 1, 2, 4))
- assert(model.predictions.map(round) === Array(1, 2, 3.3/1.2, 3.3/1.2))
+ assert(model.predictions.map(round) === Array(1, 2, 3.3 / 1.2, 3.3 / 1.2))
}
test("weighted isotonic regression with negative weights") {
@@ -176,16 +213,31 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w
}
test("SPARK-16426 isotonic regression with duplicate features that produce NaNs") {
- val trainRDD = sc.parallelize(Seq[(Double, Double, Double)]((2, 1, 1), (1, 1, 1), (0, 2, 1),
- (1, 2, 1), (0.5, 3, 1), (0, 3, 1)),
- 2)
-
- val model = new IsotonicRegression().run(trainRDD)
+ val model = runIsotonicRegressionOnInput(
+ Seq((2, 1, 1), (1, 1, 1), (0, 2, 1), (1, 2, 1), (0.5, 3, 1), (0, 3, 1)),
+ true,
+ 2)
assert(model.boundaries === Array(1.0, 3.0))
assert(model.predictions === Array(0.75, 0.75))
}
+ test("SPARK-41008 isotonic regression with duplicate features differs from sklearn") {
+ val model = runIsotonicRegressionOnInput(
+ Seq((1, 0.6, 1), (0, 0.6, 1),
+ (0, 1.0 / 3, 1), (1, 1.0 / 3, 1), (0, 1.0 / 3, 1),
+ (1, 0.2, 1), (0, 0.2, 1), (0, 0.2, 1), (0, 0.2, 1)),
+ true,
+ 2)
+
+ assert(model.boundaries === Array(0.2, 1.0 / 3, 0.6))
+ assert(model.predictions === Array(0.25, 1.0 / 3, 0.5))
+
+ assert(model.predict(0.6) === 0.5)
+ assert(model.predict(1.0 / 3) === 1.0 / 3)
+ assert(model.predict(0.2) === 0.25)
+ }
+
test("isotonic regression prediction") {
val model = runIsotonicRegression(Seq(1, 2, 7, 1, 2), true)
@@ -194,32 +246,38 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w
assert(model.predict(0.5) === 1.5)
assert(model.predict(0.75) === 1.75)
assert(model.predict(1) === 2)
- assert(model.predict(2) === 10d/3)
- assert(model.predict(9) === 10d/3)
+ assert(model.predict(2) === 10.0 / 3)
+ assert(model.predict(9) === 10.0 / 3)
}
test("isotonic regression prediction with duplicate features") {
- val trainRDD = sc.parallelize(
- Seq[(Double, Double, Double)](
- (2, 1, 1), (1, 1, 1), (4, 2, 1), (2, 2, 1), (6, 3, 1), (5, 3, 1)), 2).cache()
- val model = new IsotonicRegression().run(trainRDD)
-
- assert(model.predict(0) === 1)
- assert(model.predict(1.5) === 2)
- assert(model.predict(2.5) === 4.5)
- assert(model.predict(4) === 6)
+ val model = runIsotonicRegressionOnInput(
+ Seq((2, 1, 1), (1, 1, 1), (4, 2, 1), (2, 2, 1), (6, 3, 1), (5, 3, 1)),
+ true,
+ 2)
+
+ assert(model.boundaries === Array(1.0, 2.0, 3.0))
+ assert(model.predictions === Array(1.5, 3.0, 5.5))
+
+ assert(model.predict(0) === 1.5)
+ assert(model.predict(1.5) === 2.25)
+ assert(model.predict(2.5) === 4.25)
+ assert(model.predict(4) === 5.5)
}
test("antitonic regression prediction with duplicate features") {
- val trainRDD = sc.parallelize(
- Seq[(Double, Double, Double)](
- (5, 1, 1), (6, 1, 1), (2, 2, 1), (4, 2, 1), (1, 3, 1), (2, 3, 1)), 2).cache()
- val model = new IsotonicRegression().setIsotonic(false).run(trainRDD)
-
- assert(model.predict(0) === 6)
- assert(model.predict(1.5) === 4.5)
- assert(model.predict(2.5) === 2)
- assert(model.predict(4) === 1)
+ val model = runIsotonicRegressionOnInput(
+ Seq((5, 1, 1), (6, 1, 1), (2, 2, 1), (4, 2, 1), (1, 3, 1), (2, 3, 1)),
+ false,
+ 2)
+
+ assert(model.boundaries === Array(1.0, 2.0, 3.0))
+ assert(model.predictions === Array(5.5, 3.0, 1.5))
+
+ assert(model.predict(0) === 5.5)
+ assert(model.predict(1.5) === 4.25)
+ assert(model.predict(2.5) === 2.25)
+ assert(model.predict(4) === 1.5)
}
test("isotonic regression RDD prediction") {
@@ -227,7 +285,7 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w
val testRDD = sc.parallelize(List(-2.0, -1.0, 0.5, 0.75, 1.0, 2.0, 9.0), 2).cache()
val predictions = testRDD.map(x => (x, model.predict(x))).collect().sortBy(_._1).map(_._2)
- assert(predictions === Array(1, 1, 1.5, 1.75, 2, 10.0/3, 10.0/3))
+ assert(predictions === Array(1, 1, 1.5, 1.75, 2, 10.0 / 3, 10.0 / 3))
}
test("antitonic regression prediction") {
@@ -270,4 +328,63 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w
new IsotonicRegressionModel(Array(0.0, 1.0), Array(1.0, 2.0), isotonic = false)
}
}
+
+ test("makeUnique: handle duplicate features") {
+ val regressor = new IsotonicRegression()
+ import regressor.makeUnique
+
+ // Note: input must be lexicographically sorted by feature
+
+ // empty
+ assert(makeUnique(Array.empty) === Array.empty)
+
+ // single
+ assert(makeUnique(Array((1.0, 1.0, 1.0))) === Array((1.0, 1.0, 1.0)))
+
+ // two and duplicate
+ assert(makeUnique(Array((1.0, 1.0, 1.0), (1.0, 1.0, 1.0))) === Array((1.0, 1.0, 2.0)))
+
+ // two and unique
+ assert(
+ makeUnique(Array((1.0, 1.0, 1.0), (1.0, 2.0, 1.0))) ===
+ Array((1.0, 1.0, 1.0), (1.0, 2.0, 1.0)))
+
+ // generic with duplicates
+ assert(
+ makeUnique(
+ Array(
+ (10.0, 1.0, 1.0), (20.0, 1.0, 1.0),
+ (10.0, 2.0, 1.0), (20.0, 2.0, 1.0), (30.0, 2.0, 1.0),
+ (10.0, 3.0, 1.0)
+ )) === Array((15.0, 1.0, 2.0), (20.0, 2.0, 3.0), (10.0, 3.0, 1.0)))
+
+ // generic unique
+ assert(
+ makeUnique(Array((10.0, 1.0, 1.0), (10.0, 2.0, 1.0), (10.0, 3.0, 1.0))) === Array(
+ (10.0, 1.0, 1.0),
+ (10.0, 2.0, 1.0),
+ (10.0, 3.0, 1.0)))
+
+ // generic with duplicates and non-uniform weights
+ assert(
+ makeUnique(
+ Array(
+ (10.0, 1.0, 0.3), (20.0, 1.0, 0.7),
+ (10.0, 2.0, 0.3), (20.0, 2.0, 0.3), (30.0, 2.0, 0.4),
+ (10.0, 3.0, 1.0)
+ )) === Array(
+ (10.0 * 0.3 + 20.0 * 0.7, 1.0, 1.0),
+ (10.0 * 0.3 + 20.0 * 0.3 + 30.0 * 0.4, 2.0, 1.0),
+ (10.0, 3.0, 1.0)))
+
+ // don't handle tiny representation errors
+ // e.g. infinitely adjacent doubles are already unique
+ val adjacentDoubles = {
+ // i-th next representable double to 1.0 is java.lang.Double.longBitsToDouble(base + i)
+ val base = java.lang.Double.doubleToRawLongBits(1.0)
+ (0 until 10).map(i => java.lang.Double.longBitsToDouble(base + i))
+ .map((1.0, _, 1.0)).toArray
+ }
+ assert(makeUnique(adjacentDoubles) === adjacentDoubles)
+ }
}
diff --git a/pom.xml b/pom.xml
index b2e5979f46..da7c8eccfc 100644
--- a/pom.xml
+++ b/pom.xml
@@ -123,7 +123,7 @@
2.5.0
- 3.21.9
+ 3.21.11
3.11.4
${hadoop.version}
3.6.3
@@ -161,7 +161,7 @@
0.12.8
hadoop3-2.2.7
- 4.5.13
+ 4.5.14
4.4.14
3.6.1
@@ -175,7 +175,7 @@
errors building different Hadoop versions.
See: SPARK-36547, SPARK-38394.
-->
- 4.7.2
+ 4.8.0
true
true
@@ -2042,6 +2042,10 @@
${hive.group}
hive-ant
+
+ ${hive.group}
+ hive-vector-code-gen
+
${hive.group}
@@ -3229,17 +3233,6 @@
org.spark-project.spark:unused
- org.eclipse.jetty:jetty-io
- org.eclipse.jetty:jetty-http
- org.eclipse.jetty:jetty-proxy
- org.eclipse.jetty:jetty-client
- org.eclipse.jetty:jetty-continuation
- org.eclipse.jetty:jetty-servlet
- org.eclipse.jetty:jetty-servlets
- org.eclipse.jetty:jetty-plus
- org.eclipse.jetty:jetty-security
- org.eclipse.jetty:jetty-util
- org.eclipse.jetty:jetty-server
com.google.guava:guava
org.jpmml:*
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index eed79d1f20..7ec4ef37a0 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -123,7 +123,13 @@ object MimaExcludes {
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryException.this"),
// [SPARK-41180][SQL] Reuse INVALID_SCHEMA instead of _LEGACY_ERROR_TEMP_1227
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.types.DataType.parseTypeWithFallback")
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.types.DataType.parseTypeWithFallback"),
+
+ // [SPARK-41360][CORE] Avoid BlockManager re-registration if the executor has been lost
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockManagerMessages#RegisterBlockManager.copy"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockManagerMessages#RegisterBlockManager.this"),
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.storage.BlockManagerMessages$RegisterBlockManager$"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockManagerMessages#RegisterBlockManager.apply")
)
// Defulat exclude rules
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index e6a39714e6..556f8528ea 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -87,7 +87,7 @@ object BuildCommons {
// Google Protobuf version used for generating the protobuf.
// SPARK-41247: needs to be consistent with `protobuf.version` in `pom.xml`.
- val protoVersion = "3.21.9"
+ val protoVersion = "3.21.11"
// GRPC version used for Spark Connect.
val gprcVersion = "1.47.0"
}
@@ -112,15 +112,13 @@ object SparkBuild extends PomBuild {
sys.props.put("test.jdwp.enabled", "true")
}
if (profiles.contains("user-defined-protoc")) {
- val connectProtocExecPath = Properties.envOrNone("CONNECT_PROTOC_EXEC_PATH")
+ val sparkProtocExecPath = Properties.envOrNone("SPARK_PROTOC_EXEC_PATH")
val connectPluginExecPath = Properties.envOrNone("CONNECT_PLUGIN_EXEC_PATH")
- val protobufProtocExecPath = Properties.envOrNone("PROTOBUF_PROTOC_EXEC_PATH")
- if (connectProtocExecPath.isDefined && connectPluginExecPath.isDefined) {
- sys.props.put("connect.protoc.executable.path", connectProtocExecPath.get)
- sys.props.put("connect.plugin.executable.path", connectPluginExecPath.get)
+ if (sparkProtocExecPath.isDefined) {
+ sys.props.put("spark.protoc.executable.path", sparkProtocExecPath.get)
}
- if (protobufProtocExecPath.isDefined) {
- sys.props.put("protobuf.protoc.executable.path", protobufProtocExecPath.get)
+ if (connectPluginExecPath.isDefined) {
+ sys.props.put("connect.plugin.executable.path", connectPluginExecPath.get)
}
}
profiles
@@ -644,7 +642,16 @@ object Core {
val propsFile = baseDirectory.value / "target" / "extra-resources" / "spark-version-info.properties"
Seq(propsFile)
}.taskValue
- )
+ ) ++ {
+ val sparkProtocExecPath = sys.props.get("spark.protoc.executable.path")
+ if (sparkProtocExecPath.isDefined) {
+ Seq(
+ PB.protocExecutable := file(sparkProtocExecPath.get)
+ )
+ } else {
+ Seq.empty
+ }
+ }
}
object SparkConnectCommon {
@@ -709,15 +716,15 @@ object SparkConnectCommon {
case _ => MergeStrategy.first
}
) ++ {
- val connectProtocExecPath = sys.props.get("connect.protoc.executable.path")
+ val sparkProtocExecPath = sys.props.get("spark.protoc.executable.path")
val connectPluginExecPath = sys.props.get("connect.plugin.executable.path")
- if (connectProtocExecPath.isDefined && connectPluginExecPath.isDefined) {
+ if (sparkProtocExecPath.isDefined && connectPluginExecPath.isDefined) {
Seq(
(Compile / PB.targets) := Seq(
PB.gens.java -> (Compile / sourceManaged).value,
PB.gens.plugin(name = "grpc-java", path = connectPluginExecPath.get) -> (Compile / sourceManaged).value
),
- PB.protocExecutable := file(connectProtocExecPath.get)
+ PB.protocExecutable := file(sparkProtocExecPath.get)
)
} else {
Seq(
@@ -867,10 +874,10 @@ object SparkProtobuf {
case _ => MergeStrategy.first
},
) ++ {
- val protobufProtocExecPath = sys.props.get("protobuf.protoc.executable.path")
- if (protobufProtocExecPath.isDefined) {
+ val sparkProtocExecPath = sys.props.get("spark.protoc.executable.path")
+ if (sparkProtocExecPath.isDefined) {
Seq(
- PB.protocExecutable := file(protobufProtocExecPath.get)
+ PB.protocExecutable := file(sparkProtocExecPath.get)
)
} else {
Seq.empty
diff --git a/python/mypy.ini b/python/mypy.ini
index 927254d3b3..603647bd3c 100644
--- a/python/mypy.ini
+++ b/python/mypy.ini
@@ -26,7 +26,7 @@ warn_redundant_casts = True
[mypy-pyspark.sql.connect.proto.*]
ignore_errors = True
-; Allow untyped def in internal modules and tests
+; Allow untyped def in internal modules
[mypy-pyspark.daemon]
disallow_untyped_defs = False
@@ -46,33 +46,18 @@ disallow_untyped_defs = False
[mypy-pyspark.join]
disallow_untyped_defs = False
-[mypy-pyspark.ml.tests.*]
-disallow_untyped_defs = False
-
-[mypy-pyspark.mllib.tests.*]
-disallow_untyped_defs = False
-
[mypy-pyspark.rddsampler]
disallow_untyped_defs = False
-[mypy-pyspark.resource.tests.*]
-disallow_untyped_defs = False
-
[mypy-pyspark.serializers]
disallow_untyped_defs = False
[mypy-pyspark.shuffle]
disallow_untyped_defs = False
-[mypy-pyspark.streaming.tests.*]
-disallow_untyped_defs = False
-
[mypy-pyspark.streaming.util]
disallow_untyped_defs = False
-[mypy-pyspark.sql.tests.*]
-disallow_untyped_defs = False
-
[mypy-pyspark.sql.pandas.serializers]
disallow_untyped_defs = False
@@ -88,20 +73,37 @@ disallow_untyped_defs = False
[mypy-pyspark.pandas.usage_logging.*]
disallow_untyped_defs = False
-[mypy-pyspark.pandas.tests.*]
+[mypy-pyspark.traceback_utils]
disallow_untyped_defs = False
-[mypy-pyspark.tests.*]
+[mypy-pyspark.worker]
disallow_untyped_defs = False
-[mypy-pyspark.testing.*]
-disallow_untyped_defs = False
+; Ignore errors in tests
-[mypy-pyspark.traceback_utils]
-disallow_untyped_defs = False
+[mypy-pyspark.ml.tests.*]
+ignore_errors = True
-[mypy-pyspark.worker]
-disallow_untyped_defs = False
+[mypy-pyspark.mllib.tests.*]
+ignore_errors = True
+
+[mypy-pyspark.resource.tests.*]
+ignore_errors = True
+
+[mypy-pyspark.streaming.tests.*]
+ignore_errors = True
+
+[mypy-pyspark.sql.tests.*]
+ignore_errors = True
+
+[mypy-pyspark.pandas.tests.*]
+ignore_errors = True
+
+[mypy-pyspark.tests.*]
+ignore_errors = True
+
+[mypy-pyspark.testing.*]
+ignore_errors = True
; Allow non-strict optional for pyspark.pandas
@@ -145,6 +147,9 @@ ignore_missing_imports = True
[mypy-google.protobuf.*]
ignore_missing_imports = True
+[mypy-grpc.*]
+ignore_missing_imports = True
+
; Ignore errors for proto generated code
[mypy-pyspark.sql.connect.proto.*, pyspark.sql.connect.proto]
ignore_errors = True
diff --git a/python/pyspark/ml/tests/test_algorithms.py b/python/pyspark/ml/tests/test_algorithms.py
index e677e79cec..accdddb29c 100644
--- a/python/pyspark/ml/tests/test_algorithms.py
+++ b/python/pyspark/ml/tests/test_algorithms.py
@@ -413,7 +413,7 @@ def test_linear_regression_with_huber_loss(self):
from pyspark.ml.tests.test_algorithms import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/ml/tests/test_base.py b/python/pyspark/ml/tests/test_base.py
index b95b8fbdd5..6c3c51d1c0 100644
--- a/python/pyspark/ml/tests/test_base.py
+++ b/python/pyspark/ml/tests/test_base.py
@@ -88,7 +88,7 @@ def testDefaultFitMultiple(self):
from pyspark.ml.tests.test_base import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/ml/tests/test_evaluation.py b/python/pyspark/ml/tests/test_evaluation.py
index d2fd369624..3c5ae3fbe7 100644
--- a/python/pyspark/ml/tests/test_evaluation.py
+++ b/python/pyspark/ml/tests/test_evaluation.py
@@ -69,7 +69,7 @@ def test_clustering_evaluator_with_cosine_distance(self):
from pyspark.ml.tests.test_evaluation import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/ml/tests/test_feature.py b/python/pyspark/ml/tests/test_feature.py
index 6cf3175865..0051d47ae3 100644
--- a/python/pyspark/ml/tests/test_feature.py
+++ b/python/pyspark/ml/tests/test_feature.py
@@ -393,7 +393,7 @@ def test_apply_binary_term_freqs(self):
from pyspark.ml.tests.test_feature import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/ml/tests/test_image.py b/python/pyspark/ml/tests/test_image.py
index 8a155ab56a..86fa46c324 100644
--- a/python/pyspark/ml/tests/test_image.py
+++ b/python/pyspark/ml/tests/test_image.py
@@ -74,7 +74,7 @@ def test_read_images(self):
from pyspark.ml.tests.test_image import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/ml/tests/test_linalg.py b/python/pyspark/ml/tests/test_linalg.py
index a6e9f4e752..6632d100ea 100644
--- a/python/pyspark/ml/tests/test_linalg.py
+++ b/python/pyspark/ml/tests/test_linalg.py
@@ -401,7 +401,7 @@ def test_infer_schema(self):
from pyspark.ml.tests.test_linalg import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/ml/tests/test_param.py b/python/pyspark/ml/tests/test_param.py
index 64ed2f6dbe..8df50a5963 100644
--- a/python/pyspark/ml/tests/test_param.py
+++ b/python/pyspark/ml/tests/test_param.py
@@ -433,7 +433,7 @@ def test_java_params(self):
from pyspark.ml.tests.test_param import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/ml/tests/test_persistence.py b/python/pyspark/ml/tests/test_persistence.py
index 0b54540f06..406180d9a6 100644
--- a/python/pyspark/ml/tests/test_persistence.py
+++ b/python/pyspark/ml/tests/test_persistence.py
@@ -538,7 +538,7 @@ def test_save_and_load_on_nested_list_params(self):
from pyspark.ml.tests.test_persistence import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/ml/tests/test_pipeline.py b/python/pyspark/ml/tests/test_pipeline.py
index 1f73fdd344..afc900cec4 100644
--- a/python/pyspark/ml/tests/test_pipeline.py
+++ b/python/pyspark/ml/tests/test_pipeline.py
@@ -63,7 +63,7 @@ def doTransform(pipeline):
from pyspark.ml.tests.test_pipeline import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/ml/tests/test_stat.py b/python/pyspark/ml/tests/test_stat.py
index 16ce1bc7da..6bab41b567 100644
--- a/python/pyspark/ml/tests/test_stat.py
+++ b/python/pyspark/ml/tests/test_stat.py
@@ -44,7 +44,7 @@ def test_chisquaretest(self):
from pyspark.ml.tests.test_stat import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/ml/tests/test_training_summary.py b/python/pyspark/ml/tests/test_training_summary.py
index 27d9c182cf..5704d71867 100644
--- a/python/pyspark/ml/tests/test_training_summary.py
+++ b/python/pyspark/ml/tests/test_training_summary.py
@@ -486,7 +486,7 @@ def test_kmeans_summary(self):
from pyspark.ml.tests.test_training_summary import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/ml/tests/test_tuning.py b/python/pyspark/ml/tests/test_tuning.py
index c4273f36d7..d9a5c51fd5 100644
--- a/python/pyspark/ml/tests/test_tuning.py
+++ b/python/pyspark/ml/tests/test_tuning.py
@@ -1027,7 +1027,7 @@ def test_copy(self):
from pyspark.ml.tests.test_tuning import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/ml/tests/test_util.py b/python/pyspark/ml/tests/test_util.py
index 4d5c6a4727..55c973831b 100644
--- a/python/pyspark/ml/tests/test_util.py
+++ b/python/pyspark/ml/tests/test_util.py
@@ -77,7 +77,7 @@ def _check_uid_set_equal(stages, expected_stages):
from pyspark.ml.tests.test_util import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/ml/tests/test_wrapper.py b/python/pyspark/ml/tests/test_wrapper.py
index 02ce6f3192..33d93c02ac 100644
--- a/python/pyspark/ml/tests/test_wrapper.py
+++ b/python/pyspark/ml/tests/test_wrapper.py
@@ -130,7 +130,7 @@ def test_new_java_array(self):
from pyspark.ml.tests.test_wrapper import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/mllib/tests/test_algorithms.py b/python/pyspark/mllib/tests/test_algorithms.py
index 8882242259..6a9be99ecd 100644
--- a/python/pyspark/mllib/tests/test_algorithms.py
+++ b/python/pyspark/mllib/tests/test_algorithms.py
@@ -338,7 +338,7 @@ def test_fpgrowth(self):
from pyspark.mllib.tests.test_algorithms import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/mllib/tests/test_feature.py b/python/pyspark/mllib/tests/test_feature.py
index 080a2bf1f5..ca06f39da2 100644
--- a/python/pyspark/mllib/tests/test_feature.py
+++ b/python/pyspark/mllib/tests/test_feature.py
@@ -184,7 +184,7 @@ def test_pca(self):
from pyspark.mllib.tests.test_feature import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/mllib/tests/test_linalg.py b/python/pyspark/mllib/tests/test_linalg.py
index 007f42d3c2..d137c88836 100644
--- a/python/pyspark/mllib/tests/test_linalg.py
+++ b/python/pyspark/mllib/tests/test_linalg.py
@@ -672,7 +672,7 @@ def test_regression(self):
from pyspark.mllib.tests.test_linalg import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/mllib/tests/test_stat.py b/python/pyspark/mllib/tests/test_stat.py
index 7a33d773d1..cef1294ada 100644
--- a/python/pyspark/mllib/tests/test_stat.py
+++ b/python/pyspark/mllib/tests/test_stat.py
@@ -198,7 +198,7 @@ def test_R_implementation_equivalence(self):
from pyspark.mllib.tests.test_stat import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/mllib/tests/test_streaming_algorithms.py b/python/pyspark/mllib/tests/test_streaming_algorithms.py
index 779fff7090..5a06742ba7 100644
--- a/python/pyspark/mllib/tests/test_streaming_algorithms.py
+++ b/python/pyspark/mllib/tests/test_streaming_algorithms.py
@@ -463,7 +463,7 @@ def condition():
from pyspark.mllib.tests.test_streaming_algorithms import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/mllib/tests/test_util.py b/python/pyspark/mllib/tests/test_util.py
index aad1349c71..28a53af0aa 100644
--- a/python/pyspark/mllib/tests/test_util.py
+++ b/python/pyspark/mllib/tests/test_util.py
@@ -100,7 +100,7 @@ def test_to_java_object_rdd(self): # SPARK-6660
from pyspark.mllib.tests.test_util import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/data_type_ops/test_base.py b/python/pyspark/pandas/tests/data_type_ops/test_base.py
index db4724b982..9b40d15db6 100644
--- a/python/pyspark/pandas/tests/data_type_ops/test_base.py
+++ b/python/pyspark/pandas/tests/data_type_ops/test_base.py
@@ -95,7 +95,7 @@ def test_bool_ext_ops(self):
from pyspark.pandas.tests.data_type_ops.test_base import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py
index 7135800bd9..6eca20d2db 100644
--- a/python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py
+++ b/python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py
@@ -212,7 +212,7 @@ def test_ge(self):
from pyspark.pandas.tests.data_type_ops.test_binary_ops import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py
index 7376120226..ad7ead6316 100644
--- a/python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py
+++ b/python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py
@@ -813,7 +813,7 @@ def test_ge(self):
from pyspark.pandas.tests.data_type_ops.test_boolean_ops import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py
index 992e3ed70f..41e6c4885d 100644
--- a/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py
+++ b/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py
@@ -550,7 +550,7 @@ def test_ge(self):
from pyspark.pandas.tests.data_type_ops.test_categorical_ops import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py
index bbdf837ce2..2b85e7bb26 100644
--- a/python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py
+++ b/python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py
@@ -356,7 +356,7 @@ def test_ge(self):
from pyspark.pandas.tests.data_type_ops.test_complex_ops import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/data_type_ops/test_date_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_date_ops.py
index b457ab2cc8..2fe8a4c688 100644
--- a/python/pyspark/pandas/tests/data_type_ops/test_date_ops.py
+++ b/python/pyspark/pandas/tests/data_type_ops/test_date_ops.py
@@ -235,7 +235,7 @@ def test_ge(self):
from pyspark.pandas.tests.data_type_ops.test_date_ops import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py
index de9c6acb2c..55d06c07cd 100644
--- a/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py
+++ b/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py
@@ -248,7 +248,7 @@ def setUpClass(cls):
from pyspark.pandas.tests.data_type_ops.test_datetime_ops import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py
index fc6cdd1a43..44ea159f2a 100644
--- a/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py
+++ b/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py
@@ -165,7 +165,7 @@ def test_ge(self):
from pyspark.pandas.tests.data_type_ops.test_null_ops import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py
index cb678ff585..22d4e8d8ff 100644
--- a/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py
+++ b/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py
@@ -694,7 +694,7 @@ def test_ge(self):
from pyspark.pandas.tests.data_type_ops.test_num_ops import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/data_type_ops/test_string_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_string_ops.py
index cc448dc42d..cf785f1ebb 100644
--- a/python/pyspark/pandas/tests/data_type_ops/test_string_ops.py
+++ b/python/pyspark/pandas/tests/data_type_ops/test_string_ops.py
@@ -342,7 +342,7 @@ def test_ge(self):
from pyspark.pandas.tests.data_type_ops.test_string_ops import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/data_type_ops/test_timedelta_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_timedelta_ops.py
index eeaba4d277..3889520ad8 100644
--- a/python/pyspark/pandas/tests/data_type_ops/test_timedelta_ops.py
+++ b/python/pyspark/pandas/tests/data_type_ops/test_timedelta_ops.py
@@ -207,7 +207,7 @@ def test_ge(self):
from pyspark.pandas.tests.data_type_ops.test_timedelta_ops import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/data_type_ops/test_udt_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_udt_ops.py
index 81767af76f..beebc1f320 100644
--- a/python/pyspark/pandas/tests/data_type_ops/test_udt_ops.py
+++ b/python/pyspark/pandas/tests/data_type_ops/test_udt_ops.py
@@ -180,7 +180,7 @@ def test_ge(self):
from pyspark.pandas.tests.data_type_ops.test_udt_ops import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/indexes/test_base.py b/python/pyspark/pandas/tests/indexes/test_base.py
index 9ca31923d5..0e2c640979 100644
--- a/python/pyspark/pandas/tests/indexes/test_base.py
+++ b/python/pyspark/pandas/tests/indexes/test_base.py
@@ -2590,7 +2590,7 @@ def test_multi_index_nunique(self):
from pyspark.pandas.tests.indexes.test_base import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/indexes/test_category.py b/python/pyspark/pandas/tests/indexes/test_category.py
index ba737eb520..10c822a3ca 100644
--- a/python/pyspark/pandas/tests/indexes/test_category.py
+++ b/python/pyspark/pandas/tests/indexes/test_category.py
@@ -459,7 +459,7 @@ def test_map(self):
from pyspark.pandas.tests.indexes.test_category import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/indexes/test_datetime.py b/python/pyspark/pandas/tests/indexes/test_datetime.py
index f715518743..8f8e283f3a 100644
--- a/python/pyspark/pandas/tests/indexes/test_datetime.py
+++ b/python/pyspark/pandas/tests/indexes/test_datetime.py
@@ -254,7 +254,7 @@ def test_map(self):
from pyspark.pandas.tests.indexes.test_datetime import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/indexes/test_timedelta.py b/python/pyspark/pandas/tests/indexes/test_timedelta.py
index b191ff8bfb..654f5ee3a0 100644
--- a/python/pyspark/pandas/tests/indexes/test_timedelta.py
+++ b/python/pyspark/pandas/tests/indexes/test_timedelta.py
@@ -110,7 +110,7 @@ def test_properties(self):
from pyspark.pandas.tests.indexes.test_timedelta import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/plot/test_frame_plot.py b/python/pyspark/pandas/tests/plot/test_frame_plot.py
index 5d265ff2ee..817ea896e7 100644
--- a/python/pyspark/pandas/tests/plot/test_frame_plot.py
+++ b/python/pyspark/pandas/tests/plot/test_frame_plot.py
@@ -158,7 +158,7 @@ def check_box_multi_columns(psdf):
from pyspark.pandas.tests.plot.test_frame_plot import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py b/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py
index bb400996e2..7c63371098 100644
--- a/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py
+++ b/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py
@@ -477,7 +477,7 @@ def check_kde_plot(pdf, psdf, *args, **kwargs):
from pyspark.pandas.tests.plot.test_frame_plot_matplotlib import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py b/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py
index d169326b7b..f7cf1fc349 100644
--- a/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py
+++ b/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py
@@ -273,7 +273,7 @@ def test_kde_plot(self):
from pyspark.pandas.tests.plot.test_frame_plot_plotly import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/plot/test_series_plot.py b/python/pyspark/pandas/tests/plot/test_series_plot.py
index f3d4ef553b..fab04bac21 100644
--- a/python/pyspark/pandas/tests/plot/test_series_plot.py
+++ b/python/pyspark/pandas/tests/plot/test_series_plot.py
@@ -94,7 +94,7 @@ def check_box_summary(psdf, pdf):
from pyspark.pandas.tests.plot.test_series_plot import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py b/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py
index 680eee13de..c17290c44b 100644
--- a/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py
+++ b/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py
@@ -397,7 +397,7 @@ def test_single_value_hist(self):
from pyspark.pandas.tests.plot.test_series_plot_matplotlib import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py b/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py
index 8a50b1829d..7bd612c1a8 100644
--- a/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py
+++ b/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py
@@ -235,7 +235,7 @@ def test_kde_plot(self):
from pyspark.pandas.tests.plot.test_series_plot_plotly import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_categorical.py b/python/pyspark/pandas/tests/test_categorical.py
index 99f315a43a..d5a660a66e 100644
--- a/python/pyspark/pandas/tests/test_categorical.py
+++ b/python/pyspark/pandas/tests/test_categorical.py
@@ -436,7 +436,7 @@ def test_groupby_transform_without_shortcut(self):
pdf, psdf = self.df_pair
- def identity(x) -> ps.Series[psdf.b.dtype]: # type: ignore[name-defined]
+ def identity(x) -> ps.Series[psdf.b.dtype]:
return x
self.assert_eq(
@@ -796,7 +796,7 @@ def test_set_categories(self):
from pyspark.pandas.tests.test_categorical import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_config.py b/python/pyspark/pandas/tests/test_config.py
index d3900e216c..c1c2299240 100644
--- a/python/pyspark/pandas/tests/test_config.py
+++ b/python/pyspark/pandas/tests/test_config.py
@@ -148,7 +148,7 @@ def test_dir_options(self):
from pyspark.pandas.tests.test_config import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_csv.py b/python/pyspark/pandas/tests/test_csv.py
index 6bdc989c5d..a94125e648 100644
--- a/python/pyspark/pandas/tests/test_csv.py
+++ b/python/pyspark/pandas/tests/test_csv.py
@@ -435,7 +435,7 @@ def test_to_csv_with_partition_cols(self):
from pyspark.pandas.tests.test_csv import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_dataframe.py b/python/pyspark/pandas/tests/test_dataframe.py
index 4e80c680b6..1b06d321e1 100644
--- a/python/pyspark/pandas/tests/test_dataframe.py
+++ b/python/pyspark/pandas/tests/test_dataframe.py
@@ -7074,6 +7074,10 @@ def test_cov(self):
psdf = ps.from_pandas(pdf)
self.assert_eq(pdf.cov(), psdf.cov())
+ @unittest.skipIf(
+ LooseVersion(pd.__version__) < LooseVersion("1.3.0"),
+ "pandas support `Styler.to_latex` since 1.3.0",
+ )
def test_style(self):
# Currently, the `style` function returns a pandas object `Styler` as it is,
# processing only the number of rows declared in `compute.max_rows`.
@@ -7102,7 +7106,7 @@ def check_style():
from pyspark.pandas.tests.test_dataframe import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_dataframe_conversion.py b/python/pyspark/pandas/tests/test_dataframe_conversion.py
index 4e4c9ac2e7..67ff40e9f1 100644
--- a/python/pyspark/pandas/tests/test_dataframe_conversion.py
+++ b/python/pyspark/pandas/tests/test_dataframe_conversion.py
@@ -262,7 +262,7 @@ def test_from_records(self):
from pyspark.pandas.tests.test_dataframe_conversion import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_dataframe_spark_io.py b/python/pyspark/pandas/tests/test_dataframe_spark_io.py
index dd83070a16..9904ff032d 100644
--- a/python/pyspark/pandas/tests/test_dataframe_spark_io.py
+++ b/python/pyspark/pandas/tests/test_dataframe_spark_io.py
@@ -475,7 +475,7 @@ def test_orc_write(self):
from pyspark.pandas.tests.test_dataframe_spark_io import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_default_index.py b/python/pyspark/pandas/tests/test_default_index.py
index dcb120aee4..ddd9e29662 100644
--- a/python/pyspark/pandas/tests/test_default_index.py
+++ b/python/pyspark/pandas/tests/test_default_index.py
@@ -97,7 +97,7 @@ def test_index_distributed_sequence_cleanup(self):
from pyspark.pandas.tests.test_default_index import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_ewm.py b/python/pyspark/pandas/tests/test_ewm.py
index 3ce0bd4507..4d3c98572d 100644
--- a/python/pyspark/pandas/tests/test_ewm.py
+++ b/python/pyspark/pandas/tests/test_ewm.py
@@ -422,7 +422,7 @@ def test_groupby_ewm_func(self):
from pyspark.pandas.tests.test_ewm import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_expanding.py b/python/pyspark/pandas/tests/test_expanding.py
index 77ced41eb8..d712f03f7d 100644
--- a/python/pyspark/pandas/tests/test_expanding.py
+++ b/python/pyspark/pandas/tests/test_expanding.py
@@ -241,7 +241,7 @@ def test_groupby_expanding_kurt(self):
from pyspark.pandas.tests.test_expanding import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_extension.py b/python/pyspark/pandas/tests/test_extension.py
index dd2d08dded..5d4b5dfa76 100644
--- a/python/pyspark/pandas/tests/test_extension.py
+++ b/python/pyspark/pandas/tests/test_extension.py
@@ -140,7 +140,7 @@ def __init__(self, data):
from pyspark.pandas.tests.test_extension import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_frame_spark.py b/python/pyspark/pandas/tests/test_frame_spark.py
index 9b47ceca7a..df090b74d9 100644
--- a/python/pyspark/pandas/tests/test_frame_spark.py
+++ b/python/pyspark/pandas/tests/test_frame_spark.py
@@ -148,7 +148,7 @@ def test_local_checkpoint(self):
from pyspark.pandas.tests.test_frame_spark import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_generic_functions.py b/python/pyspark/pandas/tests/test_generic_functions.py
index d476302205..72e0e47aed 100644
--- a/python/pyspark/pandas/tests/test_generic_functions.py
+++ b/python/pyspark/pandas/tests/test_generic_functions.py
@@ -222,7 +222,7 @@ def test_prod_precision(self):
from pyspark.pandas.tests.test_generic_functions import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_groupby.py b/python/pyspark/pandas/tests/test_groupby.py
index a203f77717..1c940e3abf 100644
--- a/python/pyspark/pandas/tests/test_groupby.py
+++ b/python/pyspark/pandas/tests/test_groupby.py
@@ -2334,7 +2334,7 @@ def add_max2(
def test_apply_negative(self):
def func(_) -> ps.Series[int]:
- return pd.Series([1]) # type: ignore[return-value]
+ return pd.Series([1])
with self.assertRaisesRegex(TypeError, "Series as a return type hint at frame groupby"):
ps.range(10).groupby("id").apply(func)
@@ -3242,7 +3242,7 @@ def test_getitem(self):
from pyspark.pandas.tests.test_groupby import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_indexing.py b/python/pyspark/pandas/tests/test_indexing.py
index c939a69929..9d52c41274 100644
--- a/python/pyspark/pandas/tests/test_indexing.py
+++ b/python/pyspark/pandas/tests/test_indexing.py
@@ -1327,7 +1327,7 @@ def test_index_operator_int(self):
from pyspark.pandas.tests.test_indexing import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_indexops_spark.py b/python/pyspark/pandas/tests/test_indexops_spark.py
index 275ef77f71..8b0b5c87c9 100644
--- a/python/pyspark/pandas/tests/test_indexops_spark.py
+++ b/python/pyspark/pandas/tests/test_indexops_spark.py
@@ -68,7 +68,7 @@ def test_series_apply_negative(self):
from pyspark.pandas.tests.test_indexops_spark import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_internal.py b/python/pyspark/pandas/tests/test_internal.py
index 2ace222ed6..30a4bdcb66 100644
--- a/python/pyspark/pandas/tests/test_internal.py
+++ b/python/pyspark/pandas/tests/test_internal.py
@@ -112,7 +112,7 @@ def test_from_pandas(self):
from pyspark.pandas.tests.test_internal import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_namespace.py b/python/pyspark/pandas/tests/test_namespace.py
index 8f73c65846..c0bda11d98 100644
--- a/python/pyspark/pandas/tests/test_namespace.py
+++ b/python/pyspark/pandas/tests/test_namespace.py
@@ -621,7 +621,7 @@ def test_missing(self):
from pyspark.pandas.tests.test_namespace import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_numpy_compat.py b/python/pyspark/pandas/tests/test_numpy_compat.py
index d16d9996ec..fc6e332782 100644
--- a/python/pyspark/pandas/tests/test_numpy_compat.py
+++ b/python/pyspark/pandas/tests/test_numpy_compat.py
@@ -188,7 +188,7 @@ def test_np_spark_compat_frame(self):
from pyspark.pandas.tests.test_numpy_compat import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
index 71c393dcf3..734e2545d1 100644
--- a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
+++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
@@ -2141,7 +2141,7 @@ def test_series_eq(self):
from pyspark.pandas.tests.test_ops_on_diff_frames import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py b/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py
index 69621e4930..1bc1ab4772 100644
--- a/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py
+++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py
@@ -630,7 +630,7 @@ def test_fillna(self):
from pyspark.pandas.tests.test_ops_on_diff_frames_groupby import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_expanding.py b/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_expanding.py
index 08f17745df..072a83d294 100644
--- a/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_expanding.py
+++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_expanding.py
@@ -99,7 +99,7 @@ def test_groupby_expanding_var(self):
from pyspark.pandas.tests.test_ops_on_diff_frames_groupby_expanding import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_rolling.py b/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_rolling.py
index 04ea448d80..e9a42e79ab 100644
--- a/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_rolling.py
+++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_rolling.py
@@ -99,7 +99,7 @@ def test_groupby_rolling_var(self):
from pyspark.pandas.tests.test_ops_on_diff_frames_groupby_rolling import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_repr.py b/python/pyspark/pandas/tests/test_repr.py
index 271ed0a14c..d1ba46e63f 100644
--- a/python/pyspark/pandas/tests/test_repr.py
+++ b/python/pyspark/pandas/tests/test_repr.py
@@ -178,7 +178,7 @@ def test_repr_float_index(self):
from pyspark.pandas.tests.test_repr import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_resample.py b/python/pyspark/pandas/tests/test_resample.py
index 56106940f1..3b494e05e7 100644
--- a/python/pyspark/pandas/tests/test_resample.py
+++ b/python/pyspark/pandas/tests/test_resample.py
@@ -295,7 +295,7 @@ def test_resample_on(self):
from pyspark.pandas.tests.test_resample import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_reshape.py b/python/pyspark/pandas/tests/test_reshape.py
index 30550a9fba..a7574a5388 100644
--- a/python/pyspark/pandas/tests/test_reshape.py
+++ b/python/pyspark/pandas/tests/test_reshape.py
@@ -483,7 +483,7 @@ def test_merge_asof(self):
from pyspark.pandas.tests.test_reshape import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_rolling.py b/python/pyspark/pandas/tests/test_rolling.py
index be21bf16d4..6c31073d3f 100644
--- a/python/pyspark/pandas/tests/test_rolling.py
+++ b/python/pyspark/pandas/tests/test_rolling.py
@@ -242,7 +242,7 @@ def test_groupby_rolling_kurt(self):
from pyspark.pandas.tests.test_rolling import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_scalars.py b/python/pyspark/pandas/tests/test_scalars.py
index 0c8aa8508f..00900dbdd9 100644
--- a/python/pyspark/pandas/tests/test_scalars.py
+++ b/python/pyspark/pandas/tests/test_scalars.py
@@ -47,7 +47,7 @@ def test_missing(self):
from pyspark.pandas.tests.test_scalars import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_series.py b/python/pyspark/pandas/tests/test_series.py
index e47f716ecf..46a687b36c 100644
--- a/python/pyspark/pandas/tests/test_series.py
+++ b/python/pyspark/pandas/tests/test_series.py
@@ -3392,7 +3392,7 @@ def test_series_stat_fail(self):
from pyspark.pandas.tests.test_series import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_series_conversion.py b/python/pyspark/pandas/tests/test_series_conversion.py
index bc83fdacbe..79c2f1ff30 100644
--- a/python/pyspark/pandas/tests/test_series_conversion.py
+++ b/python/pyspark/pandas/tests/test_series_conversion.py
@@ -68,7 +68,7 @@ def test_to_latex(self):
from pyspark.pandas.tests.test_series_conversion import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_series_datetime.py b/python/pyspark/pandas/tests/test_series_datetime.py
index 1fe078e972..1c392644ed 100644
--- a/python/pyspark/pandas/tests/test_series_datetime.py
+++ b/python/pyspark/pandas/tests/test_series_datetime.py
@@ -287,7 +287,7 @@ def test_unsupported_type(self):
from pyspark.pandas.tests.test_series_datetime import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_series_string.py b/python/pyspark/pandas/tests/test_series_string.py
index 0b778583e7..f82f57981f 100644
--- a/python/pyspark/pandas/tests/test_series_string.py
+++ b/python/pyspark/pandas/tests/test_series_string.py
@@ -336,7 +336,7 @@ def test_string_get_dummies(self):
from pyspark.pandas.tests.test_series_string import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_spark_functions.py b/python/pyspark/pandas/tests/test_spark_functions.py
index c18dc30240..4da20f754d 100644
--- a/python/pyspark/pandas/tests/test_spark_functions.py
+++ b/python/pyspark/pandas/tests/test_spark_functions.py
@@ -34,7 +34,7 @@ def test_repeat(self):
from pyspark.pandas.tests.test_spark_functions import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_sql.py b/python/pyspark/pandas/tests/test_sql.py
index 5a5d6d484b..4d4afb8882 100644
--- a/python/pyspark/pandas/tests/test_sql.py
+++ b/python/pyspark/pandas/tests/test_sql.py
@@ -100,7 +100,7 @@ def test_sql_with_pandas_on_spark_objects(self):
from pyspark.pandas.tests.test_sql import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_stats.py b/python/pyspark/pandas/tests/test_stats.py
index 4fb08ee69e..fa7cff8f3c 100644
--- a/python/pyspark/pandas/tests/test_stats.py
+++ b/python/pyspark/pandas/tests/test_stats.py
@@ -554,7 +554,7 @@ def test_numeric_only_unsupported(self):
from pyspark.pandas.tests.test_stats import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_typedef.py b/python/pyspark/pandas/tests/test_typedef.py
index 1bc5c8cfdd..a5f2b2dc2b 100644
--- a/python/pyspark/pandas/tests/test_typedef.py
+++ b/python/pyspark/pandas/tests/test_typedef.py
@@ -133,7 +133,7 @@ def func() -> pd.DataFrame[np.float_]:
pdf = pd.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5]})
- def func() -> pd.DataFrame[pdf.dtypes]: # type: ignore[name-defined]
+ def func() -> pd.DataFrame[pdf.dtypes]:
pass
expected = StructType([StructField("c0", LongType()), StructField("c1", LongType())])
@@ -143,14 +143,14 @@ def func() -> pd.DataFrame[pdf.dtypes]: # type: ignore[name-defined]
pdf = pd.DataFrame({"a": [1, 2, 3], "b": pd.Categorical(["a", "b", "c"])})
- def func() -> pd.Series[pdf.b.dtype]: # type: ignore[name-defined]
+ def func() -> pd.Series[pdf.b.dtype]:
pass
inferred = infer_return_type(func)
self.assertEqual(inferred.dtype, CategoricalDtype(categories=["a", "b", "c"]))
self.assertEqual(inferred.spark_type, LongType())
- def func() -> pd.DataFrame[pdf.dtypes]: # type: ignore[name-defined]
+ def func() -> pd.DataFrame[pdf.dtypes]:
pass
expected = StructType([StructField("c0", LongType()), StructField("c1", LongType())])
@@ -246,7 +246,7 @@ def f() -> 'pd.DataFrame["a" : float : 1, "b":str:2]': # noqa: F405
pdf = pd.DataFrame({"a": ["a", 2, None]})
def try_infer_return_type():
- def f() -> pd.DataFrame[pdf.dtypes]: # type: ignore[name-defined]
+ def f() -> pd.DataFrame[pdf.dtypes]:
pass
infer_return_type(f)
@@ -254,7 +254,7 @@ def f() -> pd.DataFrame[pdf.dtypes]: # type: ignore[name-defined]
self.assertRaisesRegex(TypeError, "object.*not understood", try_infer_return_type)
def try_infer_return_type():
- def f() -> pd.Series[pdf.a.dtype]: # type: ignore[name-defined]
+ def f() -> pd.Series[pdf.a.dtype]:
pass
infer_return_type(f)
@@ -293,7 +293,7 @@ def f() -> 'ps.DataFrame["a" : np.float_ : 1, "b":str:2]': # noqa: F405
pdf = pd.DataFrame({"a": ["a", 2, None]})
def try_infer_return_type():
- def f() -> ps.DataFrame[pdf.dtypes]: # type: ignore[name-defined]
+ def f() -> ps.DataFrame[pdf.dtypes]:
pass
infer_return_type(f)
@@ -301,7 +301,7 @@ def f() -> ps.DataFrame[pdf.dtypes]: # type: ignore[name-defined]
self.assertRaisesRegex(TypeError, "object.*not understood", try_infer_return_type)
def try_infer_return_type():
- def f() -> ps.Series[pdf.a.dtype]: # type: ignore[name-defined]
+ def f() -> ps.Series[pdf.a.dtype]:
pass
infer_return_type(f)
@@ -439,7 +439,7 @@ def test_as_spark_type_extension_float_dtypes(self):
from pyspark.pandas.tests.test_typedef import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_utils.py b/python/pyspark/pandas/tests/test_utils.py
index 11f560c6f5..cfbcb5ba0a 100644
--- a/python/pyspark/pandas/tests/test_utils.py
+++ b/python/pyspark/pandas/tests/test_utils.py
@@ -121,7 +121,7 @@ def lazy_prop(self):
from pyspark.pandas.tests.test_utils import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/pandas/tests/test_window.py b/python/pyspark/pandas/tests/test_window.py
index 49779566c9..d8bc2775fa 100644
--- a/python/pyspark/pandas/tests/test_window.py
+++ b/python/pyspark/pandas/tests/test_window.py
@@ -453,7 +453,7 @@ def test_missing_groupby(self):
from pyspark.pandas.tests.test_window import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/resource/tests/test_resources.py b/python/pyspark/resource/tests/test_resources.py
index b6babf3c6c..81a4ea4f1d 100644
--- a/python/pyspark/resource/tests/test_resources.py
+++ b/python/pyspark/resource/tests/test_resources.py
@@ -75,7 +75,7 @@ def assert_request_contents(exec_reqs, task_reqs):
from pyspark.resource.tests.test_resources import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
index 35c3397de5..da03110c32 100644
--- a/python/pyspark/shuffle.py
+++ b/python/pyspark/shuffle.py
@@ -78,9 +78,12 @@ def _get_local_dirs(sub):
path = os.environ.get("SPARK_LOCAL_DIRS", "/tmp")
dirs = path.split(",")
if len(dirs) > 1:
- # different order in different processes and instances
- rnd = random.Random(os.getpid() + id(dirs))
- random.shuffle(dirs, rnd.random)
+ if sys.version_info < (3, 11):
+ # different order in different processes and instances
+ rnd = random.Random(os.getpid() + id(dirs))
+ random.shuffle(dirs, rnd.random)
+ else:
+ random.shuffle(dirs)
return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs]
diff --git a/python/pyspark/sql/connect/__init__.py b/python/pyspark/sql/connect/__init__.py
index 3df96963f9..4a98368c81 100644
--- a/python/pyspark/sql/connect/__init__.py
+++ b/python/pyspark/sql/connect/__init__.py
@@ -18,5 +18,13 @@
"""Currently Spark Connect is very experimental and the APIs to interact with
Spark through this API are can be changed at any time without warning."""
-
from pyspark.sql.connect.dataframe import DataFrame # noqa: F401
+from pyspark.sql.pandas.utils import (
+ require_minimum_pandas_version,
+ require_minimum_pyarrow_version,
+ require_minimum_grpc_version,
+)
+
+require_minimum_pandas_version()
+require_minimum_pyarrow_version()
+require_minimum_grpc_version()
diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
index 745ca79fda..c4c74f5d6c 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -20,36 +20,19 @@
import uuid
from typing import Iterable, Optional, Any, Union, List, Tuple, Dict
-import grpc # type: ignore
+import grpc
import pandas
import pyarrow as pa
import pyspark.sql.connect.proto as pb2
import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
+import pyspark.sql.connect.types as types
import pyspark.sql.types
from pyspark import cloudpickle
from pyspark.sql.types import (
DataType,
- ByteType,
- ShortType,
- IntegerType,
- FloatType,
- DateType,
- TimestampType,
- DayTimeIntervalType,
- MapType,
- StringType,
- CharType,
- VarcharType,
StructType,
StructField,
- ArrayType,
- DoubleType,
- LongType,
- DecimalType,
- BinaryType,
- BooleanType,
- NullType,
)
@@ -350,73 +333,7 @@ def _to_pandas(self, plan: pb2.Plan) -> "pandas.DataFrame":
return self._execute_and_fetch(req)
def _proto_schema_to_pyspark_schema(self, schema: pb2.DataType) -> DataType:
- if schema.HasField("null"):
- return NullType()
- elif schema.HasField("boolean"):
- return BooleanType()
- elif schema.HasField("binary"):
- return BinaryType()
- elif schema.HasField("byte"):
- return ByteType()
- elif schema.HasField("short"):
- return ShortType()
- elif schema.HasField("integer"):
- return IntegerType()
- elif schema.HasField("long"):
- return LongType()
- elif schema.HasField("float"):
- return FloatType()
- elif schema.HasField("double"):
- return DoubleType()
- elif schema.HasField("decimal"):
- p = schema.decimal.precision if schema.decimal.HasField("precision") else 10
- s = schema.decimal.scale if schema.decimal.HasField("scale") else 0
- return DecimalType(precision=p, scale=s)
- elif schema.HasField("string"):
- return StringType()
- elif schema.HasField("char"):
- return CharType(schema.char.length)
- elif schema.HasField("var_char"):
- return VarcharType(schema.var_char.length)
- elif schema.HasField("date"):
- return DateType()
- elif schema.HasField("timestamp"):
- return TimestampType()
- elif schema.HasField("day_time_interval"):
- start: Optional[int] = (
- schema.day_time_interval.start_field
- if schema.day_time_interval.HasField("start_field")
- else None
- )
- end: Optional[int] = (
- schema.day_time_interval.end_field
- if schema.day_time_interval.HasField("end_field")
- else None
- )
- return DayTimeIntervalType(startField=start, endField=end)
- elif schema.HasField("array"):
- return ArrayType(
- self._proto_schema_to_pyspark_schema(schema.array.element_type),
- schema.array.contains_null,
- )
- elif schema.HasField("struct"):
- fields = [
- StructField(
- f.name,
- self._proto_schema_to_pyspark_schema(f.data_type),
- f.nullable,
- )
- for f in schema.struct.fields
- ]
- return StructType(fields)
- elif schema.HasField("map"):
- return MapType(
- self._proto_schema_to_pyspark_schema(schema.map.key_type),
- self._proto_schema_to_pyspark_schema(schema.map.value_type),
- schema.map.value_contains_null,
- )
- else:
- raise Exception(f"Unsupported data type {schema}")
+ return types.proto_schema_to_pyspark_data_type(schema)
def schema(self, plan: pb2.Plan) -> StructType:
proto_schema = self._analyze(plan).schema
diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py
index e864f6c93e..58d4e3dc41 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -21,18 +21,21 @@
import decimal
import datetime
-from pyspark.sql.types import TimestampType, DayTimeIntervalType, DateType
+from pyspark.sql.types import TimestampType, DayTimeIntervalType, DataType, DateType
import pyspark.sql.connect.proto as proto
+from pyspark.sql.connect.types import pyspark_types_to_proto_types
if TYPE_CHECKING:
- from pyspark.sql.connect._typing import ColumnOrName
+ from pyspark.sql.connect._typing import ColumnOrName, PrimitiveType
from pyspark.sql.connect.client import SparkConnectClient
import pyspark.sql.connect.proto as proto
-# TODO(SPARK-41329): solve the circular import between _typing and this class
-# if we want to reuse _type.PrimitiveType
-PrimitiveType = Union[bool, float, int, str]
+
+JVM_INT_MIN = -(1 << 31)
+JVM_INT_MAX = (1 << 31) - 1
+JVM_LONG_MIN = -(1 << 63)
+JVM_LONG_MAX = (1 << 63) - 1
def _func_op(name: str, doc: str = "") -> Callable[["Column"], "Column"]:
@@ -183,7 +186,12 @@ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
elif isinstance(self._value, bool):
expr.literal.boolean = bool(self._value)
elif isinstance(self._value, int):
- expr.literal.long = int(self._value)
+ if JVM_INT_MIN <= self._value <= JVM_INT_MAX:
+ expr.literal.integer = int(self._value)
+ elif JVM_LONG_MIN <= self._value <= JVM_LONG_MAX:
+ expr.literal.long = int(self._value)
+ else:
+ raise ValueError(f"integer {self._value} out of bounds")
elif isinstance(self._value, float):
expr.literal.double = float(self._value)
elif isinstance(self._value, str):
@@ -355,6 +363,29 @@ def __repr__(self) -> str:
return f"{self._name}({', '.join([str(arg) for arg in self._args])})"
+class CastExpression(Expression):
+ def __init__(
+ self,
+ col: "Column",
+ data_type: Union[DataType, str],
+ ) -> None:
+ super().__init__()
+ self._col = col
+ self._data_type = data_type
+
+ def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
+ fun = proto.Expression()
+ fun.cast.expr.CopyFrom(self._col.to_plan(session))
+ if isinstance(self._data_type, str):
+ fun.cast.type_str = self._data_type
+ else:
+ fun.cast.type.CopyFrom(pyspark_types_to_proto_types(self._data_type))
+ return fun
+
+ def __repr__(self) -> str:
+ return f"({self._col} ({self._data_type}))"
+
+
class Column:
"""
A column in a DataFrame. Column can refer to different things based on the
@@ -530,7 +561,7 @@ def __ne__( # type: ignore[override]
return _func_op("not")(_bin_op("==")(self, other))
# string methods
- def contains(self, other: Union[PrimitiveType, "Column"]) -> "Column":
+ def contains(self, other: Union["PrimitiveType", "Column"]) -> "Column":
"""
Contains the other element. Returns a boolean :class:`Column` based on a string match.
@@ -674,6 +705,9 @@ def substr(self, startPos: Union[int, "Column"], length: Union[int, "Column"]) -
>>> df.select(df.name.substr(1, 3).alias("col")).collect()
[Row(col='Ali'), Row(col='Bob')]
"""
+ from pyspark.sql.connect.function_builder import functions as F
+ from pyspark.sql.connect.functions import lit
+
if type(startPos) != type(length):
raise TypeError(
"startPos and length must be the same type. "
@@ -682,17 +716,16 @@ def substr(self, startPos: Union[int, "Column"], length: Union[int, "Column"]) -
length_t=type(length),
)
)
- from pyspark.sql.connect.function_builder import functions as F
if isinstance(length, int):
- length_exp = self._lit(length)
+ length_exp = lit(length)
elif isinstance(length, Column):
length_exp = length
else:
raise TypeError("Unsupported type for substr().")
if isinstance(startPos, int):
- start_exp = self._lit(startPos)
+ start_exp = lit(startPos)
else:
start_exp = startPos
@@ -702,8 +735,11 @@ def __eq__(self, other: Any) -> "Column": # type: ignore[override]
"""Returns a binary expression with the current column as the left
side and the other expression as the right side.
"""
+ from pyspark.sql.connect._typing import PrimitiveType
+ from pyspark.sql.connect.functions import lit
+
if isinstance(other, get_args(PrimitiveType)):
- other = self._lit(other)
+ other = lit(other)
return scalar_function("==", self, other)
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
@@ -733,10 +769,63 @@ def desc_nulls_last(self) -> "Column":
def name(self) -> str:
return self._expr.name()
- # TODO(SPARK-41329): solve the circular import between functions.py and
- # this class if we want to reuse functions.lit
- def _lit(self, x: Any) -> "Column":
- return Column(LiteralExpression(x))
+ def cast(self, dataType: Union[DataType, str]) -> "Column":
+ """
+ Casts the column into type ``dataType``.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ dataType : :class:`DataType` or str
+ a DataType or Python string literal with a DDL-formatted string
+ to use when parsing the column to the same type.
+
+ Returns
+ -------
+ :class:`Column`
+ Column representing whether each element of Column is cast into new type.
+ """
+ if isinstance(dataType, (DataType, str)):
+ return Column(CastExpression(col=self, data_type=dataType))
+ else:
+ raise TypeError("unexpected type: %s" % type(dataType))
def __repr__(self) -> str:
return "Column<'%s'>" % self._expr.__repr__()
+
+ def otherwise(self, *args: Any, **kwargs: Any) -> None:
+ raise NotImplementedError("otherwise() is not yet implemented.")
+
+ def over(self, *args: Any, **kwargs: Any) -> None:
+ raise NotImplementedError("over() is not yet implemented.")
+
+ def isin(self, *args: Any, **kwargs: Any) -> None:
+ raise NotImplementedError("isin() is not yet implemented.")
+
+ def when(self, *args: Any, **kwargs: Any) -> None:
+ raise NotImplementedError("when() is not yet implemented.")
+
+ def getItem(self, *args: Any, **kwargs: Any) -> None:
+ raise NotImplementedError("getItem() is not yet implemented.")
+
+ def astype(self, *args: Any, **kwargs: Any) -> None:
+ raise NotImplementedError("astype() is not yet implemented.")
+
+ def between(self, *args: Any, **kwargs: Any) -> None:
+ raise NotImplementedError("between() is not yet implemented.")
+
+ def getField(self, *args: Any, **kwargs: Any) -> None:
+ raise NotImplementedError("getField() is not yet implemented.")
+
+ def withField(self, *args: Any, **kwargs: Any) -> None:
+ raise NotImplementedError("withField() is not yet implemented.")
+
+ def dropFields(self, *args: Any, **kwargs: Any) -> None:
+ raise NotImplementedError("dropFields() is not yet implemented.")
+
+ def __getitem__(self, k: Any) -> None:
+ raise NotImplementedError("apply() - __getitem__ is not yet implemented.")
+
+ def __iter__(self) -> None:
+ raise TypeError("Column is not iterable")
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index f268dc431b..08d48bb11f 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -824,6 +824,57 @@ def withColumn(self, colName: str, col: Column) -> "DataFrame":
session=self._session,
)
+ def unpivot(
+ self,
+ ids: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]],
+ values: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]],
+ variableColumnName: str,
+ valueColumnName: str,
+ ) -> "DataFrame":
+ """
+ Returns a new :class:`DataFrame` by unpivot a DataFrame from wide format to long format,
+ optionally leaving identifier columns set.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ ids : list
+ Id columns.
+ values : list, optional
+ Value columns to unpivot.
+ variableColumnName : str
+ Name of the variable column.
+ valueColumnName : str
+ Name of the value column.
+
+ Returns
+ -------
+ :class:`DataFrame`
+ """
+
+ def to_jcols(
+ cols: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]]
+ ) -> List["ColumnOrName"]:
+ if cols is None:
+ lst = []
+ elif isinstance(cols, tuple):
+ lst = list(cols)
+ elif isinstance(cols, list):
+ lst = cols
+ else:
+ lst = [cols]
+ return lst
+
+ return DataFrame.withPlan(
+ plan.Unpivot(
+ self._plan, to_jcols(ids), to_jcols(values), variableColumnName, valueColumnName
+ ),
+ self._session,
+ )
+
+ melt = unpivot
+
def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None:
"""
Prints the first ``n`` rows to the console.
diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py
index 8b36647ae5..dccb6d6e0c 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -23,7 +23,7 @@
SQLExpression,
)
-from typing import Any, TYPE_CHECKING, Union, List, Optional, Tuple
+from typing import Any, TYPE_CHECKING, Union, List, overload, Optional, Tuple
if TYPE_CHECKING:
from pyspark.sql.connect._typing import ColumnOrName
@@ -90,7 +90,10 @@ def col(col: str) -> Column:
def lit(col: Any) -> Column:
- return Column(LiteralExpression(col))
+ if isinstance(col, Column):
+ return col
+ else:
+ return Column(LiteralExpression(col))
# def bitwiseNOT(col: "ColumnOrName") -> Column:
@@ -3208,136 +3211,235 @@ def variance(col: "ColumnOrName") -> Column:
return var_samp(col)
-# String/Binary functions
+# Collection Functions
-def upper(col: "ColumnOrName") -> Column:
+# TODO(SPARK-41434): need to support LambdaFunction Expression first
+# def aggregate(
+# col: "ColumnOrName",
+# initialValue: "ColumnOrName",
+# merge: Callable[[Column, Column], Column],
+# finish: Optional[Callable[[Column], Column]] = None,
+# ) -> Column:
+# """
+# Applies a binary operator to an initial state and all elements in the array,
+# and reduces this to a single state. The final state is converted into the final result
+# by applying a finish function.
+#
+# Both functions can use methods of :class:`~pyspark.sql.Column`, functions defined in
+# :py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``.
+# Python ``UserDefinedFunctions`` are not supported
+# (`SPARK-27052 `__).
+#
+# .. versionadded:: 3.1.0
+#
+# Parameters
+# ----------
+# col : :class:`~pyspark.sql.Column` or str
+# name of column or expression
+# initialValue : :class:`~pyspark.sql.Column` or str
+# initial value. Name of column or expression
+# merge : function
+# a binary function ``(acc: Column, x: Column) -> Column...`` returning expression
+# of the same type as ``zero``
+# finish : function
+# an optional unary function ``(x: Column) -> Column: ...``
+# used to convert accumulated value.
+#
+# Returns
+# -------
+# :class:`~pyspark.sql.Column`
+# final value after aggregate function is applied.
+#
+# Examples
+# --------
+# >>> df = spark.createDataFrame([(1, [20.0, 4.0, 2.0, 6.0, 10.0])], ("id", "values"))
+# >>> df.select(aggregate("values", lit(0.0), lambda acc, x: acc + x).alias("sum")).show()
+# +----+
+# | sum|
+# +----+
+# |42.0|
+# +----+
+#
+# >>> def merge(acc, x):
+# ... count = acc.count + 1
+# ... sum = acc.sum + x
+# ... return struct(count.alias("count"), sum.alias("sum"))
+# >>> df.select(
+# ... aggregate(
+# ... "values",
+# ... struct(lit(0).alias("count"), lit(0.0).alias("sum")),
+# ... merge,
+# ... lambda acc: acc.sum / acc.count,
+# ... ).alias("mean")
+# ... ).show()
+# +----+
+# |mean|
+# +----+
+# | 8.4|
+# +----+
+# """
+# if finish is not None:
+# return _invoke_higher_order_function("ArrayAggregate", [col, initialValue],
+# [merge, finish])
+#
+# else:
+# return _invoke_higher_order_function("ArrayAggregate", [col, initialValue],
+# [merge])
+
+
+def array(*cols: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]) -> Column:
+ """Creates a new array column.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ cols : :class:`~pyspark.sql.Column` or str
+ column names or :class:`~pyspark.sql.Column`\\s that have
+ the same data type.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ a column of array type.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([("Alice", 2), ("Bob", 5)], ("name", "age"))
+ >>> df.select(array('age', 'age').alias("arr")).collect()
+ [Row(arr=[2, 2]), Row(arr=[5, 5])]
+ >>> df.select(array([df.age, df.age]).alias("arr")).collect()
+ [Row(arr=[2, 2]), Row(arr=[5, 5])]
+ >>> df.select(array('age', 'age').alias("col")).printSchema()
+ root
+ |-- col: array (nullable = false)
+ | |-- element: long (containsNull = true)
"""
- Converts a string expression to upper case.
+ if len(cols) == 1 and isinstance(cols[0], (list, set, tuple)):
+ cols = cols[0] # type: ignore[assignment]
+ return _invoke_function_over_columns("array", *cols) # type: ignore[arg-type]
+
+
+def array_contains(col: "ColumnOrName", value: Any) -> Column:
+ """
+ Collection function: returns null if the array is null, true if the array contains the
+ given value, and false otherwise.
.. versionadded:: 3.4.0
Parameters
----------
col : :class:`~pyspark.sql.Column` or str
- target column to work on.
+ name of column containing array
+ value :
+ value or column to check for in array
Returns
-------
:class:`~pyspark.sql.Column`
- upper case values.
+ a column of Boolean type.
Examples
--------
- >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING")
- >>> df.select(upper("value")).show()
- +------------+
- |upper(value)|
- +------------+
- | SPARK|
- | PYSPARK|
- | PANDAS API|
- +------------+
+ >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data'])
+ >>> df.select(array_contains(df.data, "a")).collect()
+ [Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)]
+ >>> df.select(array_contains(df.data, lit("a"))).collect()
+ [Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)]
"""
- return _invoke_function_over_columns("upper", col)
+ return _invoke_function("array_contains", _to_col(col), lit(value))
-def lower(col: "ColumnOrName") -> Column:
+def array_distinct(col: "ColumnOrName") -> Column:
"""
- Converts a string expression to lower case.
+ Collection function: removes duplicate values from the array.
.. versionadded:: 3.4.0
Parameters
----------
col : :class:`~pyspark.sql.Column` or str
- target column to work on.
+ name of column or expression
Returns
-------
:class:`~pyspark.sql.Column`
- lower case values.
+ an array of unique values.
Examples
--------
- >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING")
- >>> df.select(lower("value")).show()
- +------------+
- |lower(value)|
- +------------+
- | spark|
- | pyspark|
- | pandas api|
- +------------+
+ >>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], ['data'])
+ >>> df.select(array_distinct(df.data)).collect()
+ [Row(array_distinct(data)=[1, 2, 3]), Row(array_distinct(data)=[4, 5])]
"""
- return _invoke_function_over_columns("lower", col)
+ return _invoke_function_over_columns("array_distinct", col)
-def ascii(col: "ColumnOrName") -> Column:
+def array_except(col1: "ColumnOrName", col2: "ColumnOrName") -> Column:
"""
- Computes the numeric value of the first character of the string column.
+ Collection function: returns an array of the elements in col1 but not in col2,
+ without duplicates.
.. versionadded:: 3.4.0
Parameters
----------
- col : :class:`~pyspark.sql.Column` or str
- target column to work on.
+ col1 : :class:`~pyspark.sql.Column` or str
+ name of column containing array
+ col2 : :class:`~pyspark.sql.Column` or str
+ name of column containing array
Returns
-------
:class:`~pyspark.sql.Column`
- numeric value.
+ an array of values from first array that are not in the second.
Examples
--------
- >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING")
- >>> df.select(ascii("value")).show()
- +------------+
- |ascii(value)|
- +------------+
- | 83|
- | 80|
- | 80|
- +------------+
+ >>> from pyspark.sql import Row
+ >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])])
+ >>> df.select(array_except(df.c1, df.c2)).collect()
+ [Row(array_except(c1, c2)=['b'])]
"""
- return _invoke_function_over_columns("ascii", col)
+ return _invoke_function_over_columns("array_except", col1, col2)
-def base64(col: "ColumnOrName") -> Column:
+def array_intersect(col1: "ColumnOrName", col2: "ColumnOrName") -> Column:
"""
- Computes the BASE64 encoding of a binary column and returns it as a string column.
+ Collection function: returns an array of the elements in the intersection of col1 and col2,
+ without duplicates.
.. versionadded:: 3.4.0
Parameters
----------
- col : :class:`~pyspark.sql.Column` or str
- target column to work on.
+ col1 : :class:`~pyspark.sql.Column` or str
+ name of column containing array
+ col2 : :class:`~pyspark.sql.Column` or str
+ name of column containing array
Returns
-------
:class:`~pyspark.sql.Column`
- BASE64 encoding of string value.
+ an array of values in the intersection of two arrays.
Examples
--------
- >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING")
- >>> df.select(base64("value")).show()
- +----------------+
- | base64(value)|
- +----------------+
- | U3Bhcms=|
- | UHlTcGFyaw==|
- |UGFuZGFzIEFQSQ==|
- +----------------+
+ >>> from pyspark.sql import Row
+ >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])])
+ >>> df.select(array_intersect(df.c1, df.c2)).collect()
+ [Row(array_intersect(c1, c2)=['a', 'c'])]
"""
- return _invoke_function_over_columns("base64", col)
+ return _invoke_function_over_columns("array_intersect", col1, col2)
-def unbase64(col: "ColumnOrName") -> Column:
+def array_join(
+ col: "ColumnOrName", delimiter: str, null_replacement: Optional[str] = None
+) -> Column:
"""
- Decodes a BASE64 encoded string column and returns it as a binary column.
+ Concatenates the elements of `column` using the `delimiter`. Null values are replaced with
+ `null_replacement` if set, otherwise they are ignored.
.. versionadded:: 3.4.0
@@ -3345,210 +3447,3510 @@ def unbase64(col: "ColumnOrName") -> Column:
----------
col : :class:`~pyspark.sql.Column` or str
target column to work on.
+ delimiter : str
+ delimiter used to concatenate elements
+ null_replacement : str, optional
+ if set then null values will be replaced by this value
Returns
-------
:class:`~pyspark.sql.Column`
- encoded string value.
+ a column of string type. Concatenated values.
Examples
--------
- >>> df = spark.createDataFrame(["U3Bhcms=",
- ... "UHlTcGFyaw==",
- ... "UGFuZGFzIEFQSQ=="], "STRING")
- >>> df.select(unbase64("value")).show()
- +--------------------+
- | unbase64(value)|
- +--------------------+
- | [53 70 61 72 6B]|
- |[50 79 53 70 61 7...|
- |[50 61 6E 64 61 7...|
- +--------------------+
+ >>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data'])
+ >>> df.select(array_join(df.data, ",").alias("joined")).collect()
+ [Row(joined='a,b,c'), Row(joined='a')]
+ >>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect()
+ [Row(joined='a,b,c'), Row(joined='a,NULL')]
+ """
+ if null_replacement is None:
+ return _invoke_function("array_join", _to_col(col), lit(delimiter))
+ else:
+ return _invoke_function("array_join", _to_col(col), lit(delimiter), lit(null_replacement))
+
+
+def array_max(col: "ColumnOrName") -> Column:
"""
- return _invoke_function_over_columns("unbase64", col)
+ Collection function: returns the maximum value of the array.
+ .. versionadded:: 3.4.0
-def ltrim(col: "ColumnOrName") -> Column:
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ name of column or expression
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ maximum value of an array.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data'])
+ >>> df.select(array_max(df.data).alias('max')).collect()
+ [Row(max=3), Row(max=10)]
"""
- Trim the spaces from left end for the specified string value.
+ return _invoke_function_over_columns("array_max", col)
+
+
+def array_min(col: "ColumnOrName") -> Column:
+ """
+ Collection function: returns the minimum value of the array.
.. versionadded:: 3.4.0
Parameters
----------
col : :class:`~pyspark.sql.Column` or str
- target column to work on.
+ name of column or expression
Returns
-------
:class:`~pyspark.sql.Column`
- left trimmed values.
+ minimum value of array.
Examples
--------
- >>> df = spark.createDataFrame([" Spark", "Spark ", " Spark"], "STRING")
- >>> df.select(ltrim("value").alias("r")).withColumn("length", length("r")).show()
- +-------+------+
- | r|length|
- +-------+------+
- | Spark| 5|
- |Spark | 7|
- | Spark| 5|
- +-------+------+
+ >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data'])
+ >>> df.select(array_min(df.data).alias('min')).collect()
+ [Row(min=1), Row(min=-1)]
"""
- return _invoke_function_over_columns("ltrim", col)
+ return _invoke_function_over_columns("array_min", col)
-def rtrim(col: "ColumnOrName") -> Column:
+def array_position(col: "ColumnOrName", value: Any) -> Column:
"""
- Trim the spaces from right end for the specified string value.
+ Collection function: Locates the position of the first occurrence of the given value
+ in the given array. Returns null if either of the arguments are null.
.. versionadded:: 3.4.0
+ Notes
+ -----
+ The position is not zero based, but 1 based index. Returns 0 if the given
+ value could not be found in the array.
+
Parameters
----------
col : :class:`~pyspark.sql.Column` or str
target column to work on.
+ value : Any
+ value to look for.
Returns
-------
:class:`~pyspark.sql.Column`
- right trimmed values.
+ position of the value in the given array if found and 0 otherwise.
Examples
--------
- >>> df = spark.createDataFrame([" Spark", "Spark ", " Spark"], "STRING")
- >>> df.select(rtrim("value").alias("r")).withColumn("length", length("r")).show()
- +--------+------+
- | r|length|
- +--------+------+
- | Spark| 8|
- | Spark| 5|
- | Spark| 6|
- +--------+------+
+ >>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data'])
+ >>> df.select(array_position(df.data, "a")).collect()
+ [Row(array_position(data, a)=3), Row(array_position(data, a)=0)]
"""
- return _invoke_function_over_columns("rtrim", col)
+ return _invoke_function("array_position", _to_col(col), lit(value))
-def trim(col: "ColumnOrName") -> Column:
+def array_remove(col: "ColumnOrName", element: Any) -> Column:
"""
- Trim the spaces from both ends for the specified string column.
+ Collection function: Remove all elements that equal to element from the given array.
.. versionadded:: 3.4.0
Parameters
----------
col : :class:`~pyspark.sql.Column` or str
- target column to work on.
+ name of column containing array
+ element :
+ element to be removed from the array
Returns
-------
:class:`~pyspark.sql.Column`
- trimmed values from both sides.
+ an array excluding given value.
Examples
--------
- >>> df = spark.createDataFrame([" Spark", "Spark ", " Spark"], "STRING")
- >>> df.select(trim("value").alias("r")).withColumn("length", length("r")).show()
- +-----+------+
- | r|length|
- +-----+------+
- |Spark| 5|
- |Spark| 5|
- |Spark| 5|
- +-----+------+
+ >>> df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ['data'])
+ >>> df.select(array_remove(df.data, 1)).collect()
+ [Row(array_remove(data, 1)=[2, 3]), Row(array_remove(data, 1)=[])]
"""
- return _invoke_function_over_columns("trim", col)
+ return _invoke_function("array_remove", _to_col(col), lit(element))
-def concat_ws(sep: str, *cols: "ColumnOrName") -> Column:
+def array_repeat(col: "ColumnOrName", count: Union["ColumnOrName", int]) -> Column:
"""
- Concatenates multiple input string columns together into a single string column,
- using the given separator.
+ Collection function: creates an array containing a column repeated count times.
.. versionadded:: 3.4.0
Parameters
----------
- sep : str
- words separator.
- cols : :class:`~pyspark.sql.Column` or str
- list of columns to work on.
+ col : :class:`~pyspark.sql.Column` or str
+ column name or column that contains the element to be repeated
+ count : :class:`~pyspark.sql.Column` or str or int
+ column name, column, or int containing the number of times to repeat the first argument
Returns
-------
:class:`~pyspark.sql.Column`
- string of concatenated words.
+ an array of repeated elements.
Examples
--------
- >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
- >>> df.select(concat_ws('-', df.s, df.d).alias('s')).collect()
- [Row(s='abcd-123')]
+ >>> df = spark.createDataFrame([('ab',)], ['data'])
+ >>> df.select(array_repeat(df.data, 3).alias('r')).collect()
+ [Row(r=['ab', 'ab', 'ab'])]
"""
- return _invoke_function("concat_ws", lit(sep), *[_to_col(c) for c in cols])
+ _count = lit(count) if isinstance(count, int) else _to_col(count)
+
+ return _invoke_function("array_repeat", _to_col(col), _count)
-# TODO: enable with SPARK-41402
-# def decode(col: "ColumnOrName", charset: str) -> Column:
+# TODO(SPARK-41434): need to support LambdaFunction Expression first
+# def array_sort(
+# col: "ColumnOrName", comparator: Optional[Callable[[Column, Column], Column]] = None
+# ) -> Column:
# """
-# Computes the first argument into a string from a binary using the provided character set
-# (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
+# Collection function: sorts the input array in ascending order. The elements of the input array
+# must be orderable. Null elements will be placed at the end of the returned array.
#
-# .. versionadded:: 3.4.0
+# .. versionadded:: 2.4.0
+# .. versionchanged:: 3.4.0
+# Can take a `comparator` function.
#
# Parameters
# ----------
# col : :class:`~pyspark.sql.Column` or str
-# target column to work on.
-# charset : str
-# charset to use to decode to.
+# name of column or expression
+# comparator : callable, optional
+# A binary ``(Column, Column) -> Column: ...``.
+# The comparator will take two
+# arguments representing two elements of the array. It returns a negative integer, 0, or a
+# positive integer as the first element is less than, equal to, or greater than the second
+# element. If the comparator function returns null, the function will fail and raise an
+# error.
#
# Returns
# -------
# :class:`~pyspark.sql.Column`
-# the column for computed results.
+# sorted array.
#
# Examples
# --------
-# >>> df = spark.createDataFrame([('abcd',)], ['a'])
-# >>> df.select(decode("a", "UTF-8")).show()
-# +----------------------+
-# |stringdecode(a, UTF-8)|
-# +----------------------+
-# | abcd|
-# +----------------------+
+# >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data'])
+# >>> df.select(array_sort(df.data).alias('r')).collect()
+# [Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])]
+# >>> df = spark.createDataFrame([(["foo", "foobar", None, "bar"],),(["foo"],),([],)], ['data'])
+# >>> df.select(array_sort(
+# ... "data",
+# ... lambda x, y: when(x.isNull() | y.isNull(), lit(0)).otherwise(length(y) - length(x))
+# ... ).alias("r")).collect()
+# [Row(r=['foobar', 'foo', None, 'bar']), Row(r=['foo']), Row(r=[])]
# """
-# return _invoke_function("decode", _to_col(col), lit(charset))
+# if comparator is None:
+# return _invoke_function_over_columns("array_sort", col)
+# else:
+# return _invoke_higher_order_function("ArraySort", [col], [comparator])
-def encode(col: "ColumnOrName", charset: str) -> Column:
+def array_union(col1: "ColumnOrName", col2: "ColumnOrName") -> Column:
"""
- Computes the first argument into a binary from a string using the provided character set
- (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
+ Collection function: returns an array of the elements in the union of col1 and col2,
+ without duplicates.
.. versionadded:: 3.4.0
Parameters
----------
- col : :class:`~pyspark.sql.Column` or str
- target column to work on.
- charset : str
- charset to use to encode.
+ col1 : :class:`~pyspark.sql.Column` or str
+ name of column containing array
+ col2 : :class:`~pyspark.sql.Column` or str
+ name of column containing array
Returns
-------
:class:`~pyspark.sql.Column`
- the column for computed results.
+ an array of values in union of two arrays.
Examples
--------
- >>> df = spark.createDataFrame([('abcd',)], ['c'])
- >>> df.select(encode("c", "UTF-8")).show()
- +----------------+
- |encode(c, UTF-8)|
- +----------------+
- | [61 62 63 64]|
- +----------------+
+ >>> from pyspark.sql import Row
+ >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])])
+ >>> df.select(array_union(df.c1, df.c2)).collect()
+ [Row(array_union(c1, c2)=['b', 'a', 'c', 'd', 'f'])]
"""
- return _invoke_function("encode", _to_col(col), lit(charset))
+ return _invoke_function_over_columns("array_union", col1, col2)
+
+
+def arrays_overlap(a1: "ColumnOrName", a2: "ColumnOrName") -> Column:
+ """
+ Collection function: returns true if the arrays contain any common non-null element; if not,
+ returns null if both the arrays are non-empty and any of them contains a null element; returns
+ false otherwise.
+
+ .. versionadded:: 3.4.0
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ a column of Boolean type.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ['x', 'y'])
+ >>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect()
+ [Row(overlap=True), Row(overlap=False)]
+ """
+ return _invoke_function_over_columns("arrays_overlap", a1, a2)
+
+
+def arrays_zip(*cols: "ColumnOrName") -> Column:
+ """
+ Collection function: Returns a merged array of structs in which the N-th struct contains all
+ N-th values of input arrays. If one of the arrays is shorter than others then
+ resulting struct type value will be a `null` for missing elements.
+
+ .. versionadded:: 2.4.0
+
+ Parameters
+ ----------
+ cols : :class:`~pyspark.sql.Column` or str
+ columns of arrays to be merged.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ merged array of entries.
+
+ Examples
+ --------
+ >>> from pyspark.sql.functions import arrays_zip
+ >>> df = spark.createDataFrame([(([1, 2, 3], [2, 4, 6], [3, 6]))], ['vals1', 'vals2', 'vals3'])
+ >>> df = df.select(arrays_zip(df.vals1, df.vals2, df.vals3).alias('zipped'))
+ >>> df.show(truncate=False)
+ +------------------------------------+
+ |zipped |
+ +------------------------------------+
+ |[{1, 2, 3}, {2, 4, 6}, {3, 6, null}]|
+ +------------------------------------+
+ >>> df.printSchema()
+ root
+ |-- zipped: array (nullable = true)
+ | |-- element: struct (containsNull = false)
+ | | |-- vals1: long (nullable = true)
+ | | |-- vals2: long (nullable = true)
+ | | |-- vals3: long (nullable = true)
+ """
+ return _invoke_function_over_columns("arrays_zip", *cols)
+
+
+def concat(*cols: "ColumnOrName") -> Column:
+ """
+ Concatenates multiple input columns together into a single column.
+ The function works with strings, numeric, binary and compatible array columns.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ cols : :class:`~pyspark.sql.Column` or str
+ target column or columns to work on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ concatenated values. Type of the `Column` depends on input columns' type.
+
+ See Also
+ --------
+ :meth:`pyspark.sql.functions.array_join` : to concatenate string columns with delimiter
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
+ >>> df = df.select(concat(df.s, df.d).alias('s'))
+ >>> df.collect()
+ [Row(s='abcd123')]
+ >>> df
+ DataFrame[s: string]
+
+ >>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a', 'b', 'c'])
+ >>> df = df.select(concat(df.a, df.b, df.c).alias("arr"))
+ >>> df.collect()
+ [Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)]
+ >>> df
+ DataFrame[arr: array]
+ """
+ return _invoke_function_over_columns("concat", *cols)
+
+
+def create_map(
+ *cols: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]
+) -> Column:
+ """Creates a new map column.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ cols : :class:`~pyspark.sql.Column` or str
+ column names or :class:`~pyspark.sql.Column`\\s that are
+ grouped as key-value pairs, e.g. (key1, value1, key2, value2, ...).
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([("Alice", 2), ("Bob", 5)], ("name", "age"))
+ >>> df.select(create_map('name', 'age').alias("map")).collect()
+ [Row(map={'Alice': 2}), Row(map={'Bob': 5})]
+ >>> df.select(create_map([df.name, df.age]).alias("map")).collect()
+ [Row(map={'Alice': 2}), Row(map={'Bob': 5})]
+ """
+ if len(cols) == 1 and isinstance(cols[0], (list, set, tuple)):
+ cols = cols[0] # type: ignore[assignment]
+ return _invoke_function_over_columns("map", *cols) # type: ignore[arg-type]
+
+
+def element_at(col: "ColumnOrName", extraction: Any) -> Column:
+ """
+ Collection function: Returns element of array at given index in `extraction` if col is array.
+ Returns value for the given key in `extraction` if col is map. If position is negative
+ then location of the element will start from end, if number is outside the
+ array boundaries then None will be returned.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ name of column containing array or map
+ extraction :
+ index to check for in array or key to check for in map
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ value at given position.
+
+ Notes
+ -----
+ The position is not zero based, but 1 based index.
+
+ See Also
+ --------
+ :meth:`get`
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([(["a", "b", "c"],)], ['data'])
+ >>> df.select(element_at(df.data, 1)).collect()
+ [Row(element_at(data, 1)='a')]
+ >>> df.select(element_at(df.data, -1)).collect()
+ [Row(element_at(data, -1)='c')]
+
+ >>> df = spark.createDataFrame([({"a": 1.0, "b": 2.0},)], ['data'])
+ >>> df.select(element_at(df.data, lit("a"))).collect()
+ [Row(element_at(data, a)=1.0)]
+ """
+ return _invoke_function("element_at", _to_col(col), lit(extraction))
+
+
+# TODO(SPARK-41434): need to support LambdaFunction Expression first
+# def exists(col: "ColumnOrName", f: Callable[[Column], Column]) -> Column:
+# """
+# Returns whether a predicate holds for one or more elements in the array.
+#
+# .. versionadded:: 3.1.0
+#
+# Parameters
+# ----------
+# col : :class:`~pyspark.sql.Column` or str
+# name of column or expression
+# f : function
+# ``(x: Column) -> Column: ...`` returning the Boolean expression.
+# Can use methods of :class:`~pyspark.sql.Column`, functions defined in
+# :py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``.
+# Python ``UserDefinedFunctions`` are not supported
+# (`SPARK-27052 `__).
+#
+# Returns
+# -------
+# :class:`~pyspark.sql.Column`
+# True if "any" element of an array evaluates to True when passed as an argument to
+# given function and False otherwise.
+#
+# Examples
+# --------
+# >>> df = spark.createDataFrame([(1, [1, 2, 3, 4]), (2, [3, -1, 0])],("key", "values"))
+# >>> df.select(exists("values", lambda x: x < 0).alias("any_negative")).show()
+# +------------+
+# |any_negative|
+# +------------+
+# | false|
+# | true|
+# +------------+
+# """
+# return _invoke_higher_order_function("ArrayExists", [col], [f])
+
+
+def explode(col: "ColumnOrName") -> Column:
+ """
+ Returns a new row for each element in the given array or map.
+ Uses the default column name `col` for elements in the array and
+ `key` and `value` for elements in the map unless specified otherwise.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target column to work on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ one row per array item or map key value.
+
+ See Also
+ --------
+ :meth:`pyspark.functions.posexplode`
+ :meth:`pyspark.functions.explode_outer`
+ :meth:`pyspark.functions.posexplode_outer`
+
+ Examples
+ --------
+ >>> from pyspark.sql import Row
+ >>> eDF = spark.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})])
+ >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect()
+ [Row(anInt=1), Row(anInt=2), Row(anInt=3)]
+
+ >>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show()
+ +---+-----+
+ |key|value|
+ +---+-----+
+ | a| b|
+ +---+-----+
+ """
+ return _invoke_function_over_columns("explode", col)
+
+
+def explode_outer(col: "ColumnOrName") -> Column:
+ """
+ Returns a new row for each element in the given array or map.
+ Unlike explode, if the array/map is null or empty then null is produced.
+ Uses the default column name `col` for elements in the array and
+ `key` and `value` for elements in the map unless specified otherwise.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target column to work on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ one row per array item or map key value.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame(
+ ... [(1, ["foo", "bar"], {"x": 1.0}), (2, [], {}), (3, None, None)],
+ ... ("id", "an_array", "a_map")
+ ... )
+ >>> df.select("id", "an_array", explode_outer("a_map")).show()
+ +---+----------+----+-----+
+ | id| an_array| key|value|
+ +---+----------+----+-----+
+ | 1|[foo, bar]| x| 1.0|
+ | 2| []|null| null|
+ | 3| null|null| null|
+ +---+----------+----+-----+
+
+ >>> df.select("id", "a_map", explode_outer("an_array")).show()
+ +---+----------+----+
+ | id| a_map| col|
+ +---+----------+----+
+ | 1|{x -> 1.0}| foo|
+ | 1|{x -> 1.0}| bar|
+ | 2| {}|null|
+ | 3| null|null|
+ +---+----------+----+
+ """
+ return _invoke_function_over_columns("explode_outer", col)
+
+
+# TODO(SPARK-41434): need to support LambdaFunction Expression first
+# def filter(
+# col: "ColumnOrName",
+# f: Union[Callable[[Column], Column], Callable[[Column, Column], Column]],
+# ) -> Column:
+# """
+# Returns an array of elements for which a predicate holds in a given array.
+#
+# .. versionadded:: 3.1.0
+#
+# Parameters
+# ----------
+# col : :class:`~pyspark.sql.Column` or str
+# name of column or expression
+# f : function
+# A function that returns the Boolean expression.
+# Can take one of the following forms:
+#
+# - Unary ``(x: Column) -> Column: ...``
+# - Binary ``(x: Column, i: Column) -> Column...``, where the second argument is
+# a 0-based index of the element.
+#
+# and can use methods of :class:`~pyspark.sql.Column`, functions defined in
+# :py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``.
+# Python ``UserDefinedFunctions`` are not supported
+# (`SPARK-27052 `__).
+#
+# Returns
+# -------
+# :class:`~pyspark.sql.Column`
+# filtered array of elements where given function evaluated to True
+# when passed as an argument.
+#
+# Examples
+# --------
+# >>> df = spark.createDataFrame(
+# ... [(1, ["2018-09-20", "2019-02-03", "2019-07-01", "2020-06-01"])],
+# ... ("key", "values")
+# ... )
+# >>> def after_second_quarter(x):
+# ... return month(to_date(x)) > 6
+# >>> df.select(
+# ... filter("values", after_second_quarter).alias("after_second_quarter")
+# ... ).show(truncate=False)
+# +------------------------+
+# |after_second_quarter |
+# +------------------------+
+# |[2018-09-20, 2019-07-01]|
+# +------------------------+
+# """
+# return _invoke_higher_order_function("ArrayFilter", [col], [f])
+
+
+def flatten(col: "ColumnOrName") -> Column:
+ """
+ Collection function: creates a single array from an array of arrays.
+ If a structure of nested arrays is deeper than two levels,
+ only one level of nesting is removed.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ name of column or expression
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ flattened array.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ['data'])
+ >>> df.show(truncate=False)
+ +------------------------+
+ |data |
+ +------------------------+
+ |[[1, 2, 3], [4, 5], [6]]|
+ |[null, [4, 5]] |
+ +------------------------+
+ >>> df.select(flatten(df.data).alias('r')).show()
+ +------------------+
+ | r|
+ +------------------+
+ |[1, 2, 3, 4, 5, 6]|
+ | null|
+ +------------------+
+ """
+ return _invoke_function_over_columns("flatten", col)
+
+
+# TODO(SPARK-41434): need to support LambdaFunction Expression first
+# def forall(col: "ColumnOrName", f: Callable[[Column], Column]) -> Column:
+# """
+# Returns whether a predicate holds for every element in the array.
+#
+# .. versionadded:: 3.1.0
+#
+# Parameters
+# ----------
+# col : :class:`~pyspark.sql.Column` or str
+# name of column or expression
+# f : function
+# ``(x: Column) -> Column: ...`` returning the Boolean expression.
+# Can use methods of :class:`~pyspark.sql.Column`, functions defined in
+# :py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``.
+# Python ``UserDefinedFunctions`` are not supported
+# (`SPARK-27052 `__).
+#
+# Returns
+# -------
+# :class:`~pyspark.sql.Column`
+# True if "all" elements of an array evaluates to True when passed as an argument to
+# given function and False otherwise.
+#
+# Examples
+# --------
+# >>> df = spark.createDataFrame(
+# ... [(1, ["bar"]), (2, ["foo", "bar"]), (3, ["foobar", "foo"])],
+# ... ("key", "values")
+# ... )
+# >>> df.select(forall("values", lambda x: x.rlike("foo")).alias("all_foo")).show()
+# +-------+
+# |all_foo|
+# +-------+
+# | false|
+# | false|
+# | true|
+# +-------+
+# """
+# return _invoke_higher_order_function("ArrayForAll", [col], [f])
+
+
+# TODO: support options
+def from_csv(
+ col: "ColumnOrName",
+ schema: Union[Column, str],
+) -> Column:
+ """
+ Parses a column containing a CSV string to a row with the specified schema.
+ Returns `null`, in the case of an unparseable string.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ a column or column name in CSV format
+ schema :class:`~pyspark.sql.Column` or str
+ a column, or Python string literal with schema in DDL format, to use
+ when parsing the CSV column.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ a column of parsed CSV values
+
+ Examples
+ --------
+ >>> data = [("1,2,3",)]
+ >>> df = spark.createDataFrame(data, ("value",))
+ >>> df.select(from_csv(df.value, "a INT, b INT, c INT").alias("csv")).collect()
+ [Row(csv=Row(a=1, b=2, c=3))]
+ >>> value = data[0][0]
+ >>> df.select(from_csv(df.value, schema_of_csv(value)).alias("csv")).collect()
+ [Row(csv=Row(_c0=1, _c1=2, _c2=3))]
+ >>> data = [(" abc",)]
+ >>> df = spark.createDataFrame(data, ("value",))
+ >>> options = {'ignoreLeadingWhiteSpace': True}
+ >>> df.select(from_csv(df.value, "s string", options).alias("csv")).collect()
+ [Row(csv=Row(s='abc'))]
+ """
+
+ if isinstance(schema, Column):
+ _schema = schema
+ elif isinstance(schema, str):
+ _schema = lit(schema)
+ else:
+ raise TypeError(f"schema should be a Column or str, but got {type(schema).__name__}")
+
+ return _invoke_function("from_csv", _to_col(col), _schema)
+
+
+# TODO: 1, support ArrayType and StructType schema; 2, support options
+def from_json(
+ col: "ColumnOrName",
+ schema: Union[Column, str],
+) -> Column:
+ """
+ Parses a column containing a JSON string into a :class:`MapType` with :class:`StringType`
+ as keys type, :class:`StructType` or :class:`ArrayType` with
+ the specified schema. Returns `null`, in the case of an unparseable string.
+
+ .. versionadded:: 2.1.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ a column or column name in JSON format
+ schema :class:`~pyspark.sql.Column` or str
+ a column, or Python string literal with schema in DDL format, to use when
+ parsing the JSON column.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ a new column of complex type from given JSON object.
+
+ Examples
+ --------
+ >>> from pyspark.sql.types import *
+ >>> data = [(1, '''{"a": 1}''')]
+ >>> schema = StructType([StructField("a", IntegerType())])
+ >>> df = spark.createDataFrame(data, ("key", "value"))
+ >>> df.select(from_json(df.value, schema).alias("json")).collect()
+ [Row(json=Row(a=1))]
+ >>> df.select(from_json(df.value, "a INT").alias("json")).collect()
+ [Row(json=Row(a=1))]
+ >>> df.select(from_json(df.value, "MAP").alias("json")).collect()
+ [Row(json={'a': 1})]
+ >>> data = [(1, '''[{"a": 1}]''')]
+ >>> schema = ArrayType(StructType([StructField("a", IntegerType())]))
+ >>> df = spark.createDataFrame(data, ("key", "value"))
+ >>> df.select(from_json(df.value, schema).alias("json")).collect()
+ [Row(json=[Row(a=1)])]
+ >>> schema = schema_of_json(lit('''{"a": 0}'''))
+ >>> df.select(from_json(df.value, schema).alias("json")).collect()
+ [Row(json=Row(a=None))]
+ >>> data = [(1, '''[1, 2, 3]''')]
+ >>> schema = ArrayType(IntegerType())
+ >>> df = spark.createDataFrame(data, ("key", "value"))
+ >>> df.select(from_json(df.value, schema).alias("json")).collect()
+ [Row(json=[1, 2, 3])]
+ """
+
+ if isinstance(schema, Column):
+ _schema = schema
+ elif isinstance(schema, str):
+ _schema = lit(schema)
+ else:
+ raise TypeError(f"schema should be a Column or str, but got {type(schema).__name__}")
+
+ return _invoke_function("from_json", _to_col(col), _schema)
+
+
+def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column:
+ """
+ Collection function: Returns element of array at given (0-based) index.
+ If the index points outside of the array boundaries, then this function
+ returns NULL.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ name of column containing array
+ index : :class:`~pyspark.sql.Column` or str or int
+ index to check for in array
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ value at given position.
+
+ Notes
+ -----
+ The position is not 1 based, but 0 based index.
+
+ See Also
+ --------
+ :meth:`element_at`
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([(["a", "b", "c"], 1)], ['data', 'index'])
+ >>> df.select(get(df.data, 1)).show()
+ +------------+
+ |get(data, 1)|
+ +------------+
+ | b|
+ +------------+
+
+ >>> df.select(get(df.data, -1)).show()
+ +-------------+
+ |get(data, -1)|
+ +-------------+
+ | null|
+ +-------------+
+
+ >>> df.select(get(df.data, 3)).show()
+ +------------+
+ |get(data, 3)|
+ +------------+
+ | null|
+ +------------+
+
+ >>> df.select(get(df.data, "index")).show()
+ +----------------+
+ |get(data, index)|
+ +----------------+
+ | b|
+ +----------------+
+
+ >>> df.select(get(df.data, col("index") - 1)).show()
+ +----------------------+
+ |get(data, (index - 1))|
+ +----------------------+
+ | a|
+ +----------------------+
+ """
+ index = lit(index) if isinstance(index, int) else index
+
+ return _invoke_function_over_columns("get", col, index)
+
+
+def get_json_object(col: "ColumnOrName", path: str) -> Column:
+ """
+ Extracts json object from a json string based on json `path` specified, and returns json string
+ of the extracted json object. It will return null if the input json string is invalid.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ string column in json format
+ path : str
+ path to the json object to extract
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ string representation of given JSON object value.
+
+ Examples
+ --------
+ >>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1": "value12"}''')]
+ >>> df = spark.createDataFrame(data, ("key", "jstring"))
+ >>> df.select(df.key, get_json_object(df.jstring, '$.f1').alias("c0"), \\
+ ... get_json_object(df.jstring, '$.f2').alias("c1") ).collect()
+ [Row(key='1', c0='value1', c1='value2'), Row(key='2', c0='value12', c1=None)]
+ """
+ return _invoke_function("get_json_object", _to_col(col), lit(path))
+
+
+def inline(col: "ColumnOrName") -> Column:
+ """
+ Explodes an array of structs into a table.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ input column of values to explode.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ generator expression with the inline exploded result.
+
+ See Also
+ --------
+ :meth:`explode`
+
+ Examples
+ --------
+ >>> from pyspark.sql import Row
+ >>> df = spark.createDataFrame([Row(structlist=[Row(a=1, b=2), Row(a=3, b=4)])])
+ >>> df.select(inline(df.structlist)).show()
+ +---+---+
+ | a| b|
+ +---+---+
+ | 1| 2|
+ | 3| 4|
+ +---+---+
+ """
+ return _invoke_function_over_columns("inline", col)
+
+
+def inline_outer(col: "ColumnOrName") -> Column:
+ """
+ Explodes an array of structs into a table.
+ Unlike inline, if the array is null or empty then null is produced for each nested column.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ input column of values to explode.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ generator expression with the inline exploded result.
+
+ See Also
+ --------
+ :meth:`explode_outer`
+ :meth:`inline`
+
+ Examples
+ --------
+ >>> from pyspark.sql import Row
+ >>> df = spark.createDataFrame([
+ ... Row(id=1, structlist=[Row(a=1, b=2), Row(a=3, b=4)]),
+ ... Row(id=2, structlist=[])
+ ... ])
+ >>> df.select('id', inline_outer(df.structlist)).show()
+ +---+----+----+
+ | id| a| b|
+ +---+----+----+
+ | 1| 1| 2|
+ | 1| 3| 4|
+ | 2|null|null|
+ +---+----+----+
+ """
+ return _invoke_function_over_columns("inline_outer", col)
+
+
+def json_tuple(col: "ColumnOrName", *fields: str) -> Column:
+ """Creates a new row for a json column according to the given field names.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ string column in json format
+ fields : str
+ a field or fields to extract
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ a new row for each given field value from json object
+
+ Examples
+ --------
+ >>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1": "value12"}''')]
+ >>> df = spark.createDataFrame(data, ("key", "jstring"))
+ >>> df.select(df.key, json_tuple(df.jstring, 'f1', 'f2')).collect()
+ [Row(key='1', c0='value1', c1='value2'), Row(key='2', c0='value12', c1=None)]
+ """
+
+ return _invoke_function("json_tuple", _to_col(col), *[lit(field) for field in fields])
+
+
+def map_concat(
+ *cols: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]
+) -> Column:
+ """Returns the union of all the given maps.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ cols : :class:`~pyspark.sql.Column` or str
+ column names or :class:`~pyspark.sql.Column`\\s
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ a map of merged entries from other maps.
+
+ Examples
+ --------
+ >>> from pyspark.sql.functions import map_concat
+ >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as map1, map(3, 'c') as map2")
+ >>> df.select(map_concat("map1", "map2").alias("map3")).show(truncate=False)
+ +------------------------+
+ |map3 |
+ +------------------------+
+ |{1 -> a, 2 -> b, 3 -> c}|
+ +------------------------+
+ """
+ if len(cols) == 1 and isinstance(cols[0], (list, set, tuple)):
+ cols = cols[0] # type: ignore[assignment]
+ return _invoke_function_over_columns("map_concat", *cols) # type: ignore[arg-type]
+
+
+def map_contains_key(col: "ColumnOrName", value: Any) -> Column:
+ """
+ Returns true if the map contains the key.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ name of column or expression
+ value :
+ a literal value
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ True if key is in the map and False otherwise.
+
+ Examples
+ --------
+ >>> from pyspark.sql.functions import map_contains_key
+ >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data")
+ >>> df.select(map_contains_key("data", 1)).show()
+ +---------------------------------+
+ |array_contains(map_keys(data), 1)|
+ +---------------------------------+
+ | true|
+ +---------------------------------+
+ >>> df.select(map_contains_key("data", -1)).show()
+ +----------------------------------+
+ |array_contains(map_keys(data), -1)|
+ +----------------------------------+
+ | false|
+ +----------------------------------+
+ """
+ return array_contains(map_keys(col), lit(value))
+
+
+def map_entries(col: "ColumnOrName") -> Column:
+ """
+ Collection function: Returns an unordered array of all entries in the given map.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ name of column or expression
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ ar array of key value pairs as a struct type
+
+ Examples
+ --------
+ >>> from pyspark.sql.functions import map_entries
+ >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data")
+ >>> df = df.select(map_entries("data").alias("entries"))
+ >>> df.show()
+ +----------------+
+ | entries|
+ +----------------+
+ |[{1, a}, {2, b}]|
+ +----------------+
+ >>> df.printSchema()
+ root
+ |-- entries: array (nullable = false)
+ | |-- element: struct (containsNull = false)
+ | | |-- key: integer (nullable = false)
+ | | |-- value: string (nullable = false)
+ """
+ return _invoke_function_over_columns("map_entries", col)
+
+
+# TODO(SPARK-41434): need to support LambdaFunction Expression first
+# def map_filter(col: "ColumnOrName", f: Callable[[Column, Column], Column]) -> Column:
+# """
+# Returns a map whose key-value pairs satisfy a predicate.
+#
+# .. versionadded:: 3.1.0
+#
+# Parameters
+# ----------
+# col : :class:`~pyspark.sql.Column` or str
+# name of column or expression
+# f : function
+# a binary function ``(k: Column, v: Column) -> Column...``
+# Can use methods of :class:`~pyspark.sql.Column`, functions defined in
+# :py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``.
+# Python ``UserDefinedFunctions`` are not supported
+# (`SPARK-27052 `__).
+#
+# Returns
+# -------
+# :class:`~pyspark.sql.Column`
+# filtered map.
+#
+# Examples
+# --------
+# >>> df = spark.createDataFrame([(1, {"foo": 42.0, "bar": 1.0, "baz": 32.0})], ("id", "data"))
+# >>> df.select(map_filter(
+# ... "data", lambda _, v: v > 30.0).alias("data_filtered")
+# ... ).show(truncate=False)
+# +--------------------------+
+# |data_filtered |
+# +--------------------------+
+# |{baz -> 32.0, foo -> 42.0}|
+# +--------------------------+
+# """
+# return _invoke_higher_order_function("MapFilter", [col], [f])
+
+
+def map_from_arrays(col1: "ColumnOrName", col2: "ColumnOrName") -> Column:
+ """Creates a new map from two arrays.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col1 : :class:`~pyspark.sql.Column` or str
+ name of column containing a set of keys. All elements should not be null
+ col2 : :class:`~pyspark.sql.Column` or str
+ name of column containing a set of values
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ a column of map type.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([([2, 5], ['a', 'b'])], ['k', 'v'])
+ >>> df = df.select(map_from_arrays(df.k, df.v).alias("col"))
+ >>> df.show()
+ +----------------+
+ | col|
+ +----------------+
+ |{2 -> a, 5 -> b}|
+ +----------------+
+ >>> df.printSchema()
+ root
+ |-- col: map (nullable = true)
+ | |-- key: long
+ | |-- value: string (valueContainsNull = true)
+ """
+ return _invoke_function_over_columns("map_from_arrays", col1, col2)
+
+
+def map_from_entries(col: "ColumnOrName") -> Column:
+ """
+ Collection function: Converts an array of entries (key value struct types) to a map
+ of values.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ name of column or expression
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ a map created from the given array of entries.
+
+ Examples
+ --------
+ >>> from pyspark.sql.functions import map_from_entries
+ >>> df = spark.sql("SELECT array(struct(1, 'a'), struct(2, 'b')) as data")
+ >>> df.select(map_from_entries("data").alias("map")).show()
+ +----------------+
+ | map|
+ +----------------+
+ |{1 -> a, 2 -> b}|
+ +----------------+
+ """
+ return _invoke_function_over_columns("map_from_entries", col)
+
+
+def map_keys(col: "ColumnOrName") -> Column:
+ """
+ Collection function: Returns an unordered array containing the keys of the map.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ name of column or expression
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ keys of the map as an array.
+
+ Examples
+ --------
+ >>> from pyspark.sql.functions import map_keys
+ >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data")
+ >>> df.select(map_keys("data").alias("keys")).show()
+ +------+
+ | keys|
+ +------+
+ |[1, 2]|
+ +------+
+ """
+ return _invoke_function_over_columns("map_keys", col)
+
+
+def map_values(col: "ColumnOrName") -> Column:
+ """
+ Collection function: Returns an unordered array containing the values of the map.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ name of column or expression
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ values of the map as an array.
+
+ Examples
+ --------
+ >>> from pyspark.sql.functions import map_values
+ >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data")
+ >>> df.select(map_values("data").alias("values")).show()
+ +------+
+ |values|
+ +------+
+ |[a, b]|
+ +------+
+ """
+ return _invoke_function_over_columns("map_values", col)
+
+
+# TODO(SPARK-41434): need to support LambdaFunction Expression first
+# def map_zip_with(
+# col1: "ColumnOrName",
+# col2: "ColumnOrName",
+# f: Callable[[Column, Column, Column], Column],
+# ) -> Column:
+# """
+# Merge two given maps, key-wise into a single map using a function.
+#
+# .. versionadded:: 3.1.0
+#
+# Parameters
+# ----------
+# col1 : :class:`~pyspark.sql.Column` or str
+# name of the first column or expression
+# col2 : :class:`~pyspark.sql.Column` or str
+# name of the second column or expression
+# f : function
+# a ternary function ``(k: Column, v1: Column, v2: Column) -> Column...``
+# Can use methods of :class:`~pyspark.sql.Column`, functions defined in
+# :py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``.
+# Python ``UserDefinedFunctions`` are not supported
+# (`SPARK-27052 `__).
+#
+# Returns
+# -------
+# :class:`~pyspark.sql.Column`
+# zipped map where entries are calculated by applying given function to each
+# pair of arguments.
+#
+# Examples
+# --------
+# >>> df = spark.createDataFrame([
+# ... (1, {"IT": 24.0, "SALES": 12.00}, {"IT": 2.0, "SALES": 1.4})],
+# ... ("id", "base", "ratio")
+# ... )
+# >>> df.select(map_zip_with(
+# ... "base", "ratio", lambda k, v1, v2: round(v1 * v2, 2)).alias("updated_data")
+# ... ).show(truncate=False)
+# +---------------------------+
+# |updated_data |
+# +---------------------------+
+# |{SALES -> 16.8, IT -> 48.0}|
+# +---------------------------+
+# """
+# return _invoke_higher_order_function("MapZipWith", [col1, col2], [f])
+
+
+def posexplode(col: "ColumnOrName") -> Column:
+ """
+ Returns a new row for each element with position in the given array or map.
+ Uses the default column name `pos` for position, and `col` for elements in the
+ array and `key` and `value` for elements in the map unless specified otherwise.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target column to work on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ one row per array item or map key value including positions as a separate column.
+
+ Examples
+ --------
+ >>> from pyspark.sql import Row
+ >>> eDF = spark.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})])
+ >>> eDF.select(posexplode(eDF.intlist)).collect()
+ [Row(pos=0, col=1), Row(pos=1, col=2), Row(pos=2, col=3)]
+
+ >>> eDF.select(posexplode(eDF.mapfield)).show()
+ +---+---+-----+
+ |pos|key|value|
+ +---+---+-----+
+ | 0| a| b|
+ +---+---+-----+
+ """
+ return _invoke_function_over_columns("posexplode", col)
+
+
+def posexplode_outer(col: "ColumnOrName") -> Column:
+ """
+ Returns a new row for each element with position in the given array or map.
+ Unlike posexplode, if the array/map is null or empty then the row (null, null) is produced.
+ Uses the default column name `pos` for position, and `col` for elements in the
+ array and `key` and `value` for elements in the map unless specified otherwise.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target column to work on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ one row per array item or map key value including positions as a separate column.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame(
+ ... [(1, ["foo", "bar"], {"x": 1.0}), (2, [], {}), (3, None, None)],
+ ... ("id", "an_array", "a_map")
+ ... )
+ >>> df.select("id", "an_array", posexplode_outer("a_map")).show()
+ +---+----------+----+----+-----+
+ | id| an_array| pos| key|value|
+ +---+----------+----+----+-----+
+ | 1|[foo, bar]| 0| x| 1.0|
+ | 2| []|null|null| null|
+ | 3| null|null|null| null|
+ +---+----------+----+----+-----+
+ >>> df.select("id", "a_map", posexplode_outer("an_array")).show()
+ +---+----------+----+----+
+ | id| a_map| pos| col|
+ +---+----------+----+----+
+ | 1|{x -> 1.0}| 0| foo|
+ | 1|{x -> 1.0}| 1| bar|
+ | 2| {}|null|null|
+ | 3| null|null|null|
+ +---+----------+----+----+
+ """
+ return _invoke_function_over_columns("posexplode_outer", col)
+
+
+def reverse(col: "ColumnOrName") -> Column:
+ """
+ Collection function: returns a reversed string or an array with reverse order of elements.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ name of column or expression
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ array of elements in reverse order.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('Spark SQL',)], ['data'])
+ >>> df.select(reverse(df.data).alias('s')).collect()
+ [Row(s='LQS krapS')]
+ >>> df = spark.createDataFrame([([2, 1, 3],) ,([1],) ,([],)], ['data'])
+ >>> df.select(reverse(df.data).alias('r')).collect()
+ [Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])]
+ """
+ return _invoke_function_over_columns("reverse", col)
+
+
+# TODO(SPARK-41493): Support options
+def schema_of_csv(csv: "ColumnOrName") -> Column:
+ """
+ Parses a CSV string and infers its schema in DDL format.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ csv : :class:`~pyspark.sql.Column` or str
+ a CSV string or a foldable string column containing a CSV string.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ a string representation of a :class:`StructType` parsed from given CSV.
+
+ Examples
+ --------
+ >>> df = spark.range(1)
+ >>> df.select(schema_of_csv(lit('1|a'), {'sep':'|'}).alias("csv")).collect()
+ [Row(csv='STRUCT<_c0: INT, _c1: STRING>')]
+ >>> df.select(schema_of_csv('1|a', {'sep':'|'}).alias("csv")).collect()
+ [Row(csv='STRUCT<_c0: INT, _c1: STRING>')]
+ """
+
+ if isinstance(csv, Column):
+ _csv = csv
+ elif isinstance(csv, str):
+ _csv = lit(csv)
+ else:
+ raise TypeError(f"csv should be a Column or str, but got {type(csv).__name__}")
+
+ return _invoke_function("schema_of_csv", _csv)
+
+
+# TODO(SPARK-41494): Support options
+def schema_of_json(json: "ColumnOrName") -> Column:
+ """
+ Parses a JSON string and infers its schema in DDL format.
+
+ .. versionadded:: 2.4.0
+
+ Parameters
+ ----------
+ json : :class:`~pyspark.sql.Column` or str
+ a JSON string or a foldable string column containing a JSON string.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ a string representation of a :class:`StructType` parsed from given JSON.
+
+ Examples
+ --------
+ >>> df = spark.range(1)
+ >>> df.select(schema_of_json(lit('{"a": 0}')).alias("json")).collect()
+ [Row(json='STRUCT')]
+ >>> schema = schema_of_json('{a: 1}', {'allowUnquotedFieldNames':'true'})
+ >>> df.select(schema.alias("json")).collect()
+ [Row(json='STRUCT')]
+ """
+
+ if isinstance(json, Column):
+ _json = json
+ elif isinstance(json, str):
+ _json = lit(json)
+ else:
+ raise TypeError(f"json should be a Column or str, but got {type(json).__name__}")
+
+ return _invoke_function("schema_of_json", _json)
+
+
+def shuffle(col: "ColumnOrName") -> Column:
+ """
+ Collection function: Generates a random permutation of the given array.
+
+ .. versionadded:: 3.4.0
+
+ Notes
+ -----
+ The function is non-deterministic.
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ name of column or expression
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ an array of elements in random order.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([([1, 20, 3, 5],), ([1, 20, None, 3],)], ['data'])
+ >>> df.select(shuffle(df.data).alias('s')).collect() # doctest: +SKIP
+ [Row(s=[3, 1, 5, 20]), Row(s=[20, None, 3, 1])]
+ """
+ return _invoke_function_over_columns("shuffle", col)
+
+
+def size(col: "ColumnOrName") -> Column:
+ """
+ Collection function: returns the length of the array or map stored in the column.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ name of column or expression
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ length of the array/map.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([([1, 2, 3],),([1],),([],)], ['data'])
+ >>> df.select(size(df.data)).collect()
+ [Row(size(data)=3), Row(size(data)=1), Row(size(data)=0)]
+ """
+ return _invoke_function_over_columns("size", col)
+
+
+def slice(
+ col: "ColumnOrName", start: Union["ColumnOrName", int], length: Union["ColumnOrName", int]
+) -> Column:
+ """
+ Collection function: returns an array containing all the elements in `x` from index `start`
+ (array indices start at 1, or from the end if `start` is negative) with the specified `length`.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ column name or column containing the array to be sliced
+ start : :class:`~pyspark.sql.Column` or str or int
+ column name, column, or int containing the starting index
+ length : :class:`~pyspark.sql.Column` or str or int
+ column name, column, or int containing the length of the slice
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ a column of array type. Subset of array.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x'])
+ >>> df.select(slice(df.x, 2, 2).alias("sliced")).collect()
+ [Row(sliced=[2, 3]), Row(sliced=[5])]
+ """
+ if isinstance(start, Column):
+ _start = start
+ elif isinstance(start, int):
+ _start = lit(start)
+ else:
+ raise TypeError(f"start should be a Column or int, but got {type(start).__name__}")
+
+ if isinstance(length, Column):
+ _length = length
+ elif isinstance(length, int):
+ _length = lit(length)
+ else:
+ raise TypeError(f"start should be a Column or int, but got {type(length).__name__}")
+
+ return _invoke_function("slice", _to_col(col), _start, _length)
+
+
+def sort_array(col: "ColumnOrName", asc: bool = True) -> Column:
+ """
+ Collection function: sorts the input array in ascending or descending order according
+ to the natural ordering of the array elements. Null elements will be placed at the beginning
+ of the returned array in ascending order or at the end of the returned array in descending
+ order.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ name of column or expression
+ asc : bool, optional
+ whether to sort in ascending or descending order. If `asc` is True (default)
+ then ascending and if False then descending.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ sorted array.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data'])
+ >>> df.select(sort_array(df.data).alias('r')).collect()
+ [Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])]
+ >>> df.select(sort_array(df.data, asc=False).alias('r')).collect()
+ [Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])]
+ """
+ return _invoke_function("sort_array", _to_col(col), lit(asc))
+
+
+def struct(
+ *cols: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]
+) -> Column:
+ """Creates a new struct column.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ cols : list, set, str or :class:`~pyspark.sql.Column`
+ column names or :class:`~pyspark.sql.Column`\\s to contain in the output struct.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ a struct type column of given columns.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([("Alice", 2), ("Bob", 5)], ("name", "age"))
+ >>> df.select(struct('age', 'name').alias("struct")).collect()
+ [Row(struct=Row(age=2, name='Alice')), Row(struct=Row(age=5, name='Bob'))]
+ >>> df.select(struct([df.age, df.name]).alias("struct")).collect()
+ [Row(struct=Row(age=2, name='Alice')), Row(struct=Row(age=5, name='Bob'))]
+ """
+ if len(cols) == 1 and isinstance(cols[0], (list, set, tuple)):
+ cols = cols[0] # type: ignore[assignment]
+ return _invoke_function_over_columns("struct", *cols) # type: ignore[arg-type]
+
+
+# TODO(SPARK-41493): Support options
+def to_csv(col: "ColumnOrName") -> Column:
+ """
+ Converts a column containing a :class:`StructType` into a CSV string.
+ Throws an exception, in the case of an unsupported type.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ name of column containing a struct.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ a CSV string converted from given :class:`StructType`.
+
+ Examples
+ --------
+ >>> from pyspark.sql import Row
+ >>> data = [(1, Row(age=2, name='Alice'))]
+ >>> df = spark.createDataFrame(data, ("key", "value"))
+ >>> df.select(to_csv(df.value).alias("csv")).collect()
+ [Row(csv='2,Alice')]
+ """
+
+ return _invoke_function("to_csv", _to_col(col))
+
+
+# TODO(SPARK-41494): Support options
+def to_json(col: "ColumnOrName") -> Column:
+ """
+ Converts a column containing a :class:`StructType`, :class:`ArrayType` or a :class:`MapType`
+ into a JSON string. Throws an exception, in the case of an unsupported type.
+
+ .. versionadded:: 2.1.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ name of column containing a struct, an array or a map.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ JSON object as string column.
+
+ Examples
+ --------
+ >>> from pyspark.sql import Row
+ >>> from pyspark.sql.types import *
+ >>> data = [(1, Row(age=2, name='Alice'))]
+ >>> df = spark.createDataFrame(data, ("key", "value"))
+ >>> df.select(to_json(df.value).alias("json")).collect()
+ [Row(json='{"age":2,"name":"Alice"}')]
+ >>> data = [(1, [Row(age=2, name='Alice'), Row(age=3, name='Bob')])]
+ >>> df = spark.createDataFrame(data, ("key", "value"))
+ >>> df.select(to_json(df.value).alias("json")).collect()
+ [Row(json='[{"age":2,"name":"Alice"},{"age":3,"name":"Bob"}]')]
+ >>> data = [(1, {"name": "Alice"})]
+ >>> df = spark.createDataFrame(data, ("key", "value"))
+ >>> df.select(to_json(df.value).alias("json")).collect()
+ [Row(json='{"name":"Alice"}')]
+ >>> data = [(1, [{"name": "Alice"}, {"name": "Bob"}])]
+ >>> df = spark.createDataFrame(data, ("key", "value"))
+ >>> df.select(to_json(df.value).alias("json")).collect()
+ [Row(json='[{"name":"Alice"},{"name":"Bob"}]')]
+ >>> data = [(1, ["Alice", "Bob"])]
+ >>> df = spark.createDataFrame(data, ("key", "value"))
+ >>> df.select(to_json(df.value).alias("json")).collect()
+ [Row(json='["Alice","Bob"]')]
+ """
+
+ return _invoke_function("to_json", _to_col(col))
+
+
+# TODO(SPARK-41434): need to support LambdaFunction Expression first
+# def transform(
+# col: "ColumnOrName",
+# f: Union[Callable[[Column], Column], Callable[[Column, Column], Column]],
+# ) -> Column:
+# """
+# Returns an array of elements after applying a transformation to each element in
+# the input array.
+#
+# .. versionadded:: 3.1.0
+#
+# Parameters
+# ----------
+# col : :class:`~pyspark.sql.Column` or str
+# name of column or expression
+# f : function
+# a function that is applied to each element of the input array.
+# Can take one of the following forms:
+#
+# - Unary ``(x: Column) -> Column: ...``
+# - Binary ``(x: Column, i: Column) -> Column...``, where the second argument is
+# a 0-based index of the element.
+#
+# and can use methods of :class:`~pyspark.sql.Column`, functions defined in
+# :py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``.
+# Python ``UserDefinedFunctions`` are not supported
+# (`SPARK-27052 `__).
+#
+# Returns
+# -------
+# :class:`~pyspark.sql.Column`
+# a new array of transformed elements.
+#
+# Examples
+# --------
+# >>> df = spark.createDataFrame([(1, [1, 2, 3, 4])], ("key", "values"))
+# >>> df.select(transform("values", lambda x: x * 2).alias("doubled")).show()
+# +------------+
+# | doubled|
+# +------------+
+# |[2, 4, 6, 8]|
+# +------------+
+#
+# >>> def alternate(x, i):
+# ... return when(i % 2 == 0, x).otherwise(-x)
+# >>> df.select(transform("values", alternate).alias("alternated")).show()
+# +--------------+
+# | alternated|
+# +--------------+
+# |[1, -2, 3, -4]|
+# +--------------+
+# """
+# return _invoke_higher_order_function("ArrayTransform", [col], [f])
+
+
+# TODO(SPARK-41434): need to support LambdaFunction Expression first
+# def transform_keys(col: "ColumnOrName", f: Callable[[Column, Column], Column]) -> Column:
+# """
+# Applies a function to every key-value pair in a map and returns
+# a map with the results of those applications as the new keys for the pairs.
+#
+# .. versionadded:: 3.1.0
+#
+# Parameters
+# ----------
+# col : :class:`~pyspark.sql.Column` or str
+# name of column or expression
+# f : function
+# a binary function ``(k: Column, v: Column) -> Column...``
+# Can use methods of :class:`~pyspark.sql.Column`, functions defined in
+# :py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``.
+# Python ``UserDefinedFunctions`` are not supported
+# (`SPARK-27052 `__).
+#
+# Returns
+# -------
+# :class:`~pyspark.sql.Column`
+# a new map of enties where new keys were calculated by applying given function to
+# each key value argument.
+#
+# Examples
+# --------
+# >>> df = spark.createDataFrame([(1, {"foo": -2.0, "bar": 2.0})], ("id", "data"))
+# >>> df.select(transform_keys(
+# ... "data", lambda k, _: upper(k)).alias("data_upper")
+# ... ).show(truncate=False)
+# +-------------------------+
+# |data_upper |
+# +-------------------------+
+# |{BAR -> 2.0, FOO -> -2.0}|
+# +-------------------------+
+# """
+# return _invoke_higher_order_function("TransformKeys", [col], [f])
+
+# TODO(SPARK-41434): need to support LambdaFunction Expression first
+# def transform_values(col: "ColumnOrName", f: Callable[[Column, Column], Column]) -> Column:
+# """
+# Applies a function to every key-value pair in a map and returns
+# a map with the results of those applications as the new values for the pairs.
+#
+# .. versionadded:: 3.1.0
+#
+# Parameters
+# ----------
+# col : :class:`~pyspark.sql.Column` or str
+# name of column or expression
+# f : function
+# a binary function ``(k: Column, v: Column) -> Column...``
+# Can use methods of :class:`~pyspark.sql.Column`, functions defined in
+# :py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``.
+# Python ``UserDefinedFunctions`` are not supported
+# (`SPARK-27052 `__).
+#
+# Returns
+# -------
+# :class:`~pyspark.sql.Column`
+# a new map of enties where new values were calculated by applying given function to
+# each key value argument.
+#
+# Examples
+# --------
+# >>> df = spark.createDataFrame([(1, {"IT": 10.0, "SALES": 2.0, "OPS": 24.0})], ("id", "data"))
+# >>> df.select(transform_values(
+# ... "data", lambda k, v: when(k.isin("IT", "OPS"), v + 10.0).otherwise(v)
+# ... ).alias("new_data")).show(truncate=False)
+# +---------------------------------------+
+# |new_data |
+# +---------------------------------------+
+# |{OPS -> 34.0, IT -> 20.0, SALES -> 2.0}|
+# +---------------------------------------+
+# """
+# return _invoke_higher_order_function("TransformValues", [col], [f])
+
+
+# TODO(SPARK-41434): need to support LambdaFunction Expression first
+# def zip_with(
+# left: "ColumnOrName",
+# right: "ColumnOrName",
+# f: Callable[[Column, Column], Column],
+# ) -> Column:
+# """
+# Merge two given arrays, element-wise, into a single array using a function.
+# If one array is shorter, nulls are appended at the end to match the length of the longer
+# array, before applying the function.
+#
+# .. versionadded:: 3.1.0
+#
+# Parameters
+# ----------
+# left : :class:`~pyspark.sql.Column` or str
+# name of the first column or expression
+# right : :class:`~pyspark.sql.Column` or str
+# name of the second column or expression
+# f : function
+# a binary function ``(x1: Column, x2: Column) -> Column...``
+# Can use methods of :class:`~pyspark.sql.Column`, functions defined in
+# :py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``.
+# Python ``UserDefinedFunctions`` are not supported
+# (`SPARK-27052 `__).
+#
+# Returns
+# -------
+# :class:`~pyspark.sql.Column`
+# array of calculated values derived by applying given function to each pair of arguments.
+#
+# Examples
+# --------
+# >>> df = spark.createDataFrame([(1, [1, 3, 5, 8], [0, 2, 4, 6])], ("id", "xs", "ys"))
+# >>> df.select(zip_with("xs", "ys", lambda x, y: x ** y).alias("powers")).show(truncate=False)
+# +---------------------------+
+# |powers |
+# +---------------------------+
+# |[1.0, 9.0, 625.0, 262144.0]|
+# +---------------------------+
+#
+# >>> df = spark.createDataFrame([(1, ["foo", "bar"], [1, 2, 3])], ("id", "xs", "ys"))
+# >>> df.select(zip_with("xs", "ys", lambda x, y: concat_ws("_", x, y)).alias("xs_ys")).show()
+# +-----------------+
+# | xs_ys|
+# +-----------------+
+# |[foo_1, bar_2, 3]|
+# +-----------------+
+# """
+# return _invoke_higher_order_function("ZipWith", [left, right], [f])
+
+
+# String/Binary functions
+
+
+def upper(col: "ColumnOrName") -> Column:
+ """
+ Converts a string expression to upper case.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target column to work on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ upper case values.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING")
+ >>> df.select(upper("value")).show()
+ +------------+
+ |upper(value)|
+ +------------+
+ | SPARK|
+ | PYSPARK|
+ | PANDAS API|
+ +------------+
+ """
+ return _invoke_function_over_columns("upper", col)
+
+
+def lower(col: "ColumnOrName") -> Column:
+ """
+ Converts a string expression to lower case.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target column to work on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ lower case values.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING")
+ >>> df.select(lower("value")).show()
+ +------------+
+ |lower(value)|
+ +------------+
+ | spark|
+ | pyspark|
+ | pandas api|
+ +------------+
+ """
+ return _invoke_function_over_columns("lower", col)
+
+
+def ascii(col: "ColumnOrName") -> Column:
+ """
+ Computes the numeric value of the first character of the string column.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target column to work on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ numeric value.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING")
+ >>> df.select(ascii("value")).show()
+ +------------+
+ |ascii(value)|
+ +------------+
+ | 83|
+ | 80|
+ | 80|
+ +------------+
+ """
+ return _invoke_function_over_columns("ascii", col)
+
+
+def base64(col: "ColumnOrName") -> Column:
+ """
+ Computes the BASE64 encoding of a binary column and returns it as a string column.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target column to work on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ BASE64 encoding of string value.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING")
+ >>> df.select(base64("value")).show()
+ +----------------+
+ | base64(value)|
+ +----------------+
+ | U3Bhcms=|
+ | UHlTcGFyaw==|
+ |UGFuZGFzIEFQSQ==|
+ +----------------+
+ """
+ return _invoke_function_over_columns("base64", col)
+
+
+def unbase64(col: "ColumnOrName") -> Column:
+ """
+ Decodes a BASE64 encoded string column and returns it as a binary column.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target column to work on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ encoded string value.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame(["U3Bhcms=",
+ ... "UHlTcGFyaw==",
+ ... "UGFuZGFzIEFQSQ=="], "STRING")
+ >>> df.select(unbase64("value")).show()
+ +--------------------+
+ | unbase64(value)|
+ +--------------------+
+ | [53 70 61 72 6B]|
+ |[50 79 53 70 61 7...|
+ |[50 61 6E 64 61 7...|
+ +--------------------+
+ """
+ return _invoke_function_over_columns("unbase64", col)
+
+
+def ltrim(col: "ColumnOrName") -> Column:
+ """
+ Trim the spaces from left end for the specified string value.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target column to work on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ left trimmed values.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([" Spark", "Spark ", " Spark"], "STRING")
+ >>> df.select(ltrim("value").alias("r")).withColumn("length", length("r")).show()
+ +-------+------+
+ | r|length|
+ +-------+------+
+ | Spark| 5|
+ |Spark | 7|
+ | Spark| 5|
+ +-------+------+
+ """
+ return _invoke_function_over_columns("ltrim", col)
+
+
+def rtrim(col: "ColumnOrName") -> Column:
+ """
+ Trim the spaces from right end for the specified string value.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target column to work on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ right trimmed values.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([" Spark", "Spark ", " Spark"], "STRING")
+ >>> df.select(rtrim("value").alias("r")).withColumn("length", length("r")).show()
+ +--------+------+
+ | r|length|
+ +--------+------+
+ | Spark| 8|
+ | Spark| 5|
+ | Spark| 6|
+ +--------+------+
+ """
+ return _invoke_function_over_columns("rtrim", col)
+
+
+def trim(col: "ColumnOrName") -> Column:
+ """
+ Trim the spaces from both ends for the specified string column.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target column to work on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ trimmed values from both sides.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([" Spark", "Spark ", " Spark"], "STRING")
+ >>> df.select(trim("value").alias("r")).withColumn("length", length("r")).show()
+ +-----+------+
+ | r|length|
+ +-----+------+
+ |Spark| 5|
+ |Spark| 5|
+ |Spark| 5|
+ +-----+------+
+ """
+ return _invoke_function_over_columns("trim", col)
+
+
+def concat_ws(sep: str, *cols: "ColumnOrName") -> Column:
+ """
+ Concatenates multiple input string columns together into a single string column,
+ using the given separator.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ sep : str
+ words separator.
+ cols : :class:`~pyspark.sql.Column` or str
+ list of columns to work on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ string of concatenated words.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
+ >>> df.select(concat_ws('-', df.s, df.d).alias('s')).collect()
+ [Row(s='abcd-123')]
+ """
+ return _invoke_function("concat_ws", lit(sep), *[_to_col(c) for c in cols])
+
+
+def decode(col: "ColumnOrName", charset: str) -> Column:
+ """
+ Computes the first argument into a string from a binary using the provided character set
+ (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target column to work on.
+ charset : str
+ charset to use to decode to.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ the column for computed results.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('abcd',)], ['a'])
+ >>> df.select(decode("a", "UTF-8")).show()
+ +----------------+
+ |decode(a, UTF-8)|
+ +----------------+
+ | abcd|
+ +----------------+
+ """
+ return _invoke_function("decode", _to_col(col), lit(charset))
+
+
+def encode(col: "ColumnOrName", charset: str) -> Column:
+ """
+ Computes the first argument into a binary from a string using the provided character set
+ (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target column to work on.
+ charset : str
+ charset to use to encode.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ the column for computed results.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('abcd',)], ['c'])
+ >>> df.select(encode("c", "UTF-8")).show()
+ +----------------+
+ |encode(c, UTF-8)|
+ +----------------+
+ | [61 62 63 64]|
+ +----------------+
+ """
+ return _invoke_function("encode", _to_col(col), lit(charset))
+
+
+# Date/Timestamp functions
+# TODO(SPARK-41283): Resolve dtypes inconsistencies for:
+# to_timestamp, from_utc_timestamp, to_utc_timestamp,
+# timestamp_seconds, current_timestamp, date_trunc
+
+
+def current_date() -> Column:
+ """
+ Returns the current date at the start of query evaluation as a :class:`DateType` column.
+ All calls of current_date within the same query return the same value.
+
+ .. versionadded:: 3.4.0
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ current date.
+
+ Examples
+ --------
+ >>> df = spark.range(1)
+ >>> df.select(current_date()).show() # doctest: +SKIP
+ +--------------+
+ |current_date()|
+ +--------------+
+ | 2022-08-26|
+ +--------------+
+ """
+ return _invoke_function("current_date")
+
+
+def current_timestamp() -> Column:
+ """
+ Returns the current timestamp at the start of query evaluation as a :class:`TimestampType`
+ column. All calls of current_timestamp within the same query return the same value.
+
+ .. versionadded:: 3.4.0
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ current date and time.
+
+ Examples
+ --------
+ >>> df = spark.range(1)
+ >>> df.select(current_timestamp()).show(truncate=False) # doctest: +SKIP
+ +-----------------------+
+ |current_timestamp() |
+ +-----------------------+
+ |2022-08-26 21:23:22.716|
+ +-----------------------+
+ """
+ return _invoke_function("current_timestamp")
+
+
+def localtimestamp() -> Column:
+ """
+ Returns the current timestamp without time zone at the start of query evaluation
+ as a timestamp without time zone column. All calls of localtimestamp within the
+ same query return the same value.
+
+ .. versionadded:: 3.4.0
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ current local date and time.
+
+ Examples
+ --------
+ >>> df = spark.range(1)
+ >>> df.select(localtimestamp()).show(truncate=False) # doctest: +SKIP
+ +-----------------------+
+ |localtimestamp() |
+ +-----------------------+
+ |2022-08-26 21:28:34.639|
+ +-----------------------+
+ """
+ return _invoke_function("localtimestamp")
+
+
+def date_format(date: "ColumnOrName", format: str) -> Column:
+ """
+ Converts a date/timestamp/string to a value of string in the format specified by the date
+ format given by the second argument.
+
+ A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All
+ pattern letters of `datetime pattern`_. can be used.
+
+ .. _datetime pattern: https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html
+
+ .. versionadded:: 3.4.0
+
+ Notes
+ -----
+ Whenever possible, use specialized functions like `year`.
+
+ Parameters
+ ----------
+ date : :class:`~pyspark.sql.Column` or str
+ input column of values to format.
+ format: str
+ format to use to represent datetime values.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ string value representing formatted datetime.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
+ >>> df.select(date_format('dt', 'MM/dd/yyy').alias('date')).collect()
+ [Row(date='04/08/2015')]
+ """
+ return _invoke_function("date_format", _to_col(date), lit(format))
+
+
+def year(col: "ColumnOrName") -> Column:
+ """
+ Extract the year of a given date/timestamp as integer.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target date/timestamp column to work on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ year part of the date/timestamp as integer.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
+ >>> df.select(year('dt').alias('year')).collect()
+ [Row(year=2015)]
+ """
+ return _invoke_function_over_columns("year", col)
+
+
+def quarter(col: "ColumnOrName") -> Column:
+ """
+ Extract the quarter of a given date/timestamp as integer.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target date/timestamp column to work on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ quarter of the date/timestamp as integer.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
+ >>> df.select(quarter('dt').alias('quarter')).collect()
+ [Row(quarter=2)]
+ """
+ return _invoke_function_over_columns("quarter", col)
+
+
+def month(col: "ColumnOrName") -> Column:
+ """
+ Extract the month of a given date/timestamp as integer.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target date/timestamp column to work on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ month part of the date/timestamp as integer.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
+ >>> df.select(month('dt').alias('month')).collect()
+ [Row(month=4)]
+ """
+ return _invoke_function_over_columns("month", col)
+
+
+def dayofweek(col: "ColumnOrName") -> Column:
+ """
+ Extract the day of the week of a given date/timestamp as integer.
+ Ranges from 1 for a Sunday through to 7 for a Saturday
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target date/timestamp column to work on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ day of the week for given date/timestamp as integer.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
+ >>> df.select(dayofweek('dt').alias('day')).collect()
+ [Row(day=4)]
+ """
+ return _invoke_function_over_columns("dayofweek", col)
+
+
+def dayofmonth(col: "ColumnOrName") -> Column:
+ """
+ Extract the day of the month of a given date/timestamp as integer.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target date/timestamp column to work on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ day of the month for given date/timestamp as integer.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
+ >>> df.select(dayofmonth('dt').alias('day')).collect()
+ [Row(day=8)]
+ """
+ return _invoke_function_over_columns("dayofmonth", col)
+
+
+def dayofyear(col: "ColumnOrName") -> Column:
+ """
+ Extract the day of the year of a given date/timestamp as integer.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target date/timestamp column to work on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ day of the year for given date/timestamp as integer.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
+ >>> df.select(dayofyear('dt').alias('day')).collect()
+ [Row(day=98)]
+ """
+ return _invoke_function_over_columns("dayofyear", col)
+
+
+def hour(col: "ColumnOrName") -> Column:
+ """
+ Extract the hours of a given timestamp as integer.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target date/timestamp column to work on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ hour part of the timestamp as integer.
+
+ Examples
+ --------
+ >>> import datetime
+ >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts'])
+ >>> df.select(hour('ts').alias('hour')).collect()
+ [Row(hour=13)]
+ """
+ return _invoke_function_over_columns("hour", col)
+
+
+def minute(col: "ColumnOrName") -> Column:
+ """
+ Extract the minutes of a given timestamp as integer.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target date/timestamp column to work on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ minutes part of the timestamp as integer.
+
+ Examples
+ --------
+ >>> import datetime
+ >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts'])
+ >>> df.select(minute('ts').alias('minute')).collect()
+ [Row(minute=8)]
+ """
+ return _invoke_function_over_columns("minute", col)
+
+
+def second(col: "ColumnOrName") -> Column:
+ """
+ Extract the seconds of a given date as integer.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target date/timestamp column to work on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ `seconds` part of the timestamp as integer.
+
+ Examples
+ --------
+ >>> import datetime
+ >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts'])
+ >>> df.select(second('ts').alias('second')).collect()
+ [Row(second=15)]
+ """
+ return _invoke_function_over_columns("second", col)
+
+
+def weekofyear(col: "ColumnOrName") -> Column:
+ """
+ Extract the week number of a given date as integer.
+ A week is considered to start on a Monday and week 1 is the first week with more than 3 days,
+ as defined by ISO 8601
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target timestamp column to work on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ `week` of the year for given date as integer.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('2015-04-08',)], ['dt'])
+ >>> df.select(weekofyear(df.dt).alias('week')).collect()
+ [Row(week=15)]
+ """
+ return _invoke_function_over_columns("weekofyear", col)
+
+
+def make_date(year: "ColumnOrName", month: "ColumnOrName", day: "ColumnOrName") -> Column:
+ """
+ Returns a column with a date built from the year, month and day columns.
+
+ .. versionadded:: 3.3.0
+
+ Parameters
+ ----------
+ year : :class:`~pyspark.sql.Column` or str
+ The year to build the date
+ month : :class:`~pyspark.sql.Column` or str
+ The month to build the date
+ day : :class:`~pyspark.sql.Column` or str
+ The day to build the date
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ a date built from given parts.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([(2020, 6, 26)], ['Y', 'M', 'D'])
+ >>> df.select(make_date(df.Y, df.M, df.D).alias("datefield")).collect()
+ [Row(datefield=datetime.date(2020, 6, 26))]
+ """
+ return _invoke_function_over_columns("make_date", year, month, day)
+
+
+def date_add(start: "ColumnOrName", days: Union["ColumnOrName", int]) -> Column:
+ """
+ Returns the date that is `days` days after `start`. If `days` is a negative value
+ then these amount of days will be deducted from `start`.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ start : :class:`~pyspark.sql.Column` or str
+ date column to work on.
+ days : :class:`~pyspark.sql.Column` or str or int
+ how many days after the given date to calculate.
+ Accepts negative value as well to calculate backwards in time.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ a date after/before given number of days.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('2015-04-08', 2,)], ['dt', 'add'])
+ >>> df.select(date_add(df.dt, 1).alias('next_date')).collect()
+ [Row(next_date=datetime.date(2015, 4, 9))]
+ >>> df.select(date_add(df.dt, df.add.cast('integer')).alias('next_date')).collect()
+ [Row(next_date=datetime.date(2015, 4, 10))]
+ >>> df.select(date_add('dt', -1).alias('prev_date')).collect()
+ [Row(prev_date=datetime.date(2015, 4, 7))]
+ """
+ days = lit(days) if isinstance(days, int) else days
+ return _invoke_function_over_columns("date_add", start, days)
+
+
+def date_sub(start: "ColumnOrName", days: Union["ColumnOrName", int]) -> Column:
+ """
+ Returns the date that is `days` days before `start`. If `days` is a negative value
+ then these amount of days will be added to `start`.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ start : :class:`~pyspark.sql.Column` or str
+ date column to work on.
+ days : :class:`~pyspark.sql.Column` or str or int
+ how many days before the given date to calculate.
+ Accepts negative value as well to calculate forward in time.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ a date before/after given number of days.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('2015-04-08', 2,)], ['dt', 'sub'])
+ >>> df.select(date_sub(df.dt, 1).alias('prev_date')).collect()
+ [Row(prev_date=datetime.date(2015, 4, 7))]
+ >>> df.select(date_sub(df.dt, df.sub.cast('integer')).alias('prev_date')).collect()
+ [Row(prev_date=datetime.date(2015, 4, 6))]
+ >>> df.select(date_sub('dt', -1).alias('next_date')).collect()
+ [Row(next_date=datetime.date(2015, 4, 9))]
+ """
+ days = lit(days) if isinstance(days, int) else days
+ return _invoke_function_over_columns("date_sub", start, days)
+
+
+def datediff(end: "ColumnOrName", start: "ColumnOrName") -> Column:
+ """
+ Returns the number of days from `start` to `end`.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ end : :class:`~pyspark.sql.Column` or str
+ to date column to work on.
+ start : :class:`~pyspark.sql.Column` or str
+ from date column to work on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ difference in days between two dates.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('2015-04-08','2015-05-10')], ['d1', 'd2'])
+ >>> df.select(datediff(df.d2, df.d1).alias('diff')).collect()
+ [Row(diff=32)]
+ """
+ return _invoke_function_over_columns("datediff", end, start)
+
+
+def add_months(start: "ColumnOrName", months: Union["ColumnOrName", int]) -> Column:
+ """
+ Returns the date that is `months` months after `start`. If `months` is a negative value
+ then these amount of months will be deducted from the `start`.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ start : :class:`~pyspark.sql.Column` or str
+ date column to work on.
+ months : :class:`~pyspark.sql.Column` or str or int
+ how many months after the given date to calculate.
+ Accepts negative value as well to calculate backwards.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ a date after/before given number of months.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('2015-04-08', 2)], ['dt', 'add'])
+ >>> df.select(add_months(df.dt, 1).alias('next_month')).collect()
+ [Row(next_month=datetime.date(2015, 5, 8))]
+ >>> df.select(add_months(df.dt, df.add.cast('integer')).alias('next_month')).collect()
+ [Row(next_month=datetime.date(2015, 6, 8))]
+ >>> df.select(add_months('dt', -2).alias('prev_month')).collect()
+ [Row(prev_month=datetime.date(2015, 2, 8))]
+ """
+ months = lit(months) if isinstance(months, int) else months
+ return _invoke_function_over_columns("add_months", start, months)
+
+
+def months_between(date1: "ColumnOrName", date2: "ColumnOrName", roundOff: bool = True) -> Column:
+ """
+ Returns number of months between dates date1 and date2.
+ If date1 is later than date2, then the result is positive.
+ A whole number is returned if both inputs have the same day of month or both are the last day
+ of their respective months. Otherwise, the difference is calculated assuming 31 days per month.
+ The result is rounded off to 8 digits unless `roundOff` is set to `False`.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ date1 : :class:`~pyspark.sql.Column` or str
+ first date column.
+ date2 : :class:`~pyspark.sql.Column` or str
+ second date column.
+ roundOff : bool, optional
+ whether to round (to 8 digits) the final value or not (default: True).
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ number of months between two dates.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['date1', 'date2'])
+ >>> df.select(months_between(df.date1, df.date2).alias('months')).collect()
+ [Row(months=3.94959677)]
+ >>> df.select(months_between(df.date1, df.date2, False).alias('months')).collect()
+ [Row(months=3.9495967741935485)]
+ """
+ return _invoke_function("months_between", _to_col(date1), _to_col(date2), lit(roundOff))
+
+
+def to_date(col: "ColumnOrName", format: Optional[str] = None) -> Column:
+ """Converts a :class:`~pyspark.sql.Column` into :class:`pyspark.sql.types.DateType`
+ using the optionally specified format. Specify formats according to `datetime pattern`_.
+ By default, it follows casting rules to :class:`pyspark.sql.types.DateType` if the format
+ is omitted. Equivalent to ``col.cast("date")``.
+
+ .. _datetime pattern: https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ input column of values to convert.
+ format: str, optional
+ format to use to convert date values.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ date value as :class:`pyspark.sql.types.DateType` type.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
+ >>> df.select(to_date(df.t).alias('date')).collect()
+ [Row(date=datetime.date(1997, 2, 28))]
+
+ >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
+ >>> df.select(to_date(df.t, 'yyyy-MM-dd HH:mm:ss').alias('date')).collect()
+ [Row(date=datetime.date(1997, 2, 28))]
+ """
+ if format is None:
+ return _invoke_function_over_columns("to_date", col)
+ else:
+ return _invoke_function("to_date", _to_col(col), lit(format))
+
+
+@overload
+def to_timestamp(col: "ColumnOrName") -> Column:
+ ...
+
+
+@overload
+def to_timestamp(col: "ColumnOrName", format: str) -> Column:
+ ...
+
+
+def to_timestamp(col: "ColumnOrName", format: Optional[str] = None) -> Column:
+ """Converts a :class:`~pyspark.sql.Column` into :class:`pyspark.sql.types.TimestampType`
+ using the optionally specified format. Specify formats according to `datetime pattern`_.
+ By default, it follows casting rules to :class:`pyspark.sql.types.TimestampType` if the format
+ is omitted. Equivalent to ``col.cast("timestamp")``.
+
+ .. _datetime pattern: https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ column values to convert.
+ format: str, optional
+ format to use to convert timestamp values.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ timestamp value as :class:`pyspark.sql.types.TimestampType` type.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
+ >>> df.select(to_timestamp(df.t).alias('dt')).collect()
+ [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))]
+
+ >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
+ >>> df.select(to_timestamp(df.t, 'yyyy-MM-dd HH:mm:ss').alias('dt')).collect()
+ [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))]
+ """
+ if format is None:
+ return _invoke_function_over_columns("to_timestamp", col)
+ else:
+ return _invoke_function("to_timestamp", _to_col(col), lit(format))
+
+
+def trunc(date: "ColumnOrName", format: str) -> Column:
+ """
+ Returns date truncated to the unit specified by the format.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ date : :class:`~pyspark.sql.Column` or str
+ input column of values to truncate.
+ format : str
+ 'year', 'yyyy', 'yy' to truncate by year,
+ or 'month', 'mon', 'mm' to truncate by month
+ Other options are: 'week', 'quarter'
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ truncated date.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('1997-02-28',)], ['d'])
+ >>> df.select(trunc(df.d, 'year').alias('year')).collect()
+ [Row(year=datetime.date(1997, 1, 1))]
+ >>> df.select(trunc(df.d, 'mon').alias('month')).collect()
+ [Row(month=datetime.date(1997, 2, 1))]
+ """
+ return _invoke_function("trunc", _to_col(date), lit(format))
+
+
+def date_trunc(format: str, timestamp: "ColumnOrName") -> Column:
+ """
+ Returns timestamp truncated to the unit specified by the format.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ format : str
+ 'year', 'yyyy', 'yy' to truncate by year,
+ 'month', 'mon', 'mm' to truncate by month,
+ 'day', 'dd' to truncate by day,
+ Other options are:
+ 'microsecond', 'millisecond', 'second', 'minute', 'hour', 'week', 'quarter'
+ timestamp : :class:`~pyspark.sql.Column` or str
+ input column of values to truncate.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ truncated timestamp.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('1997-02-28 05:02:11',)], ['t'])
+ >>> df.select(date_trunc('year', df.t).alias('year')).collect()
+ [Row(year=datetime.datetime(1997, 1, 1, 0, 0))]
+ >>> df.select(date_trunc('mon', df.t).alias('month')).collect()
+ [Row(month=datetime.datetime(1997, 2, 1, 0, 0))]
+ """
+ return _invoke_function("date_trunc", lit(format), _to_col(timestamp))
+
+
+def next_day(date: "ColumnOrName", dayOfWeek: str) -> Column:
+ """
+ Returns the first date which is later than the value of the date column
+ based on second `week day` argument.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ date : :class:`~pyspark.sql.Column` or str
+ target column to compute on.
+ dayOfWeek : str
+ day of the week, case-insensitive, accepts:
+ "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ the column of computed results.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('2015-07-27',)], ['d'])
+ >>> df.select(next_day(df.d, 'Sun').alias('date')).collect()
+ [Row(date=datetime.date(2015, 8, 2))]
+ """
+ return _invoke_function("next_day", _to_col(date), lit(dayOfWeek))
+
+
+def last_day(date: "ColumnOrName") -> Column:
+ """
+ Returns the last day of the month which the given date belongs to.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ date : :class:`~pyspark.sql.Column` or str
+ target column to compute on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ last day of the month.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('1997-02-10',)], ['d'])
+ >>> df.select(last_day(df.d).alias('date')).collect()
+ [Row(date=datetime.date(1997, 2, 28))]
+ """
+ return _invoke_function_over_columns("last_day", date)
+
+
+def from_unixtime(timestamp: "ColumnOrName", format: str = "yyyy-MM-dd HH:mm:ss") -> Column:
+ """
+ Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string
+ representing the timestamp of that moment in the current system time zone in the given
+ format.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ timestamp : :class:`~pyspark.sql.Column` or str
+ column of unix time values.
+ format : str, optional
+ format to use to convert to (default: yyyy-MM-dd HH:mm:ss)
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ formatted timestamp as string.
+
+ Examples
+ --------
+ >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles")
+ >>> time_df = spark.createDataFrame([(1428476400,)], ['unix_time'])
+ >>> time_df.select(from_unixtime('unix_time').alias('ts')).collect()
+ [Row(ts='2015-04-08 00:00:00')]
+ >>> spark.conf.unset("spark.sql.session.timeZone")
+ """
+ return _invoke_function("from_unixtime", _to_col(timestamp), lit(format))
+
+
+@overload
+def unix_timestamp(timestamp: "ColumnOrName", format: str = ...) -> Column:
+ ...
+
+
+@overload
+def unix_timestamp() -> Column:
+ ...
+
+
+def unix_timestamp(
+ timestamp: Optional["ColumnOrName"] = None, format: str = "yyyy-MM-dd HH:mm:ss"
+) -> Column:
+ """
+ Convert time string with given pattern ('yyyy-MM-dd HH:mm:ss', by default)
+ to Unix time stamp (in seconds), using the default timezone and the default
+ locale, returns null if failed.
+
+ if `timestamp` is None, then it returns current timestamp.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ timestamp : :class:`~pyspark.sql.Column` or str, optional
+ timestamps of string values.
+ format : str, optional
+ alternative format to use for converting (default: yyyy-MM-dd HH:mm:ss).
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ unix time as long integer.
+
+ Examples
+ --------
+ >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles")
+ >>> time_df = spark.createDataFrame([('2015-04-08',)], ['dt'])
+ >>> time_df.select(unix_timestamp('dt', 'yyyy-MM-dd').alias('unix_time')).collect()
+ [Row(unix_time=1428476400)]
+ >>> spark.conf.unset("spark.sql.session.timeZone")
+ """
+ if timestamp is None:
+ return _invoke_function("unix_timestamp")
+ return _invoke_function("unix_timestamp", _to_col(timestamp), lit(format))
+
+
+def from_utc_timestamp(timestamp: "ColumnOrName", tz: "ColumnOrName") -> Column:
+ """
+ This is a common function for databases supporting TIMESTAMP WITHOUT TIMEZONE. This function
+ takes a timestamp which is timezone-agnostic, and interprets it as a timestamp in UTC, and
+ renders that timestamp as a timestamp in the given time zone.
+
+ However, timestamp in Spark represents number of microseconds from the Unix epoch, which is not
+ timezone-agnostic. So in Spark this function just shift the timestamp value from UTC timezone to
+ the given timezone.
+
+ This function may return confusing result if the input is a string with timezone, e.g.
+ '2018-03-13T06:18:23+00:00'. The reason is that, Spark firstly cast the string to timestamp
+ according to the timezone in the string, and finally display the result by converting the
+ timestamp to string according to the session local timezone.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ timestamp : :class:`~pyspark.sql.Column` or str
+ the column that contains timestamps
+ tz : :class:`~pyspark.sql.Column` or str
+ A string detailing the time zone ID that the input should be adjusted to. It should
+ be in the format of either region-based zone IDs or zone offsets. Region IDs must
+ have the form 'area/city', such as 'America/Los_Angeles'. Zone offsets must be in
+ the format '(+|-)HH:mm', for example '-08:00' or '+01:00'. Also 'UTC' and 'Z' are
+ supported as aliases of '+00:00'. Other short names are not recommended to use
+ because they can be ambiguous.
+ `tz` can also take a :class:`~pyspark.sql.Column` containing timezone ID strings.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ timestamp value represented in given timezone.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('1997-02-28 10:30:00', 'JST')], ['ts', 'tz'])
+ >>> df.select(from_utc_timestamp(df.ts, "PST").alias('local_time')).collect()
+ [Row(local_time=datetime.datetime(1997, 2, 28, 2, 30))]
+ >>> df.select(from_utc_timestamp(df.ts, df.tz).alias('local_time')).collect()
+ [Row(local_time=datetime.datetime(1997, 2, 28, 19, 30))]
+ """
+ if isinstance(tz, str):
+ tz = lit(tz)
+ return _invoke_function_over_columns("from_utc_timestamp", timestamp, tz)
+
+
+def to_utc_timestamp(timestamp: "ColumnOrName", tz: "ColumnOrName") -> Column:
+ """
+ This is a common function for databases supporting TIMESTAMP WITHOUT TIMEZONE. This function
+ takes a timestamp which is timezone-agnostic, and interprets it as a timestamp in the given
+ timezone, and renders that timestamp as a timestamp in UTC.
+
+ However, timestamp in Spark represents number of microseconds from the Unix epoch, which is not
+ timezone-agnostic. So in Spark this function just shift the timestamp value from the given
+ timezone to UTC timezone.
+
+ This function may return confusing result if the input is a string with timezone, e.g.
+ '2018-03-13T06:18:23+00:00'. The reason is that, Spark firstly cast the string to timestamp
+ according to the timezone in the string, and finally display the result by converting the
+ timestamp to string according to the session local timezone.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ timestamp : :class:`~pyspark.sql.Column` or str
+ the column that contains timestamps
+ tz : :class:`~pyspark.sql.Column` or str
+ A string detailing the time zone ID that the input should be adjusted to. It should
+ be in the format of either region-based zone IDs or zone offsets. Region IDs must
+ have the form 'area/city', such as 'America/Los_Angeles'. Zone offsets must be in
+ the format '(+|-)HH:mm', for example '-08:00' or '+01:00'. Also 'UTC' and 'Z' are
+ supported as aliases of '+00:00'. Other short names are not recommended to use
+ because they can be ambiguous.
+ `tz` can also take a :class:`~pyspark.sql.Column` containing timezone ID strings.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ timestamp value represented in UTC timezone.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('1997-02-28 10:30:00', 'JST')], ['ts', 'tz'])
+ >>> df.select(to_utc_timestamp(df.ts, "PST").alias('utc_time')).collect()
+ [Row(utc_time=datetime.datetime(1997, 2, 28, 18, 30))]
+ >>> df.select(to_utc_timestamp(df.ts, df.tz).alias('utc_time')).collect()
+ [Row(utc_time=datetime.datetime(1997, 2, 28, 1, 30))]
+ """
+ if isinstance(tz, str):
+ tz = lit(tz)
+ return _invoke_function_over_columns("to_utc_timestamp", timestamp, tz)
+
+
+def timestamp_seconds(col: "ColumnOrName") -> Column:
+ """
+ Converts the number of seconds from the Unix epoch (1970-01-01T00:00:00Z)
+ to a timestamp.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ unix time values.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ converted timestamp value.
+
+ Examples
+ --------
+ >>> from pyspark.sql.functions import timestamp_seconds
+ >>> spark.conf.set("spark.sql.session.timeZone", "UTC")
+ >>> time_df = spark.createDataFrame([(1230219000,)], ['unix_time'])
+ >>> time_df.select(timestamp_seconds(time_df.unix_time).alias('ts')).show()
+ +-------------------+
+ | ts|
+ +-------------------+
+ |2008-12-25 15:30:00|
+ +-------------------+
+ >>> time_df.select(timestamp_seconds('unix_time').alias('ts')).printSchema()
+ root
+ |-- ts: timestamp (nullable = true)
+ >>> spark.conf.unset("spark.sql.session.timeZone")
+ """
+
+ return _invoke_function_over_columns("timestamp_seconds", col)
+
+
+# Misc Functions
+
+
+def assert_true(col: "ColumnOrName", errMsg: Optional[Union[Column, str]] = None) -> Column:
+ """
+ Returns `null` if the input column is `true`; throws an exception
+ with the provided error message otherwise.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ column name or column that represents the input column to test
+ errMsg : :class:`~pyspark.sql.Column` or str, optional
+ A Python string literal or column containing the error message
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ `null` if the input column is `true` otherwise throws an error with specified message.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([(0,1)], ['a', 'b'])
+ >>> df.select(assert_true(df.a < df.b).alias('r')).collect()
+ [Row(r=None)]
+ >>> df.select(assert_true(df.a < df.b, df.a).alias('r')).collect()
+ [Row(r=None)]
+ >>> df.select(assert_true(df.a < df.b, 'error').alias('r')).collect()
+ [Row(r=None)]
+ >>> df.select(assert_true(df.a > df.b, 'My error msg').alias('r')).collect() # doctest: +SKIP
+ ...
+ java.lang.RuntimeException: My error msg
+ ...
+ """
+ if errMsg is None:
+ return _invoke_function_over_columns("assert_true", col)
+ if not isinstance(errMsg, (str, Column)):
+ raise TypeError("errMsg should be a Column or a str, got {}".format(type(errMsg)))
+
+ _err_msg = lit(errMsg) if isinstance(errMsg, str) else _to_col(errMsg)
+
+ return _invoke_function("assert_true", _to_col(col), _err_msg)
+
+
+def raise_error(errMsg: Union[Column, str]) -> Column:
+ """
+ Throws an exception with the provided error message.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ errMsg : :class:`~pyspark.sql.Column` or str
+ A Python string literal or column containing the error message
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ throws an error with specified message.
+
+ Examples
+ --------
+ >>> df = spark.range(1)
+ >>> df.select(raise_error("My error message")).show() # doctest: +SKIP
+ ...
+ java.lang.RuntimeException: My error message
+ ...
+ """
+ if not isinstance(errMsg, (str, Column)):
+ raise TypeError("errMsg should be a Column or a str, got {}".format(type(errMsg)))
+
+ _err_msg = lit(errMsg) if isinstance(errMsg, str) else _to_col(errMsg)
+
+ return _invoke_function("raise_error", _err_msg)
+
+
+def crc32(col: "ColumnOrName") -> Column:
+ """
+ Calculates the cyclic redundancy check value (CRC32) of a binary column and
+ returns the value as a bigint.
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target column to compute on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ the column for computed results.
+
+ .. versionadded:: 3.4.0
+
+ Examples
+ --------
+ >>> spark.createDataFrame([('ABC',)], ['a']).select(crc32('a').alias('crc32')).collect()
+ [Row(crc32=2743272264)]
+ """
+ return _invoke_function_over_columns("crc32", col)
+
+
+def hash(*cols: "ColumnOrName") -> Column:
+ """Calculates the hash code of given columns, and returns the result as an int column.
+
+ .. versionadded:: 2.0.0
+
+ Parameters
+ ----------
+ cols : :class:`~pyspark.sql.Column` or str
+ one or more columns to compute on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ hash value as int column.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('ABC', 'DEF')], ['c1', 'c2'])
+
+ Hash for one column
+
+ >>> df.select(hash('c1').alias('hash')).show()
+ +----------+
+ | hash|
+ +----------+
+ |-757602832|
+ +----------+
+
+ Two or more columns
+
+ >>> df.select(hash('c1', 'c2').alias('hash')).show()
+ +---------+
+ | hash|
+ +---------+
+ |599895104|
+ +---------+
+ """
+ return _invoke_function_over_columns("hash", *cols)
+
+
+def xxhash64(*cols: "ColumnOrName") -> Column:
+ """Calculates the hash code of given columns using the 64-bit variant of the xxHash algorithm,
+ and returns the result as a long column. The hash computation uses an initial seed of 42.
+
+ .. versionadded:: 3.0.0
+
+ Parameters
+ ----------
+ cols : :class:`~pyspark.sql.Column` or str
+ one or more columns to compute on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ hash value as long column.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([('ABC', 'DEF')], ['c1', 'c2'])
+
+ Hash for one column
+
+ >>> df.select(xxhash64('c1').alias('hash')).show()
+ +-------------------+
+ | hash|
+ +-------------------+
+ |4105715581806190027|
+ +-------------------+
+
+ Two or more columns
+
+ >>> df.select(xxhash64('c1', 'c2').alias('hash')).show()
+ +-------------------+
+ | hash|
+ +-------------------+
+ |3233247871021311208|
+ +-------------------+
+ """
+ return _invoke_function_over_columns("xxhash64", *cols)
+
+
+def md5(col: "ColumnOrName") -> Column:
+ """Calculates the MD5 digest and returns the value as a 32 character hex string.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target column to compute on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ the column for computed results.
+
+ Examples
+ --------
+ >>> spark.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect()
+ [Row(hash='902fbdd2b1df0c4f70b4a5d23525e932')]
+ """
+ return _invoke_function_over_columns("md5", col)
+
+
+def sha1(col: "ColumnOrName") -> Column:
+ """Returns the hex string result of SHA-1.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target column to compute on.
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ the column for computed results.
+
+ Examples
+ --------
+ >>> spark.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect()
+ [Row(hash='3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')]
+ """
+ return _invoke_function_over_columns("sha1", col)
+
+
+def sha2(col: "ColumnOrName", numBits: int) -> Column:
+ """Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384,
+ and SHA-512). The numBits indicates the desired bit length of the result, which must have a
+ value of 224, 256, 384, 512, or 0 (which is equivalent to 256).
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ target column to compute on.
+ numBits : int
+ the desired bit length of the result, which must have a
+ value of 224, 256, 384, 512, or 0 (which is equivalent to 256).
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ the column for computed results.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame([["Alice"], ["Bob"]], ["name"])
+ >>> df.withColumn("sha2", sha2(df.name, 256)).show(truncate=False)
+ +-----+----------------------------------------------------------------+
+ |name |sha2 |
+ +-----+----------------------------------------------------------------+
+ |Alice|3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043|
+ |Bob |cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961|
+ +-----+----------------------------------------------------------------+
+ """
+ return _invoke_function("sha2", _to_col(col), lit(numBits))
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index 2de0dbb40c..e8b6d79943 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -17,11 +17,13 @@
from typing import Any, List, Optional, Sequence, Union, cast, TYPE_CHECKING, Mapping, Dict
import functools
-import pandas
import pyarrow as pa
+
+from pyspark.sql.types import DataType
+
import pyspark.sql.connect.proto as proto
from pyspark.sql.connect.column import Column, SortOrder, ColumnReference
-
+from pyspark.sql.connect.types import pyspark_types_to_proto_types
if TYPE_CHECKING:
from pyspark.sql.connect._typing import ColumnOrName
@@ -167,21 +169,34 @@ def _repr_html_(self) -> str:
class LocalRelation(LogicalPlan):
- """Creates a LocalRelation plan object based on a Pandas DataFrame."""
+ """Creates a LocalRelation plan object based on a PyArrow Table."""
- def __init__(self, pdf: "pandas.DataFrame") -> None:
+ def __init__(
+ self,
+ table: "pa.Table",
+ schema: Optional[Union[DataType, str]] = None,
+ ) -> None:
super().__init__(None)
- self._pdf = pdf
+ assert table is not None and isinstance(table, pa.Table)
+ self._table = table
+
+ if schema is not None:
+ assert isinstance(schema, (DataType, str))
+ self._schema = schema
def plan(self, session: "SparkConnectClient") -> proto.Relation:
sink = pa.BufferOutputStream()
- table = pa.Table.from_pandas(self._pdf)
- with pa.ipc.new_stream(sink, table.schema) as writer:
- for b in table.to_batches():
+ with pa.ipc.new_stream(sink, self._table.schema) as writer:
+ for b in self._table.to_batches():
writer.write_batch(b)
plan = proto.Relation()
plan.local_relation.data = sink.getvalue().to_pybytes()
+ if self._schema is not None:
+ if isinstance(self._schema, DataType):
+ plan.local_relation.datatype.CopyFrom(pyspark_types_to_proto_types(self._schema))
+ elif isinstance(self._schema, str):
+ plan.local_relation.datatype_str = self._schema
return plan
def print(self, indent: int = 0) -> str:
@@ -984,6 +999,65 @@ def _repr_html_(self) -> str:
"""
+class Unpivot(LogicalPlan):
+ """Logical plan object for a unpivot operation."""
+
+ def __init__(
+ self,
+ child: Optional["LogicalPlan"],
+ ids: List["ColumnOrName"],
+ values: List["ColumnOrName"],
+ variable_column_name: str,
+ value_column_name: str,
+ ) -> None:
+ super().__init__(child)
+ self.ids = ids
+ self.values = values
+ self.variable_column_name = variable_column_name
+ self.value_column_name = value_column_name
+
+ def col_to_expr(self, col: "ColumnOrName", session: "SparkConnectClient") -> proto.Expression:
+ if isinstance(col, Column):
+ return col.to_plan(session)
+ else:
+ return self.unresolved_attr(col)
+
+ def plan(self, session: "SparkConnectClient") -> proto.Relation:
+ assert self._child is not None
+
+ plan = proto.Relation()
+ plan.unpivot.input.CopyFrom(self._child.plan(session))
+ plan.unpivot.ids.extend([self.col_to_expr(x, session) for x in self.ids])
+ plan.unpivot.values.extend([self.col_to_expr(x, session) for x in self.values])
+ plan.unpivot.variable_column_name = self.variable_column_name
+ plan.unpivot.value_column_name = self.value_column_name
+ return plan
+
+ def print(self, indent: int = 0) -> str:
+ c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child else ""
+ return (
+ f"{' ' * indent}"
+ f""
+ f"\n{c_buf}"
+ )
+
+ def _repr_html_(self) -> str:
+ return f"""
+
+ -
+ Unpivot
+ ids: {self.ids}
+ values: {self.values}
+ variable_column_name: {self.variable_column_name}
+ value_column_name: {self.value_column_name}
+ {self._child._repr_html_() if self._child is not None else ""}
+
+
+ """
+
+
class NAFill(LogicalPlan):
def __init__(
self, child: Optional["LogicalPlan"], cols: Optional[List[str]], values: List[Any]
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 8510216324..91c57a9ef2 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -33,7 +33,7 @@
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xd2\x14\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_string\x18\x04 \x01(\x0b\x32*.spark.connect.Expression.ExpressionStringH\x00R\x10\x65xpressionString\x12S\n\x0funresolved_star\x18\x05 \x01(\x0b\x32(.spark.connect.Expression.UnresolvedStarH\x00R\x0eunresolvedStar\x12\x37\n\x05\x61lias\x18\x06 \x01(\x0b\x32\x1f.spark.connect.Expression.AliasH\x00R\x05\x61lias\x12\x34\n\x04\x63\x61st\x18\x07 \x01(\x0b\x32\x1e.spark.connect.Expression.CastH\x00R\x04\x63\x61st\x1ap\n\x04\x43\x61st\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x39\n\x0c\x63\x61st_to_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\ncastToType\x1a\xb2\x0b\n\x07Literal\x12\x14\n\x04null\x18\x01 \x01(\x08H\x00R\x04null\x12\x18\n\x06\x62inary\x18\x02 \x01(\x0cH\x00R\x06\x62inary\x12\x1a\n\x07\x62oolean\x18\x03 \x01(\x08H\x00R\x07\x62oolean\x12\x14\n\x04\x62yte\x18\x04 \x01(\x05H\x00R\x04\x62yte\x12\x16\n\x05short\x18\x05 \x01(\x05H\x00R\x05short\x12\x1a\n\x07integer\x18\x06 \x01(\x05H\x00R\x07integer\x12\x14\n\x04long\x18\x07 \x01(\x03H\x00R\x04long\x12\x16\n\x05\x66loat\x18\n \x01(\x02H\x00R\x05\x66loat\x12\x18\n\x06\x64ouble\x18\x0b \x01(\x01H\x00R\x06\x64ouble\x12\x45\n\x07\x64\x65\x63imal\x18\x0c \x01(\x0b\x32).spark.connect.Expression.Literal.DecimalH\x00R\x07\x64\x65\x63imal\x12\x18\n\x06string\x18\r \x01(\tH\x00R\x06string\x12\x14\n\x04\x64\x61te\x18\x10 \x01(\x05H\x00R\x04\x64\x61te\x12\x1e\n\ttimestamp\x18\x11 \x01(\x03H\x00R\ttimestamp\x12%\n\rtimestamp_ntz\x18\x12 \x01(\x03H\x00R\x0ctimestampNtz\x12\x61\n\x11\x63\x61lendar_interval\x18\x13 \x01(\x0b\x32\x32.spark.connect.Expression.Literal.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12\x30\n\x13year_month_interval\x18\x14 \x01(\x05H\x00R\x11yearMonthInterval\x12,\n\x11\x64\x61y_time_interval\x18\x15 \x01(\x03H\x00R\x0f\x64\x61yTimeInterval\x12?\n\x05\x61rray\x18\x16 \x01(\x0b\x32\'.spark.connect.Expression.Literal.ArrayH\x00R\x05\x61rray\x12\x42\n\x06struct\x18\x17 \x01(\x0b\x32(.spark.connect.Expression.Literal.StructH\x00R\x06struct\x12\x39\n\x03map\x18\x18 \x01(\x0b\x32%.spark.connect.Expression.Literal.MapH\x00R\x03map\x12\x1a\n\x08nullable\x18\x32 \x01(\x08R\x08nullable\x12\x38\n\x18type_variation_reference\x18\x33 \x01(\rR\x16typeVariationReference\x1au\n\x07\x44\x65\x63imal\x12\x14\n\x05value\x18\x01 \x01(\tR\x05value\x12!\n\tprecision\x18\x02 \x01(\x05H\x00R\tprecision\x88\x01\x01\x12\x19\n\x05scale\x18\x03 \x01(\x05H\x01R\x05scale\x88\x01\x01\x42\x0c\n\n_precisionB\x08\n\x06_scale\x1a\x62\n\x10\x43\x61lendarInterval\x12\x16\n\x06months\x18\x01 \x01(\x05R\x06months\x12\x12\n\x04\x64\x61ys\x18\x02 \x01(\x05R\x04\x64\x61ys\x12"\n\x0cmicroseconds\x18\x03 \x01(\x03R\x0cmicroseconds\x1a\x43\n\x06Struct\x12\x39\n\x06\x66ields\x18\x01 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06\x66ields\x1a\x42\n\x05\x41rray\x12\x39\n\x06values\x18\x01 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1a\xbd\x01\n\x03Map\x12@\n\x05pairs\x18\x01 \x03(\x0b\x32*.spark.connect.Expression.Literal.Map.PairR\x05pairs\x1at\n\x04Pair\x12\x33\n\x03key\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05valueB\x0e\n\x0cliteral_type\x1a\x46\n\x13UnresolvedAttribute\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xcc\x01\n\x12UnresolvedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x1f\n\x0bis_distinct\x18\x03 \x01(\x08R\nisDistinct\x12\x37\n\x18is_user_defined_function\x18\x04 \x01(\x08R\x15isUserDefinedFunction\x1a\x32\n\x10\x45xpressionString\x12\x1e\n\nexpression\x18\x01 \x01(\tR\nexpression\x1a(\n\x0eUnresolvedStar\x12\x16\n\x06target\x18\x01 \x03(\tR\x06target\x1ax\n\x05\x41lias\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x12\n\x04name\x18\x02 \x03(\tR\x04name\x12\x1f\n\x08metadata\x18\x03 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadataB\x0b\n\texpr_typeB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3'
+ b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xf4\x14\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_string\x18\x04 \x01(\x0b\x32*.spark.connect.Expression.ExpressionStringH\x00R\x10\x65xpressionString\x12S\n\x0funresolved_star\x18\x05 \x01(\x0b\x32(.spark.connect.Expression.UnresolvedStarH\x00R\x0eunresolvedStar\x12\x37\n\x05\x61lias\x18\x06 \x01(\x0b\x32\x1f.spark.connect.Expression.AliasH\x00R\x05\x61lias\x12\x34\n\x04\x63\x61st\x18\x07 \x01(\x0b\x32\x1e.spark.connect.Expression.CastH\x00R\x04\x63\x61st\x1a\x91\x01\n\x04\x43\x61st\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12-\n\x04type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04type\x12\x1b\n\x08type_str\x18\x03 \x01(\tH\x00R\x07typeStrB\x0e\n\x0c\x63\x61st_to_type\x1a\xb2\x0b\n\x07Literal\x12\x14\n\x04null\x18\x01 \x01(\x08H\x00R\x04null\x12\x18\n\x06\x62inary\x18\x02 \x01(\x0cH\x00R\x06\x62inary\x12\x1a\n\x07\x62oolean\x18\x03 \x01(\x08H\x00R\x07\x62oolean\x12\x14\n\x04\x62yte\x18\x04 \x01(\x05H\x00R\x04\x62yte\x12\x16\n\x05short\x18\x05 \x01(\x05H\x00R\x05short\x12\x1a\n\x07integer\x18\x06 \x01(\x05H\x00R\x07integer\x12\x14\n\x04long\x18\x07 \x01(\x03H\x00R\x04long\x12\x16\n\x05\x66loat\x18\n \x01(\x02H\x00R\x05\x66loat\x12\x18\n\x06\x64ouble\x18\x0b \x01(\x01H\x00R\x06\x64ouble\x12\x45\n\x07\x64\x65\x63imal\x18\x0c \x01(\x0b\x32).spark.connect.Expression.Literal.DecimalH\x00R\x07\x64\x65\x63imal\x12\x18\n\x06string\x18\r \x01(\tH\x00R\x06string\x12\x14\n\x04\x64\x61te\x18\x10 \x01(\x05H\x00R\x04\x64\x61te\x12\x1e\n\ttimestamp\x18\x11 \x01(\x03H\x00R\ttimestamp\x12%\n\rtimestamp_ntz\x18\x12 \x01(\x03H\x00R\x0ctimestampNtz\x12\x61\n\x11\x63\x61lendar_interval\x18\x13 \x01(\x0b\x32\x32.spark.connect.Expression.Literal.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12\x30\n\x13year_month_interval\x18\x14 \x01(\x05H\x00R\x11yearMonthInterval\x12,\n\x11\x64\x61y_time_interval\x18\x15 \x01(\x03H\x00R\x0f\x64\x61yTimeInterval\x12?\n\x05\x61rray\x18\x16 \x01(\x0b\x32\'.spark.connect.Expression.Literal.ArrayH\x00R\x05\x61rray\x12\x42\n\x06struct\x18\x17 \x01(\x0b\x32(.spark.connect.Expression.Literal.StructH\x00R\x06struct\x12\x39\n\x03map\x18\x18 \x01(\x0b\x32%.spark.connect.Expression.Literal.MapH\x00R\x03map\x12\x1a\n\x08nullable\x18\x32 \x01(\x08R\x08nullable\x12\x38\n\x18type_variation_reference\x18\x33 \x01(\rR\x16typeVariationReference\x1au\n\x07\x44\x65\x63imal\x12\x14\n\x05value\x18\x01 \x01(\tR\x05value\x12!\n\tprecision\x18\x02 \x01(\x05H\x00R\tprecision\x88\x01\x01\x12\x19\n\x05scale\x18\x03 \x01(\x05H\x01R\x05scale\x88\x01\x01\x42\x0c\n\n_precisionB\x08\n\x06_scale\x1a\x62\n\x10\x43\x61lendarInterval\x12\x16\n\x06months\x18\x01 \x01(\x05R\x06months\x12\x12\n\x04\x64\x61ys\x18\x02 \x01(\x05R\x04\x64\x61ys\x12"\n\x0cmicroseconds\x18\x03 \x01(\x03R\x0cmicroseconds\x1a\x43\n\x06Struct\x12\x39\n\x06\x66ields\x18\x01 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06\x66ields\x1a\x42\n\x05\x41rray\x12\x39\n\x06values\x18\x01 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1a\xbd\x01\n\x03Map\x12@\n\x05pairs\x18\x01 \x03(\x0b\x32*.spark.connect.Expression.Literal.Map.PairR\x05pairs\x1at\n\x04Pair\x12\x33\n\x03key\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05valueB\x0e\n\x0cliteral_type\x1a\x46\n\x13UnresolvedAttribute\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xcc\x01\n\x12UnresolvedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x1f\n\x0bis_distinct\x18\x03 \x01(\x08R\nisDistinct\x12\x37\n\x18is_user_defined_function\x18\x04 \x01(\x08R\x15isUserDefinedFunction\x1a\x32\n\x10\x45xpressionString\x12\x1e\n\nexpression\x18\x01 \x01(\tR\nexpression\x1a(\n\x0eUnresolvedStar\x12\x16\n\x06target\x18\x01 \x03(\tR\x06target\x1ax\n\x05\x41lias\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x12\n\x04name\x18\x02 \x03(\tR\x04name\x12\x1f\n\x08metadata\x18\x03 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadataB\x0b\n\texpr_typeB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3'
)
@@ -197,31 +197,31 @@
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001"
_EXPRESSION._serialized_start = 78
- _EXPRESSION._serialized_end = 2720
- _EXPRESSION_CAST._serialized_start = 639
- _EXPRESSION_CAST._serialized_end = 751
- _EXPRESSION_LITERAL._serialized_start = 754
- _EXPRESSION_LITERAL._serialized_end = 2212
- _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1650
- _EXPRESSION_LITERAL_DECIMAL._serialized_end = 1767
- _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 1769
- _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 1867
- _EXPRESSION_LITERAL_STRUCT._serialized_start = 1869
- _EXPRESSION_LITERAL_STRUCT._serialized_end = 1936
- _EXPRESSION_LITERAL_ARRAY._serialized_start = 1938
- _EXPRESSION_LITERAL_ARRAY._serialized_end = 2004
- _EXPRESSION_LITERAL_MAP._serialized_start = 2007
- _EXPRESSION_LITERAL_MAP._serialized_end = 2196
- _EXPRESSION_LITERAL_MAP_PAIR._serialized_start = 2080
- _EXPRESSION_LITERAL_MAP_PAIR._serialized_end = 2196
- _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2214
- _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2284
- _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2287
- _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2491
- _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2493
- _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2543
- _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2545
- _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2585
- _EXPRESSION_ALIAS._serialized_start = 2587
- _EXPRESSION_ALIAS._serialized_end = 2707
+ _EXPRESSION._serialized_end = 2754
+ _EXPRESSION_CAST._serialized_start = 640
+ _EXPRESSION_CAST._serialized_end = 785
+ _EXPRESSION_LITERAL._serialized_start = 788
+ _EXPRESSION_LITERAL._serialized_end = 2246
+ _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1684
+ _EXPRESSION_LITERAL_DECIMAL._serialized_end = 1801
+ _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 1803
+ _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 1901
+ _EXPRESSION_LITERAL_STRUCT._serialized_start = 1903
+ _EXPRESSION_LITERAL_STRUCT._serialized_end = 1970
+ _EXPRESSION_LITERAL_ARRAY._serialized_start = 1972
+ _EXPRESSION_LITERAL_ARRAY._serialized_end = 2038
+ _EXPRESSION_LITERAL_MAP._serialized_start = 2041
+ _EXPRESSION_LITERAL_MAP._serialized_end = 2230
+ _EXPRESSION_LITERAL_MAP_PAIR._serialized_start = 2114
+ _EXPRESSION_LITERAL_MAP_PAIR._serialized_end = 2230
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2248
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2318
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2321
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2525
+ _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2527
+ _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2577
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2579
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2619
+ _EXPRESSION_ALIAS._serialized_start = 2621
+ _EXPRESSION_ALIAS._serialized_end = 2741
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index c1034a8636..2c486f62a9 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -60,27 +60,51 @@ class Expression(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
EXPR_FIELD_NUMBER: builtins.int
- CAST_TO_TYPE_FIELD_NUMBER: builtins.int
+ TYPE_FIELD_NUMBER: builtins.int
+ TYPE_STR_FIELD_NUMBER: builtins.int
@property
def expr(self) -> global___Expression:
"""(Required) the expression to be casted."""
@property
- def cast_to_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
- """(Required) the data type that the expr to be casted to."""
+ def type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
+ type_str: builtins.str
+ """If this is set, Server will use Catalyst parser to parse this string to DataType."""
def __init__(
self,
*,
expr: global___Expression | None = ...,
- cast_to_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
+ type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
+ type_str: builtins.str = ...,
) -> None: ...
def HasField(
self,
- field_name: typing_extensions.Literal["cast_to_type", b"cast_to_type", "expr", b"expr"],
+ field_name: typing_extensions.Literal[
+ "cast_to_type",
+ b"cast_to_type",
+ "expr",
+ b"expr",
+ "type",
+ b"type",
+ "type_str",
+ b"type_str",
+ ],
) -> builtins.bool: ...
def ClearField(
self,
- field_name: typing_extensions.Literal["cast_to_type", b"cast_to_type", "expr", b"expr"],
+ field_name: typing_extensions.Literal[
+ "cast_to_type",
+ b"cast_to_type",
+ "expr",
+ b"expr",
+ "type",
+ b"type",
+ "type_str",
+ b"type_str",
+ ],
) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["cast_to_type", b"cast_to_type"]
+ ) -> typing_extensions.Literal["type", "type_str"] | None: ...
class Literal(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py
index 06cf18417d..d1651d0b72 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -30,10 +30,11 @@
from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_expressions__pb2
+from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8b\x0e\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12\x34\n\x06set_op\x18\x06 \x01(\x0b\x32\x1b.spark.connect.SetOperationH\x00R\x05setOp\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12,\n\x05range\x18\x0f \x01(\x0b\x32\x14.spark.connect.RangeH\x00R\x05range\x12\x45\n\x0esubquery_alias\x18\x10 \x01(\x0b\x32\x1c.spark.connect.SubqueryAliasH\x00R\rsubqueryAlias\x12>\n\x0brepartition\x18\x11 \x01(\x0b\x32\x1a.spark.connect.RepartitionH\x00R\x0brepartition\x12|\n#rename_columns_by_same_length_names\x18\x12 \x01(\x0b\x32-.spark.connect.RenameColumnsBySameLengthNamesH\x00R\x1erenameColumnsBySameLengthNames\x12w\n"rename_columns_by_name_to_name_map\x18\x13 \x01(\x0b\x32+.spark.connect.RenameColumnsByNameToNameMapH\x00R\x1crenameColumnsByNameToNameMap\x12<\n\x0bshow_string\x18\x14 \x01(\x0b\x32\x19.spark.connect.ShowStringH\x00R\nshowString\x12)\n\x04\x64rop\x18\x15 \x01(\x0b\x32\x13.spark.connect.DropH\x00R\x04\x64rop\x12)\n\x04tail\x18\x16 \x01(\x0b\x32\x13.spark.connect.TailH\x00R\x04tail\x12?\n\x0cwith_columns\x18\x17 \x01(\x0b\x32\x1a.spark.connect.WithColumnsH\x00R\x0bwithColumns\x12)\n\x04hint\x18\x18 \x01(\x0b\x32\x13.spark.connect.HintH\x00R\x04hint\x12\x30\n\x07\x66ill_na\x18Z \x01(\x0b\x32\x15.spark.connect.NAFillH\x00R\x06\x66illNa\x12\x30\n\x07\x64rop_na\x18[ \x01(\x0b\x32\x15.spark.connect.NADropH\x00R\x06\x64ropNa\x12\x34\n\x07replace\x18\\ \x01(\x0b\x32\x18.spark.connect.NAReplaceH\x00R\x07replace\x12\x36\n\x07summary\x18\x64 \x01(\x0b\x32\x1a.spark.connect.StatSummaryH\x00R\x07summary\x12\x39\n\x08\x63rosstab\x18\x65 \x01(\x0b\x32\x1b.spark.connect.StatCrosstabH\x00R\x08\x63rosstab\x12\x39\n\x08\x64\x65scribe\x18\x66 \x01(\x0b\x32\x1b.spark.connect.StatDescribeH\x00R\x08\x64\x65scribe\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"1\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo"\x1b\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query"\xaa\x03\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x1a=\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xcf\x01\n\nDataSource\x12\x16\n\x06\x66ormat\x18\x01 \x01(\tR\x06\x66ormat\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x00R\x06schema\x88\x01\x01\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_schemaB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\xd7\x03\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType\x12#\n\rusing_columns\x18\x05 \x03(\tR\x0cusingColumns"\xd0\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06\x12\x13\n\x0fJOIN_TYPE_CROSS\x10\x07"\x8c\x03\n\x0cSetOperation\x12\x36\n\nleft_input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\tleftInput\x12\x38\n\x0bright_input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\nrightInput\x12\x45\n\x0bset_op_type\x18\x03 \x01(\x0e\x32%.spark.connect.SetOperation.SetOpTypeR\tsetOpType\x12\x1a\n\x06is_all\x18\x04 \x01(\x08H\x00R\x05isAll\x88\x01\x01\x12\x1c\n\x07\x62y_name\x18\x05 \x01(\x08H\x01R\x06\x62yName\x88\x01\x01"r\n\tSetOpType\x12\x1b\n\x17SET_OP_TYPE_UNSPECIFIED\x10\x00\x12\x19\n\x15SET_OP_TYPE_INTERSECT\x10\x01\x12\x15\n\x11SET_OP_TYPE_UNION\x10\x02\x12\x16\n\x12SET_OP_TYPE_EXCEPT\x10\x03\x42\t\n\x07_is_allB\n\n\x08_by_name"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"K\n\x04Tail\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"\xd2\x01\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12H\n\x12result_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x11resultExpressions"\xa6\x04\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\x0bsort_fields\x18\x02 \x03(\x0b\x32\x1d.spark.connect.Sort.SortFieldR\nsortFields\x12 \n\tis_global\x18\x03 \x01(\x08H\x00R\x08isGlobal\x88\x01\x01\x1a\xbc\x01\n\tSortField\x12\x39\n\nexpression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nexpression\x12?\n\tdirection\x18\x02 \x01(\x0e\x32!.spark.connect.Sort.SortDirectionR\tdirection\x12\x33\n\x05nulls\x18\x03 \x01(\x0e\x32\x1d.spark.connect.Sort.SortNullsR\x05nulls"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"R\n\tSortNulls\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02\x42\x0c\n\n_is_global"d\n\x04\x44rop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12-\n\x04\x63ols\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x04\x63ols"\xab\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12\x32\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08H\x00R\x10\x61llColumnsAsKeys\x88\x01\x01\x42\x16\n\x14_all_columns_as_keys"#\n\rLocalRelation\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta"\xe0\x01\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12.\n\x10with_replacement\x18\x04 \x01(\x08H\x00R\x0fwithReplacement\x88\x01\x01\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x01R\x04seed\x88\x01\x01\x42\x13\n\x11_with_replacementB\x07\n\x05_seed"\x91\x01\n\x05Range\x12\x19\n\x05start\x18\x01 \x01(\x03H\x00R\x05start\x88\x01\x01\x12\x10\n\x03\x65nd\x18\x02 \x01(\x03R\x03\x65nd\x12\x12\n\x04step\x18\x03 \x01(\x03R\x04step\x12*\n\x0enum_partitions\x18\x04 \x01(\x05H\x01R\rnumPartitions\x88\x01\x01\x42\x08\n\x06_startB\x11\n\x0f_num_partitions"r\n\rSubqueryAlias\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\x12\x1c\n\tqualifier\x18\x03 \x03(\tR\tqualifier"\x8e\x01\n\x0bRepartition\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12%\n\x0enum_partitions\x18\x02 \x01(\x05R\rnumPartitions\x12\x1d\n\x07shuffle\x18\x03 \x01(\x08H\x00R\x07shuffle\x88\x01\x01\x42\n\n\x08_shuffle"\x8d\x01\n\nShowString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x18\n\x07numRows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate\x12\x1a\n\x08vertical\x18\x04 \x01(\x08R\x08vertical"\\\n\x0bStatSummary\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1e\n\nstatistics\x18\x02 \x03(\tR\nstatistics"Q\n\x0cStatDescribe\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols"e\n\x0cStatCrosstab\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"\x86\x01\n\x06NAFill\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x39\n\x06values\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values"\x86\x01\n\x06NADrop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\'\n\rmin_non_nulls\x18\x03 \x01(\x05H\x00R\x0bminNonNulls\x88\x01\x01\x42\x10\n\x0e_min_non_nulls"\xa8\x02\n\tNAReplace\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12H\n\x0creplacements\x18\x03 \x03(\x0b\x32$.spark.connect.NAReplace.ReplacementR\x0creplacements\x1a\x8d\x01\n\x0bReplacement\x12>\n\told_value\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08oldValue\x12>\n\tnew_value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08newValue"r\n\x1eRenameColumnsBySameLengthNames\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames"\x83\x02\n\x1cRenameColumnsByNameToNameMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12o\n\x12rename_columns_map\x18\x02 \x03(\x0b\x32\x41.spark.connect.RenameColumnsByNameToNameMap.RenameColumnsMapEntryR\x10renameColumnsMap\x1a\x43\n\x15RenameColumnsMapEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"\x83\x01\n\x0bWithColumns\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x45\n\x0ename_expr_list\x18\x02 \x03(\x0b\x32\x1f.spark.connect.Expression.AliasR\x0cnameExprList"\x8c\x01\n\x04Hint\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x41\n\nparameters\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\nparametersB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3'
+ b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto"\xbf\x0e\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12\x34\n\x06set_op\x18\x06 \x01(\x0b\x32\x1b.spark.connect.SetOperationH\x00R\x05setOp\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12,\n\x05range\x18\x0f \x01(\x0b\x32\x14.spark.connect.RangeH\x00R\x05range\x12\x45\n\x0esubquery_alias\x18\x10 \x01(\x0b\x32\x1c.spark.connect.SubqueryAliasH\x00R\rsubqueryAlias\x12>\n\x0brepartition\x18\x11 \x01(\x0b\x32\x1a.spark.connect.RepartitionH\x00R\x0brepartition\x12|\n#rename_columns_by_same_length_names\x18\x12 \x01(\x0b\x32-.spark.connect.RenameColumnsBySameLengthNamesH\x00R\x1erenameColumnsBySameLengthNames\x12w\n"rename_columns_by_name_to_name_map\x18\x13 \x01(\x0b\x32+.spark.connect.RenameColumnsByNameToNameMapH\x00R\x1crenameColumnsByNameToNameMap\x12<\n\x0bshow_string\x18\x14 \x01(\x0b\x32\x19.spark.connect.ShowStringH\x00R\nshowString\x12)\n\x04\x64rop\x18\x15 \x01(\x0b\x32\x13.spark.connect.DropH\x00R\x04\x64rop\x12)\n\x04tail\x18\x16 \x01(\x0b\x32\x13.spark.connect.TailH\x00R\x04tail\x12?\n\x0cwith_columns\x18\x17 \x01(\x0b\x32\x1a.spark.connect.WithColumnsH\x00R\x0bwithColumns\x12)\n\x04hint\x18\x18 \x01(\x0b\x32\x13.spark.connect.HintH\x00R\x04hint\x12\x32\n\x07unpivot\x18\x19 \x01(\x0b\x32\x16.spark.connect.UnpivotH\x00R\x07unpivot\x12\x30\n\x07\x66ill_na\x18Z \x01(\x0b\x32\x15.spark.connect.NAFillH\x00R\x06\x66illNa\x12\x30\n\x07\x64rop_na\x18[ \x01(\x0b\x32\x15.spark.connect.NADropH\x00R\x06\x64ropNa\x12\x34\n\x07replace\x18\\ \x01(\x0b\x32\x18.spark.connect.NAReplaceH\x00R\x07replace\x12\x36\n\x07summary\x18\x64 \x01(\x0b\x32\x1a.spark.connect.StatSummaryH\x00R\x07summary\x12\x39\n\x08\x63rosstab\x18\x65 \x01(\x0b\x32\x1b.spark.connect.StatCrosstabH\x00R\x08\x63rosstab\x12\x39\n\x08\x64\x65scribe\x18\x66 \x01(\x0b\x32\x1b.spark.connect.StatDescribeH\x00R\x08\x64\x65scribe\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"1\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo"\x1b\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query"\xaa\x03\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x1a=\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xcf\x01\n\nDataSource\x12\x16\n\x06\x66ormat\x18\x01 \x01(\tR\x06\x66ormat\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x00R\x06schema\x88\x01\x01\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_schemaB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\xd7\x03\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType\x12#\n\rusing_columns\x18\x05 \x03(\tR\x0cusingColumns"\xd0\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06\x12\x13\n\x0fJOIN_TYPE_CROSS\x10\x07"\x8c\x03\n\x0cSetOperation\x12\x36\n\nleft_input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\tleftInput\x12\x38\n\x0bright_input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\nrightInput\x12\x45\n\x0bset_op_type\x18\x03 \x01(\x0e\x32%.spark.connect.SetOperation.SetOpTypeR\tsetOpType\x12\x1a\n\x06is_all\x18\x04 \x01(\x08H\x00R\x05isAll\x88\x01\x01\x12\x1c\n\x07\x62y_name\x18\x05 \x01(\x08H\x01R\x06\x62yName\x88\x01\x01"r\n\tSetOpType\x12\x1b\n\x17SET_OP_TYPE_UNSPECIFIED\x10\x00\x12\x19\n\x15SET_OP_TYPE_INTERSECT\x10\x01\x12\x15\n\x11SET_OP_TYPE_UNION\x10\x02\x12\x16\n\x12SET_OP_TYPE_EXCEPT\x10\x03\x42\t\n\x07_is_allB\n\n\x08_by_name"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"K\n\x04Tail\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"\xd2\x01\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12H\n\x12result_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x11resultExpressions"\xa6\x04\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\x0bsort_fields\x18\x02 \x03(\x0b\x32\x1d.spark.connect.Sort.SortFieldR\nsortFields\x12 \n\tis_global\x18\x03 \x01(\x08H\x00R\x08isGlobal\x88\x01\x01\x1a\xbc\x01\n\tSortField\x12\x39\n\nexpression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nexpression\x12?\n\tdirection\x18\x02 \x01(\x0e\x32!.spark.connect.Sort.SortDirectionR\tdirection\x12\x33\n\x05nulls\x18\x03 \x01(\x0e\x32\x1d.spark.connect.Sort.SortNullsR\x05nulls"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"R\n\tSortNulls\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02\x42\x0c\n\n_is_global"d\n\x04\x44rop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12-\n\x04\x63ols\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x04\x63ols"\xab\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12\x32\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08H\x00R\x10\x61llColumnsAsKeys\x88\x01\x01\x42\x16\n\x14_all_columns_as_keys"\x89\x01\n\rLocalRelation\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\x12\x35\n\x08\x64\x61tatype\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x08\x64\x61tatype\x12#\n\x0c\x64\x61tatype_str\x18\x03 \x01(\tH\x00R\x0b\x64\x61tatypeStrB\x08\n\x06schema"\xe0\x01\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12.\n\x10with_replacement\x18\x04 \x01(\x08H\x00R\x0fwithReplacement\x88\x01\x01\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x01R\x04seed\x88\x01\x01\x42\x13\n\x11_with_replacementB\x07\n\x05_seed"\x91\x01\n\x05Range\x12\x19\n\x05start\x18\x01 \x01(\x03H\x00R\x05start\x88\x01\x01\x12\x10\n\x03\x65nd\x18\x02 \x01(\x03R\x03\x65nd\x12\x12\n\x04step\x18\x03 \x01(\x03R\x04step\x12*\n\x0enum_partitions\x18\x04 \x01(\x05H\x01R\rnumPartitions\x88\x01\x01\x42\x08\n\x06_startB\x11\n\x0f_num_partitions"r\n\rSubqueryAlias\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\x12\x1c\n\tqualifier\x18\x03 \x03(\tR\tqualifier"\x8e\x01\n\x0bRepartition\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12%\n\x0enum_partitions\x18\x02 \x01(\x05R\rnumPartitions\x12\x1d\n\x07shuffle\x18\x03 \x01(\x08H\x00R\x07shuffle\x88\x01\x01\x42\n\n\x08_shuffle"\x8d\x01\n\nShowString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x18\n\x07numRows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate\x12\x1a\n\x08vertical\x18\x04 \x01(\x08R\x08vertical"\\\n\x0bStatSummary\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1e\n\nstatistics\x18\x02 \x03(\tR\nstatistics"Q\n\x0cStatDescribe\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols"e\n\x0cStatCrosstab\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"\x86\x01\n\x06NAFill\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x39\n\x06values\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values"\x86\x01\n\x06NADrop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\'\n\rmin_non_nulls\x18\x03 \x01(\x05H\x00R\x0bminNonNulls\x88\x01\x01\x42\x10\n\x0e_min_non_nulls"\xa8\x02\n\tNAReplace\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12H\n\x0creplacements\x18\x03 \x03(\x0b\x32$.spark.connect.NAReplace.ReplacementR\x0creplacements\x1a\x8d\x01\n\x0bReplacement\x12>\n\told_value\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08oldValue\x12>\n\tnew_value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08newValue"r\n\x1eRenameColumnsBySameLengthNames\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames"\x83\x02\n\x1cRenameColumnsByNameToNameMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12o\n\x12rename_columns_map\x18\x02 \x03(\x0b\x32\x41.spark.connect.RenameColumnsByNameToNameMap.RenameColumnsMapEntryR\x10renameColumnsMap\x1a\x43\n\x15RenameColumnsMapEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"\x83\x01\n\x0bWithColumns\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x45\n\x0ename_expr_list\x18\x02 \x03(\x0b\x32\x1f.spark.connect.Expression.AliasR\x0cnameExprList"\x8c\x01\n\x04Hint\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x41\n\nparameters\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\nparameters"\xf6\x01\n\x07Unpivot\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12+\n\x03ids\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x03ids\x12\x31\n\x06values\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x06values\x12\x30\n\x14variable_column_name\x18\x04 \x01(\tR\x12variableColumnName\x12*\n\x11value_column_name\x18\x05 \x01(\tR\x0fvalueColumnNameB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3'
)
@@ -77,6 +78,7 @@
)
_WITHCOLUMNS = DESCRIPTOR.message_types_by_name["WithColumns"]
_HINT = DESCRIPTOR.message_types_by_name["Hint"]
+_UNPIVOT = DESCRIPTOR.message_types_by_name["Unpivot"]
_JOIN_JOINTYPE = _JOIN.enum_types_by_name["JoinType"]
_SETOPERATION_SETOPTYPE = _SETOPERATION.enum_types_by_name["SetOpType"]
_SORT_SORTDIRECTION = _SORT.enum_types_by_name["SortDirection"]
@@ -493,6 +495,17 @@
)
_sym_db.RegisterMessage(Hint)
+Unpivot = _reflection.GeneratedProtocolMessageType(
+ "Unpivot",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _UNPIVOT,
+ "__module__": "spark.connect.relations_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.Unpivot)
+ },
+)
+_sym_db.RegisterMessage(Unpivot)
+
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
@@ -501,88 +514,90 @@
_READ_DATASOURCE_OPTIONSENTRY._serialized_options = b"8\001"
_RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._options = None
_RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_options = b"8\001"
- _RELATION._serialized_start = 82
- _RELATION._serialized_end = 1885
- _UNKNOWN._serialized_start = 1887
- _UNKNOWN._serialized_end = 1896
- _RELATIONCOMMON._serialized_start = 1898
- _RELATIONCOMMON._serialized_end = 1947
- _SQL._serialized_start = 1949
- _SQL._serialized_end = 1976
- _READ._serialized_start = 1979
- _READ._serialized_end = 2405
- _READ_NAMEDTABLE._serialized_start = 2121
- _READ_NAMEDTABLE._serialized_end = 2182
- _READ_DATASOURCE._serialized_start = 2185
- _READ_DATASOURCE._serialized_end = 2392
- _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2323
- _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2381
- _PROJECT._serialized_start = 2407
- _PROJECT._serialized_end = 2524
- _FILTER._serialized_start = 2526
- _FILTER._serialized_end = 2638
- _JOIN._serialized_start = 2641
- _JOIN._serialized_end = 3112
- _JOIN_JOINTYPE._serialized_start = 2904
- _JOIN_JOINTYPE._serialized_end = 3112
- _SETOPERATION._serialized_start = 3115
- _SETOPERATION._serialized_end = 3511
- _SETOPERATION_SETOPTYPE._serialized_start = 3374
- _SETOPERATION_SETOPTYPE._serialized_end = 3488
- _LIMIT._serialized_start = 3513
- _LIMIT._serialized_end = 3589
- _OFFSET._serialized_start = 3591
- _OFFSET._serialized_end = 3670
- _TAIL._serialized_start = 3672
- _TAIL._serialized_end = 3747
- _AGGREGATE._serialized_start = 3750
- _AGGREGATE._serialized_end = 3960
- _SORT._serialized_start = 3963
- _SORT._serialized_end = 4513
- _SORT_SORTFIELD._serialized_start = 4117
- _SORT_SORTFIELD._serialized_end = 4305
- _SORT_SORTDIRECTION._serialized_start = 4307
- _SORT_SORTDIRECTION._serialized_end = 4415
- _SORT_SORTNULLS._serialized_start = 4417
- _SORT_SORTNULLS._serialized_end = 4499
- _DROP._serialized_start = 4515
- _DROP._serialized_end = 4615
- _DEDUPLICATE._serialized_start = 4618
- _DEDUPLICATE._serialized_end = 4789
- _LOCALRELATION._serialized_start = 4791
- _LOCALRELATION._serialized_end = 4826
- _SAMPLE._serialized_start = 4829
- _SAMPLE._serialized_end = 5053
- _RANGE._serialized_start = 5056
- _RANGE._serialized_end = 5201
- _SUBQUERYALIAS._serialized_start = 5203
- _SUBQUERYALIAS._serialized_end = 5317
- _REPARTITION._serialized_start = 5320
- _REPARTITION._serialized_end = 5462
- _SHOWSTRING._serialized_start = 5465
- _SHOWSTRING._serialized_end = 5606
- _STATSUMMARY._serialized_start = 5608
- _STATSUMMARY._serialized_end = 5700
- _STATDESCRIBE._serialized_start = 5702
- _STATDESCRIBE._serialized_end = 5783
- _STATCROSSTAB._serialized_start = 5785
- _STATCROSSTAB._serialized_end = 5886
- _NAFILL._serialized_start = 5889
- _NAFILL._serialized_end = 6023
- _NADROP._serialized_start = 6026
- _NADROP._serialized_end = 6160
- _NAREPLACE._serialized_start = 6163
- _NAREPLACE._serialized_end = 6459
- _NAREPLACE_REPLACEMENT._serialized_start = 6318
- _NAREPLACE_REPLACEMENT._serialized_end = 6459
- _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 6461
- _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 6575
- _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 6578
- _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 6837
- _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 6770
- _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 6837
- _WITHCOLUMNS._serialized_start = 6840
- _WITHCOLUMNS._serialized_end = 6971
- _HINT._serialized_start = 6974
- _HINT._serialized_end = 7114
+ _RELATION._serialized_start = 109
+ _RELATION._serialized_end = 1964
+ _UNKNOWN._serialized_start = 1966
+ _UNKNOWN._serialized_end = 1975
+ _RELATIONCOMMON._serialized_start = 1977
+ _RELATIONCOMMON._serialized_end = 2026
+ _SQL._serialized_start = 2028
+ _SQL._serialized_end = 2055
+ _READ._serialized_start = 2058
+ _READ._serialized_end = 2484
+ _READ_NAMEDTABLE._serialized_start = 2200
+ _READ_NAMEDTABLE._serialized_end = 2261
+ _READ_DATASOURCE._serialized_start = 2264
+ _READ_DATASOURCE._serialized_end = 2471
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2402
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2460
+ _PROJECT._serialized_start = 2486
+ _PROJECT._serialized_end = 2603
+ _FILTER._serialized_start = 2605
+ _FILTER._serialized_end = 2717
+ _JOIN._serialized_start = 2720
+ _JOIN._serialized_end = 3191
+ _JOIN_JOINTYPE._serialized_start = 2983
+ _JOIN_JOINTYPE._serialized_end = 3191
+ _SETOPERATION._serialized_start = 3194
+ _SETOPERATION._serialized_end = 3590
+ _SETOPERATION_SETOPTYPE._serialized_start = 3453
+ _SETOPERATION_SETOPTYPE._serialized_end = 3567
+ _LIMIT._serialized_start = 3592
+ _LIMIT._serialized_end = 3668
+ _OFFSET._serialized_start = 3670
+ _OFFSET._serialized_end = 3749
+ _TAIL._serialized_start = 3751
+ _TAIL._serialized_end = 3826
+ _AGGREGATE._serialized_start = 3829
+ _AGGREGATE._serialized_end = 4039
+ _SORT._serialized_start = 4042
+ _SORT._serialized_end = 4592
+ _SORT_SORTFIELD._serialized_start = 4196
+ _SORT_SORTFIELD._serialized_end = 4384
+ _SORT_SORTDIRECTION._serialized_start = 4386
+ _SORT_SORTDIRECTION._serialized_end = 4494
+ _SORT_SORTNULLS._serialized_start = 4496
+ _SORT_SORTNULLS._serialized_end = 4578
+ _DROP._serialized_start = 4594
+ _DROP._serialized_end = 4694
+ _DEDUPLICATE._serialized_start = 4697
+ _DEDUPLICATE._serialized_end = 4868
+ _LOCALRELATION._serialized_start = 4871
+ _LOCALRELATION._serialized_end = 5008
+ _SAMPLE._serialized_start = 5011
+ _SAMPLE._serialized_end = 5235
+ _RANGE._serialized_start = 5238
+ _RANGE._serialized_end = 5383
+ _SUBQUERYALIAS._serialized_start = 5385
+ _SUBQUERYALIAS._serialized_end = 5499
+ _REPARTITION._serialized_start = 5502
+ _REPARTITION._serialized_end = 5644
+ _SHOWSTRING._serialized_start = 5647
+ _SHOWSTRING._serialized_end = 5788
+ _STATSUMMARY._serialized_start = 5790
+ _STATSUMMARY._serialized_end = 5882
+ _STATDESCRIBE._serialized_start = 5884
+ _STATDESCRIBE._serialized_end = 5965
+ _STATCROSSTAB._serialized_start = 5967
+ _STATCROSSTAB._serialized_end = 6068
+ _NAFILL._serialized_start = 6071
+ _NAFILL._serialized_end = 6205
+ _NADROP._serialized_start = 6208
+ _NADROP._serialized_end = 6342
+ _NAREPLACE._serialized_start = 6345
+ _NAREPLACE._serialized_end = 6641
+ _NAREPLACE_REPLACEMENT._serialized_start = 6500
+ _NAREPLACE_REPLACEMENT._serialized_end = 6641
+ _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 6643
+ _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 6757
+ _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 6760
+ _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 7019
+ _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 6952
+ _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 7019
+ _WITHCOLUMNS._serialized_start = 7022
+ _WITHCOLUMNS._serialized_end = 7153
+ _HINT._serialized_start = 7156
+ _HINT._serialized_end = 7296
+ _UNPIVOT._serialized_start = 7299
+ _UNPIVOT._serialized_end = 7545
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index f133661368..e942a63629 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -40,6 +40,7 @@ import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import pyspark.sql.connect.proto.expressions_pb2
+import pyspark.sql.connect.proto.types_pb2
import sys
import typing
@@ -83,6 +84,7 @@ class Relation(google.protobuf.message.Message):
TAIL_FIELD_NUMBER: builtins.int
WITH_COLUMNS_FIELD_NUMBER: builtins.int
HINT_FIELD_NUMBER: builtins.int
+ UNPIVOT_FIELD_NUMBER: builtins.int
FILL_NA_FIELD_NUMBER: builtins.int
DROP_NA_FIELD_NUMBER: builtins.int
REPLACE_FIELD_NUMBER: builtins.int
@@ -139,6 +141,8 @@ class Relation(google.protobuf.message.Message):
@property
def hint(self) -> global___Hint: ...
@property
+ def unpivot(self) -> global___Unpivot: ...
+ @property
def fill_na(self) -> global___NAFill:
"""NA functions"""
@property
@@ -181,6 +185,7 @@ class Relation(google.protobuf.message.Message):
tail: global___Tail | None = ...,
with_columns: global___WithColumns | None = ...,
hint: global___Hint | None = ...,
+ unpivot: global___Unpivot | None = ...,
fill_na: global___NAFill | None = ...,
drop_na: global___NADrop | None = ...,
replace: global___NAReplace | None = ...,
@@ -254,6 +259,8 @@ class Relation(google.protobuf.message.Message):
b"tail",
"unknown",
b"unknown",
+ "unpivot",
+ b"unpivot",
"with_columns",
b"with_columns",
],
@@ -323,6 +330,8 @@ class Relation(google.protobuf.message.Message):
b"tail",
"unknown",
b"unknown",
+ "unpivot",
+ b"unpivot",
"with_columns",
b"with_columns",
],
@@ -353,6 +362,7 @@ class Relation(google.protobuf.message.Message):
"tail",
"with_columns",
"hint",
+ "unpivot",
"fill_na",
"drop_na",
"replace",
@@ -1159,16 +1169,45 @@ class LocalRelation(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
DATA_FIELD_NUMBER: builtins.int
+ DATATYPE_FIELD_NUMBER: builtins.int
+ DATATYPE_STR_FIELD_NUMBER: builtins.int
data: builtins.bytes
"""Local collection data serialized into Arrow IPC streaming format which contains
the schema of the data.
"""
+ @property
+ def datatype(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
+ datatype_str: builtins.str
+ """Server will use Catalyst parser to parse this string to DataType."""
def __init__(
self,
*,
data: builtins.bytes = ...,
+ datatype: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
+ datatype_str: builtins.str = ...,
) -> None: ...
- def ClearField(self, field_name: typing_extensions.Literal["data", b"data"]) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "datatype", b"datatype", "datatype_str", b"datatype_str", "schema", b"schema"
+ ],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "data",
+ b"data",
+ "datatype",
+ b"datatype",
+ "datatype_str",
+ b"datatype_str",
+ "schema",
+ b"schema",
+ ],
+ ) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["schema", b"schema"]
+ ) -> typing_extensions.Literal["datatype", "datatype_str"] | None: ...
global___LocalRelation = LocalRelation
@@ -1963,3 +2002,66 @@ class Hint(google.protobuf.message.Message):
) -> None: ...
global___Hint = Hint
+
+class Unpivot(google.protobuf.message.Message):
+ """Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set."""
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ INPUT_FIELD_NUMBER: builtins.int
+ IDS_FIELD_NUMBER: builtins.int
+ VALUES_FIELD_NUMBER: builtins.int
+ VARIABLE_COLUMN_NAME_FIELD_NUMBER: builtins.int
+ VALUE_COLUMN_NAME_FIELD_NUMBER: builtins.int
+ @property
+ def input(self) -> global___Relation:
+ """(Required) The input relation."""
+ @property
+ def ids(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+ pyspark.sql.connect.proto.expressions_pb2.Expression
+ ]:
+ """(Required) Id columns."""
+ @property
+ def values(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+ pyspark.sql.connect.proto.expressions_pb2.Expression
+ ]:
+ """(Optional) Value columns to unpivot."""
+ variable_column_name: builtins.str
+ """(Required) Name of the variable column."""
+ value_column_name: builtins.str
+ """(Required) Name of the value column."""
+ def __init__(
+ self,
+ *,
+ input: global___Relation | None = ...,
+ ids: collections.abc.Iterable[pyspark.sql.connect.proto.expressions_pb2.Expression]
+ | None = ...,
+ values: collections.abc.Iterable[pyspark.sql.connect.proto.expressions_pb2.Expression]
+ | None = ...,
+ variable_column_name: builtins.str = ...,
+ value_column_name: builtins.str = ...,
+ ) -> None: ...
+ def HasField(
+ self, field_name: typing_extensions.Literal["input", b"input"]
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "ids",
+ b"ids",
+ "input",
+ b"input",
+ "value_column_name",
+ b"value_column_name",
+ "values",
+ b"values",
+ "variable_column_name",
+ b"variable_column_name",
+ ],
+ ) -> None: ...
+
+global___Unpivot = Unpivot
diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py
index 45239a2fa2..778509bcf7 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -50,7 +50,7 @@ def _set_opts(
self.option(k, v) # type: ignore[attr-defined]
-class DataFrameReader:
+class DataFrameReader(OptionUtils):
"""
TODO(SPARK-40539) Achieve parity with PySpark.
"""
@@ -164,7 +164,6 @@ def load(
return self._df(plan)
def _df(self, plan: LogicalPlan) -> "DataFrame":
- # The import is needed here to avoid circular import issues.
from pyspark.sql.connect.dataframe import DataFrame
return DataFrame.withPlan(plan, self._client)
@@ -172,6 +171,164 @@ def _df(self, plan: LogicalPlan) -> "DataFrame":
def table(self, tableName: str) -> "DataFrame":
return self._df(Read(tableName))
+ def json(
+ self,
+ path: str,
+ schema: Optional[str] = None,
+ primitivesAsString: Optional[Union[bool, str]] = None,
+ prefersDecimal: Optional[Union[bool, str]] = None,
+ allowComments: Optional[Union[bool, str]] = None,
+ allowUnquotedFieldNames: Optional[Union[bool, str]] = None,
+ allowSingleQuotes: Optional[Union[bool, str]] = None,
+ allowNumericLeadingZero: Optional[Union[bool, str]] = None,
+ allowBackslashEscapingAnyCharacter: Optional[Union[bool, str]] = None,
+ mode: Optional[str] = None,
+ columnNameOfCorruptRecord: Optional[str] = None,
+ dateFormat: Optional[str] = None,
+ timestampFormat: Optional[str] = None,
+ multiLine: Optional[Union[bool, str]] = None,
+ allowUnquotedControlChars: Optional[Union[bool, str]] = None,
+ lineSep: Optional[str] = None,
+ samplingRatio: Optional[Union[float, str]] = None,
+ dropFieldIfAllNull: Optional[Union[bool, str]] = None,
+ encoding: Optional[str] = None,
+ locale: Optional[str] = None,
+ pathGlobFilter: Optional[Union[bool, str]] = None,
+ recursiveFileLookup: Optional[Union[bool, str]] = None,
+ modifiedBefore: Optional[Union[bool, str]] = None,
+ modifiedAfter: Optional[Union[bool, str]] = None,
+ allowNonNumericNumbers: Optional[Union[bool, str]] = None,
+ ) -> "DataFrame":
+ """
+ Loads JSON files and returns the results as a :class:`DataFrame`.
+
+ `JSON Lines `_ (newline-delimited JSON) is supported by default.
+ For JSON (one record per file), set the ``multiLine`` parameter to ``true``.
+
+ If the ``schema`` parameter is not specified, this function goes
+ through the input once to determine the input schema.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ path : str
+ string represents path to the JSON dataset
+ schema : str, optional
+ a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
+
+ Other Parameters
+ ----------------
+ Extra options
+ For the extra options, refer to
+ `Data Source Option `_
+ for the version you use.
+
+ .. # noqa
+
+ Examples
+ --------
+ Write a DataFrame into a JSON file and read it back.
+
+ >>> import tempfile
+ >>> with tempfile.TemporaryDirectory() as d:
+ ... # Write a DataFrame into a JSON file
+ ... spark.createDataFrame(
+ ... [{"age": 100, "name": "Hyukjin Kwon"}]
+ ... ).write.mode("overwrite").format("json").save(d)
+ ...
+ ... # Read the JSON file as a DataFrame.
+ ... spark.read.json(d).show()
+ +---+------------+
+ |age| name|
+ +---+------------+
+ |100|Hyukjin Kwon|
+ +---+------------+
+ """
+ self._set_opts(
+ primitivesAsString=primitivesAsString,
+ prefersDecimal=prefersDecimal,
+ allowComments=allowComments,
+ allowUnquotedFieldNames=allowUnquotedFieldNames,
+ allowSingleQuotes=allowSingleQuotes,
+ allowNumericLeadingZero=allowNumericLeadingZero,
+ allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
+ mode=mode,
+ columnNameOfCorruptRecord=columnNameOfCorruptRecord,
+ dateFormat=dateFormat,
+ timestampFormat=timestampFormat,
+ multiLine=multiLine,
+ allowUnquotedControlChars=allowUnquotedControlChars,
+ lineSep=lineSep,
+ samplingRatio=samplingRatio,
+ dropFieldIfAllNull=dropFieldIfAllNull,
+ encoding=encoding,
+ locale=locale,
+ pathGlobFilter=pathGlobFilter,
+ recursiveFileLookup=recursiveFileLookup,
+ modifiedBefore=modifiedBefore,
+ modifiedAfter=modifiedAfter,
+ allowNonNumericNumbers=allowNonNumericNumbers,
+ )
+ return self.load(path=path, format="json", schema=schema)
+
+ def parquet(self, path: str, **options: "OptionalPrimitiveType") -> "DataFrame":
+ """
+ Loads Parquet files, returning the result as a :class:`DataFrame`.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ path : str
+
+ Other Parameters
+ ----------------
+ **options
+ For the extra options, refer to
+ `Data Source Option `_
+ for the version you use.
+
+ .. # noqa
+
+ Examples
+ --------
+ Write a DataFrame into a Parquet file and read it back.
+
+ >>> import tempfile
+ >>> with tempfile.TemporaryDirectory() as d:
+ ... # Write a DataFrame into a Parquet file
+ ... spark.createDataFrame(
+ ... [{"age": 100, "name": "Hyukjin Kwon"}]
+ ... ).write.mode("overwrite").format("parquet").save(d)
+ ...
+ ... # Read the Parquet file as a DataFrame.
+ ... spark.read.parquet(d).show()
+ +---+------------+
+ |age| name|
+ +---+------------+
+ |100|Hyukjin Kwon|
+ +---+------------+
+ """
+ mergeSchema = options.get("mergeSchema", None)
+ pathGlobFilter = options.get("pathGlobFilter", None)
+ modifiedBefore = options.get("modifiedBefore", None)
+ modifiedAfter = options.get("modifiedAfter", None)
+ recursiveFileLookup = options.get("recursiveFileLookup", None)
+ datetimeRebaseMode = options.get("datetimeRebaseMode", None)
+ int96RebaseMode = options.get("int96RebaseMode", None)
+ self._set_opts(
+ mergeSchema=mergeSchema,
+ pathGlobFilter=pathGlobFilter,
+ recursiveFileLookup=recursiveFileLookup,
+ modifiedBefore=modifiedBefore,
+ modifiedAfter=modifiedAfter,
+ datetimeRebaseMode=datetimeRebaseMode,
+ int96RebaseMode=int96RebaseMode,
+ )
+
+ return self.load(path=path, format="parquet")
+
class DataFrameWriter(OptionUtils):
"""
diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py
index 28aebbdecb..0a3d03110f 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -16,17 +16,35 @@
#
from threading import RLock
-from typing import Optional, Any, Union, Dict, cast, overload
+from collections.abc import Sized
+
+import numpy as np
import pandas as pd
+import pyarrow as pa
+
+from pyspark.sql.types import DataType, StructType
-import pyspark.sql.types
from pyspark.sql.connect.client import SparkConnectClient
from pyspark.sql.connect.dataframe import DataFrame
-from pyspark.sql.connect.plan import SQL, Range
+from pyspark.sql.connect.plan import SQL, Range, LocalRelation
from pyspark.sql.connect.readwriter import DataFrameReader
from pyspark.sql.utils import to_str
-from . import plan
-from ._typing import OptionalPrimitiveType
+
+from typing import (
+ Optional,
+ Any,
+ Union,
+ Dict,
+ List,
+ Tuple,
+ cast,
+ overload,
+ Iterable,
+ TYPE_CHECKING,
+)
+
+if TYPE_CHECKING:
+ from pyspark.sql.connect._typing import OptionalPrimitiveType
# TODO(SPARK-38912): This method can be dropped once support for Python 3.8 is dropped
@@ -240,7 +258,11 @@ def read(self) -> "DataFrameReader":
"""
return DataFrameReader(self)
- def createDataFrame(self, data: "pd.DataFrame") -> "DataFrame":
+ def createDataFrame(
+ self,
+ data: Union["pd.DataFrame", "np.ndarray", Iterable[Any]],
+ schema: Optional[Union[StructType, str, List[str], Tuple[str, ...]]] = None,
+ ) -> "DataFrame":
"""
Creates a :class:`DataFrame` from a :class:`pandas.DataFrame`.
@@ -249,7 +271,15 @@ def createDataFrame(self, data: "pd.DataFrame") -> "DataFrame":
Parameters
----------
- data : :class:`pandas.DataFrame`
+ data : :class:`pandas.DataFrame` or :class:`list`, or :class:`numpy.ndarray`.
+ schema : :class:`pyspark.sql.types.DataType`, str or list, optional
+
+ When ``schema`` is :class:`pyspark.sql.types.DataType` or a datatype string, it must
+ match the real data, or an exception will be thrown at runtime. If the given schema is
+ not :class:`pyspark.sql.types.StructType`, it will be wrapped into a
+ :class:`pyspark.sql.types.StructType` as its only field, and the field name will be
+ "value". Each record will also be wrapped into a tuple, which can be converted to row
+ later.
Returns
-------
@@ -264,9 +294,71 @@ def createDataFrame(self, data: "pd.DataFrame") -> "DataFrame":
"""
assert data is not None
- if len(data) == 0:
+ if isinstance(data, DataFrame):
+ raise TypeError("data is already a DataFrame")
+ if isinstance(data, Sized) and len(data) == 0:
raise ValueError("Input data cannot be empty")
- return DataFrame.withPlan(plan.LocalRelation(data), self)
+
+ _schema: Optional[StructType] = None
+ _schema_str: Optional[str] = None
+ _cols: Optional[List[str]] = None
+
+ if isinstance(schema, StructType):
+ _schema = schema
+
+ elif isinstance(schema, str):
+ _schema_str = schema
+
+ elif isinstance(schema, (list, tuple)):
+ # Must re-encode any unicode strings to be consistent with StructField names
+ _cols = [x.encode("utf-8") if not isinstance(x, str) else x for x in schema]
+
+ # Create the Pandas DataFrame
+ if isinstance(data, pd.DataFrame):
+ pdf = data
+
+ elif isinstance(data, np.ndarray):
+ # `data` of numpy.ndarray type will be converted to a pandas DataFrame,
+ if data.ndim not in [1, 2]:
+ raise ValueError("NumPy array input should be of 1 or 2 dimensions.")
+
+ pdf = pd.DataFrame(data)
+
+ if _cols is None:
+ if data.ndim == 1 or data.shape[1] == 1:
+ _cols = ["value"]
+ else:
+ _cols = ["_%s" % i for i in range(1, data.shape[1] + 1)]
+
+ else:
+ pdf = pd.DataFrame(list(data))
+
+ if _cols is None:
+ _cols = ["_%s" % i for i in range(1, pdf.shape[1] + 1)]
+
+ # Validate number of columns
+ num_cols = pdf.shape[1]
+ if _schema is not None and len(_schema.fields) != num_cols:
+ raise ValueError(
+ f"Length mismatch: Expected axis has {num_cols} elements, "
+ f"new values have {len(_schema.fields)} elements"
+ )
+ elif _cols is not None and len(_cols) != num_cols:
+ raise ValueError(
+ f"Length mismatch: Expected axis has {num_cols} elements, "
+ f"new values have {len(_cols)} elements"
+ )
+
+ table = pa.Table.from_pandas(pdf)
+
+ if _schema is not None:
+ return DataFrame.withPlan(LocalRelation(table, schema=_schema), self)
+ elif _schema_str is not None:
+ return DataFrame.withPlan(LocalRelation(table, schema=_schema_str), self)
+ elif _cols is not None and len(_cols) > 0:
+ return DataFrame.withPlan(LocalRelation(table), self).toDF(*_cols)
+ else:
+ return DataFrame.withPlan(LocalRelation(table), self)
@property
def client(self) -> "SparkConnectClient":
@@ -279,9 +371,7 @@ def client(self) -> "SparkConnectClient":
"""
return self._client
- def register_udf(
- self, function: Any, return_type: Union[str, pyspark.sql.types.DataType]
- ) -> str:
+ def register_udf(self, function: Any, return_type: Union[str, DataType]) -> str:
return self._client.register_udf(function, return_type)
def sql(self, sql_string: str) -> "DataFrame":
diff --git a/python/pyspark/sql/connect/types.py b/python/pyspark/sql/connect/types.py
new file mode 100644
index 0000000000..55f5953660
--- /dev/null
+++ b/python/pyspark/sql/connect/types.py
@@ -0,0 +1,143 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Optional
+
+import pyspark.sql.connect.proto as pb2
+from pyspark.sql.types import (
+ DataType,
+ ByteType,
+ ShortType,
+ IntegerType,
+ FloatType,
+ DateType,
+ TimestampType,
+ DayTimeIntervalType,
+ MapType,
+ StringType,
+ CharType,
+ VarcharType,
+ StructType,
+ StructField,
+ ArrayType,
+ DoubleType,
+ LongType,
+ DecimalType,
+ BinaryType,
+ BooleanType,
+ NullType,
+)
+
+
+def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType:
+ ret = pb2.DataType()
+ if isinstance(data_type, StringType):
+ ret.string.CopyFrom(pb2.DataType.String())
+ elif isinstance(data_type, BooleanType):
+ ret.boolean.CopyFrom(pb2.DataType.Boolean())
+ elif isinstance(data_type, BinaryType):
+ ret.binary.CopyFrom(pb2.DataType.Binary())
+ elif isinstance(data_type, ByteType):
+ ret.byte.CopyFrom(pb2.DataType.Byte())
+ elif isinstance(data_type, ShortType):
+ ret.short.CopyFrom(pb2.DataType.Short())
+ elif isinstance(data_type, IntegerType):
+ ret.integer.CopyFrom(pb2.DataType.Integer())
+ elif isinstance(data_type, LongType):
+ ret.long.CopyFrom(pb2.DataType.Long())
+ elif isinstance(data_type, FloatType):
+ ret.float.CopyFrom(pb2.DataType.Float())
+ elif isinstance(data_type, DoubleType):
+ ret.double.CopyFrom(pb2.DataType.Double())
+ elif isinstance(data_type, DecimalType):
+ ret.decimal.CopyFrom(pb2.DataType.Decimal())
+ elif isinstance(data_type, DayTimeIntervalType):
+ ret.day_time_interval.start_field = data_type.startField
+ ret.day_time_interval.end_field = data_type.endField
+ else:
+ raise Exception(f"Unsupported data type {data_type}")
+ return ret
+
+
+def proto_schema_to_pyspark_data_type(schema: pb2.DataType) -> DataType:
+ if schema.HasField("null"):
+ return NullType()
+ elif schema.HasField("boolean"):
+ return BooleanType()
+ elif schema.HasField("binary"):
+ return BinaryType()
+ elif schema.HasField("byte"):
+ return ByteType()
+ elif schema.HasField("short"):
+ return ShortType()
+ elif schema.HasField("integer"):
+ return IntegerType()
+ elif schema.HasField("long"):
+ return LongType()
+ elif schema.HasField("float"):
+ return FloatType()
+ elif schema.HasField("double"):
+ return DoubleType()
+ elif schema.HasField("decimal"):
+ p = schema.decimal.precision if schema.decimal.HasField("precision") else 10
+ s = schema.decimal.scale if schema.decimal.HasField("scale") else 0
+ return DecimalType(precision=p, scale=s)
+ elif schema.HasField("string"):
+ return StringType()
+ elif schema.HasField("char"):
+ return CharType(schema.char.length)
+ elif schema.HasField("var_char"):
+ return VarcharType(schema.var_char.length)
+ elif schema.HasField("date"):
+ return DateType()
+ elif schema.HasField("timestamp"):
+ return TimestampType()
+ elif schema.HasField("day_time_interval"):
+ start: Optional[int] = (
+ schema.day_time_interval.start_field
+ if schema.day_time_interval.HasField("start_field")
+ else None
+ )
+ end: Optional[int] = (
+ schema.day_time_interval.end_field
+ if schema.day_time_interval.HasField("end_field")
+ else None
+ )
+ return DayTimeIntervalType(startField=start, endField=end)
+ elif schema.HasField("array"):
+ return ArrayType(
+ proto_schema_to_pyspark_data_type(schema.array.element_type),
+ schema.array.contains_null,
+ )
+ elif schema.HasField("struct"):
+ fields = [
+ StructField(
+ f.name,
+ proto_schema_to_pyspark_data_type(f.data_type),
+ f.nullable,
+ )
+ for f in schema.struct.fields
+ ]
+ return StructType(fields)
+ elif schema.HasField("map"):
+ return MapType(
+ proto_schema_to_pyspark_data_type(schema.map.key_type),
+ proto_schema_to_pyspark_data_type(schema.map.value_type),
+ schema.map.value_contains_null,
+ )
+ else:
+ raise Exception(f"Unsupported data type {schema}")
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 9746196dc9..de540c6249 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -5560,11 +5560,11 @@ def decode(col: "ColumnOrName", charset: str) -> Column:
--------
>>> df = spark.createDataFrame([('abcd',)], ['a'])
>>> df.select(decode("a", "UTF-8")).show()
- +----------------------+
- |stringdecode(a, UTF-8)|
- +----------------------+
- | abcd|
- +----------------------+
+ +----------------+
+ |decode(a, UTF-8)|
+ +----------------+
+ | abcd|
+ +----------------+
"""
return _invoke_function("decode", _to_java_column(col), charset)
@@ -8036,7 +8036,7 @@ def sequence(
def from_csv(
col: "ColumnOrName",
- schema: Union[StructType, Column, str],
+ schema: Union[Column, str],
options: Optional[Dict[str, str]] = None,
) -> Column:
"""
diff --git a/python/pyspark/sql/pandas/utils.py b/python/pyspark/sql/pandas/utils.py
index c51a90ca57..ab62b955b6 100644
--- a/python/pyspark/sql/pandas/utils.py
+++ b/python/pyspark/sql/pandas/utils.py
@@ -73,6 +73,25 @@ def require_minimum_pyarrow_version() -> None:
)
+def require_minimum_grpc_version() -> None:
+ """Raise ImportError if minimum version of grpc is not installed"""
+ minimum_pandas_version = "1.48.1"
+
+ from distutils.version import LooseVersion
+
+ try:
+ import grpc
+ except ImportError as error:
+ raise ImportError(
+ "grpc >= %s must be installed; however, " "it was not found." % minimum_pandas_version
+ ) from error
+ if LooseVersion(grpc.__version__) < LooseVersion(minimum_pandas_version):
+ raise ImportError(
+ "Pandas >= %s must be installed; however, "
+ "your version was %s." % (minimum_pandas_version, grpc.__version__)
+ )
+
+
def pyarrow_version_less_than_minimum(minimum_pyarrow_version: str) -> bool:
"""Return False if the installed pyarrow version is less than minimum_pyarrow_version
or if pyarrow is not installed."""
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 98150731c2..6dabbaedff 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -14,38 +14,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-from typing import Any
import unittest
import shutil
import tempfile
-import grpc # type: ignore
-
-from pyspark.sql.connect.column import Column
-from pyspark.testing.sqlutils import have_pandas, SQLTestUtils
-
-if have_pandas:
- import pandas
-
+from pyspark.testing.sqlutils import SQLTestUtils
from pyspark.sql import SparkSession, Row
from pyspark.sql.types import StructType, StructField, LongType, StringType
+import pyspark.sql.functions
+from pyspark.testing.utils import ReusedPySparkTestCase
+from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
-if have_pandas:
+if should_test_connect:
+ import grpc
+ import pandas as pd
+ import numpy as np
from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
from pyspark.sql.connect.client import ChannelBuilder
+ from pyspark.sql.connect.column import Column
from pyspark.sql.connect.dataframe import DataFrame as CDataFrame
from pyspark.sql.connect.function_builder import udf
from pyspark.sql.connect.functions import lit, col
- from pyspark.testing.pandasutils import PandasOnSparkTestCase
-else:
- from pyspark.testing.sqlutils import ReusedSQLTestCase as PandasOnSparkTestCase # type: ignore
-from pyspark.sql.dataframe import DataFrame
-import pyspark.sql.functions
-from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
-from pyspark.testing.utils import ReusedPySparkTestCase
-
-
-import tempfile
@unittest.skipIf(not should_test_connect, connect_requirement_message)
@@ -53,15 +43,8 @@ class SparkConnectSQLTestCase(PandasOnSparkTestCase, ReusedPySparkTestCase, SQLT
"""Parent test fixture class for all Spark Connect related
test cases."""
- if have_pandas:
- connect: RemoteSparkSession
- tbl_name: str
- tbl_name_empty: str
- df_text: "DataFrame"
- spark: SparkSession
-
@classmethod
- def setUpClass(cls: Any):
+ def setUpClass(cls):
ReusedPySparkTestCase.setUpClass()
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
cls.hive_available = True
@@ -82,12 +65,12 @@ def setUpClass(cls: Any):
cls.spark_connect_load_test_data()
@classmethod
- def tearDownClass(cls: Any) -> None:
+ def tearDownClass(cls):
cls.spark_connect_clean_up_test_data()
ReusedPySparkTestCase.tearDownClass()
@classmethod
- def spark_connect_load_test_data(cls: Any):
+ def spark_connect_load_test_data(cls):
# Setup Remote Spark Session
cls.connect = RemoteSparkSession.builder.remote().getOrCreate()
df = cls.spark.createDataFrame([(x, f"{x}") for x in range(100)], ["id", "name"])
@@ -108,7 +91,7 @@ def spark_connect_load_test_data(cls: Any):
empty_df.write.saveAsTable(cls.tbl_name_empty)
@classmethod
- def spark_connect_clean_up_test_data(cls: Any) -> None:
+ def spark_connect_clean_up_test_data(cls):
cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name))
cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name2))
cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name_empty))
@@ -121,6 +104,35 @@ def test_simple_read(self):
# Check that the limit is applied
self.assertEqual(len(data.index), 10)
+ def test_json(self):
+ with tempfile.TemporaryDirectory() as d:
+ # Write a DataFrame into a JSON file
+ self.spark.createDataFrame([{"age": 100, "name": "Hyukjin Kwon"}]).write.mode(
+ "overwrite"
+ ).format("json").save(d)
+ # Read the JSON file as a DataFrame.
+ self.assert_eq(self.connect.read.json(d).toPandas(), self.spark.read.json(d).toPandas())
+ self.assert_eq(
+ self.connect.read.json(path=d, schema="age INT, name STRING").toPandas(),
+ self.spark.read.json(path=d, schema="age INT, name STRING").toPandas(),
+ )
+ self.assert_eq(
+ self.connect.read.json(path=d, primitivesAsString=True).toPandas(),
+ self.spark.read.json(path=d, primitivesAsString=True).toPandas(),
+ )
+
+ def test_paruqet(self):
+ # SPARK-41445: Implement DataFrameReader.paruqet
+ with tempfile.TemporaryDirectory() as d:
+ # Write a DataFrame into a JSON file
+ self.spark.createDataFrame([{"age": 100, "name": "Hyukjin Kwon"}]).write.mode(
+ "overwrite"
+ ).format("parquet").save(d)
+ # Read the Parquet file as a DataFrame.
+ self.assert_eq(
+ self.connect.read.parquet(d).toPandas(), self.spark.read.parquet(d).toPandas()
+ )
+
def test_join_condition_column_list_columns(self):
left_connect_df = self.connect.read.table(self.tbl_name)
right_connect_df = self.connect.read.table(self.tbl_name2)
@@ -183,7 +195,7 @@ def conv_udf(x) -> str:
def test_with_local_data(self):
"""SPARK-41114: Test creating a dataframe using local data"""
- pdf = pandas.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]})
+ pdf = pd.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]})
df = self.connect.createDataFrame(pdf)
rows = df.filter(df.a == lit(3)).collect()
self.assertTrue(len(rows) == 1)
@@ -191,10 +203,94 @@ def test_with_local_data(self):
self.assertEqual(rows[0][1], "c")
# Check correct behavior for empty DataFrame
- pdf = pandas.DataFrame({"a": []})
+ pdf = pd.DataFrame({"a": []})
with self.assertRaises(ValueError):
self.connect.createDataFrame(pdf)
+ def test_with_local_ndarray(self):
+ """SPARK-41446: Test creating a dataframe using local list"""
+ data = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
+
+ sdf = self.spark.createDataFrame(data)
+ cdf = self.connect.createDataFrame(data)
+ self.assertEqual(sdf.schema, cdf.schema)
+ self.assert_eq(sdf.toPandas(), cdf.toPandas())
+
+ # TODO: add cases for StructType after 'pyspark_types_to_proto_types' support StructType
+ for schema in [
+ "struct",
+ "col1 int, col2 int, col3 int, col4 int",
+ "col1 int, col2 long, col3 string, col4 long",
+ "col1 int, col2 string, col3 short, col4 long",
+ ["a", "b", "c", "d"],
+ ("x1", "x2", "x3", "x4"),
+ ]:
+ sdf = self.spark.createDataFrame(data, schema=schema)
+ cdf = self.connect.createDataFrame(data, schema=schema)
+
+ self.assertEqual(sdf.schema, cdf.schema)
+ self.assert_eq(sdf.toPandas(), cdf.toPandas())
+
+ with self.assertRaisesRegex(
+ ValueError,
+ "Length mismatch: Expected axis has 4 elements, new values have 5 elements",
+ ):
+ self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"])
+
+ with self.assertRaises(grpc.RpcError):
+ self.connect.createDataFrame(
+ data, "col1 magic_type, col2 int, col3 int, col4 int"
+ ).show()
+
+ with self.assertRaises(grpc.RpcError):
+ self.connect.createDataFrame(data, "col1 int, col2 int, col3 int").show()
+
+ def test_with_local_list(self):
+ """SPARK-41446: Test creating a dataframe using local list"""
+ data = [[1, 2, 3, 4]]
+
+ sdf = self.spark.createDataFrame(data)
+ cdf = self.connect.createDataFrame(data)
+ self.assertEqual(sdf.schema, cdf.schema)
+ self.assert_eq(sdf.toPandas(), cdf.toPandas())
+
+ for schema in [
+ "struct",
+ "col1 int, col2 int, col3 int, col4 int",
+ "col1 int, col2 long, col3 string, col4 long",
+ "col1 int, col2 string, col3 short, col4 long",
+ ["a", "b", "c", "d"],
+ ("x1", "x2", "x3", "x4"),
+ ]:
+ sdf = self.spark.createDataFrame(data, schema=schema)
+ cdf = self.connect.createDataFrame(data, schema=schema)
+
+ self.assertEqual(sdf.schema, cdf.schema)
+ self.assert_eq(sdf.toPandas(), cdf.toPandas())
+
+ with self.assertRaisesRegex(
+ ValueError,
+ "Length mismatch: Expected axis has 4 elements, new values have 5 elements",
+ ):
+ self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"])
+
+ with self.assertRaises(grpc.RpcError):
+ self.connect.createDataFrame(
+ data, "col1 magic_type, col2 int, col3 int, col4 int"
+ ).show()
+
+ with self.assertRaises(grpc.RpcError):
+ self.connect.createDataFrame(data, "col1 int, col2 int, col3 int").show()
+
+ def test_with_atom_type(self):
+ for data in [[(1), (2), (3)], [1, 2, 3]]:
+ for schema in ["long", "int", "short"]:
+ sdf = self.spark.createDataFrame(data, schema=schema)
+ cdf = self.connect.createDataFrame(data, schema=schema)
+
+ self.assertEqual(sdf.schema, cdf.schema)
+ self.assert_eq(sdf.toPandas(), cdf.toPandas())
+
def test_simple_explain_string(self):
df = self.connect.read.table(self.tbl_name).limit(10)
result = df._explain_string()
@@ -687,6 +783,29 @@ def test_replace(self):
"""Cannot resolve column name "x" among (a, b, c)""", str(context.exception)
)
+ def test_unpivot(self):
+ self.assert_eq(
+ self.connect.read.table(self.tbl_name)
+ .filter("id > 3")
+ .unpivot(["id"], ["name"], "variable", "value")
+ .toPandas(),
+ self.spark.read.table(self.tbl_name)
+ .filter("id > 3")
+ .unpivot(["id"], ["name"], "variable", "value")
+ .toPandas(),
+ )
+
+ self.assert_eq(
+ self.connect.read.table(self.tbl_name)
+ .filter("id > 3")
+ .unpivot("id", None, "variable", "value")
+ .toPandas(),
+ self.spark.read.table(self.tbl_name)
+ .filter("id > 3")
+ .unpivot("id", None, "variable", "value")
+ .toPandas(),
+ )
+
def test_with_columns(self):
# SPARK-41256: test withColumn(s).
self.assert_eq(
@@ -955,7 +1074,7 @@ def test_metadata(self):
from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401
try:
- import xmlrunner # type: ignore
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py
index 106ab609bf..e670123199 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column.py
@@ -14,13 +14,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase
-from pyspark.testing.sqlutils import have_pandas
+from pyspark.sql.types import StringType
+from pyspark.sql.types import (
+ ByteType,
+ ShortType,
+ IntegerType,
+ FloatType,
+ DayTimeIntervalType,
+ StringType,
+ DoubleType,
+ LongType,
+ DecimalType,
+ BinaryType,
+ BooleanType,
+)
+from pyspark.testing.connectutils import should_test_connect
-if have_pandas:
+if should_test_connect:
+ import pandas as pd
from pyspark.sql.connect.functions import lit
- import pandas
class SparkConnectTests(SparkConnectSQLTestCase):
@@ -74,11 +87,116 @@ def test_columns(self):
def test_simple_binary_expressions(self):
"""Test complex expression"""
df = self.connect.read.table(self.tbl_name)
- pd = df.select(df.id).where(df.id % lit(30) == lit(0)).sort(df.id.asc()).toPandas()
- self.assertEqual(len(pd.index), 4)
+ pdf = df.select(df.id).where(df.id % lit(30) == lit(0)).sort(df.id.asc()).toPandas()
+ self.assertEqual(len(pdf.index), 4)
+
+ res = pd.DataFrame(data={"id": [0, 30, 60, 90]})
+ self.assert_(pdf.equals(res), f"{pdf.to_string()} != {res.to_string()}")
+
+ def test_literal_integers(self):
+ cdf = self.connect.range(0, 1)
+ sdf = self.spark.range(0, 1)
+
+ from pyspark.sql import functions as SF
+ from pyspark.sql.connect import functions as CF
+ from pyspark.sql.connect.column import JVM_INT_MIN, JVM_INT_MAX, JVM_LONG_MIN, JVM_LONG_MAX
+
+ cdf1 = cdf.select(
+ CF.lit(0),
+ CF.lit(1),
+ CF.lit(-1),
+ CF.lit(JVM_INT_MAX),
+ CF.lit(JVM_INT_MIN),
+ CF.lit(JVM_INT_MAX + 1),
+ CF.lit(JVM_INT_MIN - 1),
+ CF.lit(JVM_LONG_MAX),
+ CF.lit(JVM_LONG_MIN),
+ CF.lit(JVM_LONG_MAX - 1),
+ CF.lit(JVM_LONG_MIN + 1),
+ )
+
+ sdf1 = sdf.select(
+ SF.lit(0),
+ SF.lit(1),
+ SF.lit(-1),
+ SF.lit(JVM_INT_MAX),
+ SF.lit(JVM_INT_MIN),
+ SF.lit(JVM_INT_MAX + 1),
+ SF.lit(JVM_INT_MIN - 1),
+ SF.lit(JVM_LONG_MAX),
+ SF.lit(JVM_LONG_MIN),
+ SF.lit(JVM_LONG_MAX - 1),
+ SF.lit(JVM_LONG_MIN + 1),
+ )
+
+ self.assertEqual(cdf1.schema, sdf1.schema)
+ self.assert_eq(cdf1.toPandas(), sdf1.toPandas())
+
+ with self.assertRaisesRegex(
+ ValueError,
+ "integer 9223372036854775808 out of bounds",
+ ):
+ cdf.select(CF.lit(JVM_LONG_MAX + 1)).show()
+
+ with self.assertRaisesRegex(
+ ValueError,
+ "integer -9223372036854775809 out of bounds",
+ ):
+ cdf.select(CF.lit(JVM_LONG_MIN - 1)).show()
+
+ def test_cast(self):
+ # SPARK-41412: test basic Column.cast
+ df = self.connect.read.table(self.tbl_name)
+ df2 = self.spark.read.table(self.tbl_name)
+
+ self.assert_eq(
+ df.select(df.id.cast("string")).toPandas(), df2.select(df2.id.cast("string")).toPandas()
+ )
+
+ # Test if the arguments can be passed properly.
+ # Do not need to check individual behaviour for the ANSI mode thoroughly.
+ with self.sql_conf({"spark.sql.ansi.enabled": False}):
+ for x in [
+ StringType(),
+ BinaryType(),
+ ShortType(),
+ IntegerType(),
+ LongType(),
+ FloatType(),
+ DoubleType(),
+ ByteType(),
+ DecimalType(10, 2),
+ BooleanType(),
+ DayTimeIntervalType(),
+ ]:
+ self.assert_eq(
+ df.select(df.id.cast(x)).toPandas(), df2.select(df2.id.cast(x)).toPandas()
+ )
+
+ def test_unsupported_functions(self):
+ # SPARK-41225: Disable unsupported functions.
+ c = self.connect.range(1).id
+ for f in (
+ "otherwise",
+ "over",
+ "isin",
+ "when",
+ "getItem",
+ "astype",
+ "between",
+ "getField",
+ "withField",
+ "dropFields",
+ ):
+ with self.assertRaises(NotImplementedError):
+ getattr(c, f)()
+
+ with self.assertRaises(NotImplementedError):
+ c["a"]
- res = pandas.DataFrame(data={"id": [0, 30, 60, 90]})
- self.assert_(pd.equals(res), f"{pd.to_string()} != {res.to_string()}")
+ with self.assertRaises(TypeError):
+ for x in c:
+ pass
if __name__ == "__main__":
@@ -86,7 +204,7 @@ def test_simple_binary_expressions(self):
from pyspark.sql.tests.connect.test_connect_column import * # noqa: F401
try:
- import xmlrunner # type: ignore
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
index 09e47657eb..d74473e725 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
@@ -15,22 +15,24 @@
# limitations under the License.
#
import uuid
-from typing import cast
import unittest
import decimal
import datetime
-from pyspark.testing.connectutils import PlanOnlyTestFixture
-from pyspark.testing.sqlutils import have_pandas, pandas_requirement_message
+from pyspark.testing.connectutils import (
+ PlanOnlyTestFixture,
+ should_test_connect,
+ connect_requirement_message,
+)
-if have_pandas:
+if should_test_connect:
from pyspark.sql.connect.proto import Expression as ProtoExpression
import pyspark.sql.connect.plan as p
from pyspark.sql.connect.column import Column
import pyspark.sql.connect.functions as fun
-@unittest.skipIf(not have_pandas, cast(str, pandas_requirement_message))
+@unittest.skipIf(not should_test_connect, connect_requirement_message)
class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
def test_simple_column_expressions(self):
df = self.connect.with_plan(p.Read("table"))
@@ -68,7 +70,7 @@ def test_map_literal(self):
map_lit_p = map_lit.to_plan(None)
self.assertEqual(2, len(map_lit_p.literal.map.pairs))
self.assertEqual("this", map_lit_p.literal.map.pairs[0].key.string)
- self.assertEqual(12, map_lit_p.literal.map.pairs[1].key.long)
+ self.assertEqual(12, map_lit_p.literal.map.pairs[1].key.integer)
val = {"this": fun.lit("is"), 12: [12, 32, 43]}
map_lit = fun.lit(val)
@@ -89,7 +91,10 @@ def test_column_literals(self):
self.assertIsNotNone(fun.lit(10).to_plan(None))
plan = fun.lit(10).to_plan(None)
- self.assertIs(plan.literal.long, 10)
+ self.assertIs(plan.literal.integer, 10)
+
+ plan = fun.lit(1 << 33).to_plan(None)
+ self.assertEqual(plan.literal.long, 1 << 33)
def test_numeric_literal_types(self):
int_lit = fun.lit(10)
@@ -167,13 +172,13 @@ def test_tuple_to_literal(self):
p2 = fun.lit(t2).to_plan(None)
self.assertIsNotNone(p2)
self.assertTrue(p2.literal.HasField("struct"))
- self.assertEqual(p2.literal.struct.fields[0].long, 1)
+ self.assertEqual(p2.literal.struct.fields[0].integer, 1)
self.assertEqual(p2.literal.struct.fields[1].string, "xyz")
p3 = fun.lit(t3).to_plan(None)
self.assertIsNotNone(p3)
self.assertTrue(p3.literal.HasField("struct"))
- self.assertEqual(p3.literal.struct.fields[0].long, 1)
+ self.assertEqual(p3.literal.struct.fields[0].integer, 1)
self.assertEqual(p3.literal.struct.fields[1].string, "abc")
self.assertEqual(p3.literal.struct.fields[2].struct.fields[0].double, 3.5)
self.assertEqual(p3.literal.struct.fields[2].struct.fields[1].boolean, True)
@@ -207,7 +212,7 @@ def test_column_expressions(self):
lit_fun = expr_plan.unresolved_function.arguments[1]
self.assertIsInstance(lit_fun, ProtoExpression)
self.assertIsInstance(lit_fun.literal, ProtoExpression.Literal)
- self.assertEqual(lit_fun.literal.long, 10)
+ self.assertEqual(lit_fun.literal.integer, 10)
mod_fun = expr_plan.unresolved_function.arguments[0]
self.assertIsInstance(mod_fun, ProtoExpression)
@@ -228,7 +233,7 @@ def test_column_expressions(self):
from pyspark.sql.tests.connect.test_connect_column_expressions import * # noqa: F401
try:
- import xmlrunner # type: ignore
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py
index ee3a927708..ee5d2d49d9 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -14,22 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-from typing import Any
import unittest
import tempfile
-from pyspark.testing.sqlutils import have_pandas, SQLTestUtils
-
from pyspark.sql import SparkSession
-
-if have_pandas:
- from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
- from pyspark.testing.pandasutils import PandasOnSparkTestCase
-else:
- from pyspark.testing.sqlutils import ReusedSQLTestCase as PandasOnSparkTestCase # type: ignore
-from pyspark.sql.dataframe import DataFrame
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.utils import ReusedPySparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
+
+if should_test_connect:
+ import grpc
+ from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
@unittest.skipIf(not should_test_connect, connect_requirement_message)
@@ -37,15 +33,8 @@ class SparkConnectFuncTestCase(PandasOnSparkTestCase, ReusedPySparkTestCase, SQL
"""Parent test fixture class for all Spark Connect related
test cases."""
- if have_pandas:
- connect: RemoteSparkSession
- tbl_name: str
- tbl_name_empty: str
- df_text: "DataFrame"
- spark: SparkSession
-
@classmethod
- def setUpClass(cls: Any):
+ def setUpClass(cls):
ReusedPySparkTestCase.setUpClass()
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
cls.hive_available = True
@@ -55,7 +44,7 @@ def setUpClass(cls: Any):
cls.connect = RemoteSparkSession.builder.remote().getOrCreate()
@classmethod
- def tearDownClass(cls: Any) -> None:
+ def tearDownClass(cls):
ReusedPySparkTestCase.tearDownClass()
@@ -63,6 +52,24 @@ class SparkConnectFunctionTests(SparkConnectFuncTestCase):
"""These test cases exercise the interface to the proto plan
generation but do not call Spark."""
+ def compare_by_show(self, df1, df2):
+ from pyspark.sql.dataframe import DataFrame as SDF
+ from pyspark.sql.connect.dataframe import DataFrame as CDF
+
+ assert isinstance(df1, (SDF, CDF))
+ if isinstance(df1, SDF):
+ str1 = df1._jdf.showString(20, 20, False)
+ else:
+ str1 = df1._show_string(20, 20, False)
+
+ assert isinstance(df2, (SDF, CDF))
+ if isinstance(df2, SDF):
+ str2 = df2._jdf.showString(20, 20, False)
+ else:
+ str2 = df2._show_string(20, 20, False)
+
+ self.assertEqual(str1, str2)
+
def test_normal_functions(self):
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF
@@ -428,6 +435,513 @@ def test_aggregation_functions(self):
.toPandas(),
)
+ def test_collection_functions(self):
+ from pyspark.sql import functions as SF
+ from pyspark.sql.connect import functions as CF
+
+ query = """
+ SELECT * FROM VALUES
+ (ARRAY('a', 'ab'), ARRAY(1, 2, 3), ARRAY(1, NULL, 3), 1, 2, 'a'),
+ (ARRAY('x', NULL), NULL, ARRAY(1, 3), 3, 4, 'x'),
+ (NULL, ARRAY(-1, -2, -3), Array(), 5, 6, NULL)
+ AS tab(a, b, c, d, e, f)
+ """
+ # +---------+------------+------------+---+---+----+
+ # | a| b| c| d| e| f|
+ # +---------+------------+------------+---+---+----+
+ # | [a, ab]| [1, 2, 3]|[1, null, 3]| 1| 2| a|
+ # |[x, null]| null| [1, 3]| 3| 4| x|
+ # | null|[-1, -2, -3]| []| 5| 6|null|
+ # +---------+------------+------------+---+---+----+
+
+ cdf = self.connect.sql(query)
+ sdf = self.spark.sql(query)
+
+ for cfunc, sfunc in [
+ (CF.array_distinct, SF.array_distinct),
+ (CF.array_max, SF.array_max),
+ (CF.array_min, SF.array_min),
+ (CF.reverse, SF.reverse),
+ (CF.size, SF.size),
+ ]:
+ self.assert_eq(
+ cdf.select(cfunc("a"), cfunc(cdf.b)).toPandas(),
+ sdf.select(sfunc("a"), sfunc(sdf.b)).toPandas(),
+ )
+
+ for cfunc, sfunc in [
+ (CF.array_except, SF.array_except),
+ (CF.array_intersect, SF.array_intersect),
+ (CF.array_union, SF.array_union),
+ (CF.arrays_overlap, SF.arrays_overlap),
+ ]:
+ self.assert_eq(
+ cdf.select(cfunc("b", cdf.c)).toPandas(),
+ sdf.select(sfunc("b", sdf.c)).toPandas(),
+ )
+
+ for cfunc, sfunc in [
+ (CF.array_position, SF.array_position),
+ (CF.array_remove, SF.array_remove),
+ ]:
+ self.assert_eq(
+ cdf.select(cfunc(cdf.a, "ab")).toPandas(),
+ sdf.select(sfunc(sdf.a, "ab")).toPandas(),
+ )
+
+ # test array
+ self.assert_eq(
+ cdf.select(CF.array(cdf.d, "e")).toPandas(),
+ sdf.select(SF.array(sdf.d, "e")).toPandas(),
+ )
+ self.assert_eq(
+ cdf.select(CF.array(cdf.d, "e", CF.lit(99))).toPandas(),
+ sdf.select(SF.array(sdf.d, "e", SF.lit(99))).toPandas(),
+ )
+
+ # test array_contains
+ self.assert_eq(
+ cdf.select(CF.array_contains(cdf.a, "ab")).toPandas(),
+ sdf.select(SF.array_contains(sdf.a, "ab")).toPandas(),
+ )
+ self.assert_eq(
+ cdf.select(CF.array_contains(cdf.a, cdf.f)).toPandas(),
+ sdf.select(SF.array_contains(sdf.a, sdf.f)).toPandas(),
+ )
+
+ # test array_join
+ self.assert_eq(
+ cdf.select(
+ CF.array_join(cdf.a, ","), CF.array_join("b", ":"), CF.array_join("c", "~")
+ ).toPandas(),
+ sdf.select(
+ SF.array_join(sdf.a, ","), SF.array_join("b", ":"), SF.array_join("c", "~")
+ ).toPandas(),
+ )
+ self.assert_eq(
+ cdf.select(
+ CF.array_join(cdf.a, ",", "_null_"),
+ CF.array_join("b", ":", ".null."),
+ CF.array_join("c", "~", "NULL"),
+ ).toPandas(),
+ sdf.select(
+ SF.array_join(sdf.a, ",", "_null_"),
+ SF.array_join("b", ":", ".null."),
+ SF.array_join("c", "~", "NULL"),
+ ).toPandas(),
+ )
+
+ # test array_repeat
+ self.assert_eq(
+ cdf.select(CF.array_repeat(cdf.f, "d")).toPandas(),
+ sdf.select(SF.array_repeat(sdf.f, "d")).toPandas(),
+ )
+ self.assert_eq(
+ cdf.select(CF.array_repeat("f", cdf.d)).toPandas(),
+ sdf.select(SF.array_repeat("f", sdf.d)).toPandas(),
+ )
+ # TODO: Make Literal contains DataType
+ # Cannot resolve "array_repeat(f, 3)" due to data type mismatch:
+ # Parameter 2 requires the "INT" type, however "3" has the type "BIGINT".
+ # self.assert_eq(
+ # cdf.select(CF.array_repeat("f", 3)).toPandas(),
+ # sdf.select(SF.array_repeat("f", 3)).toPandas(),
+ # )
+
+ # test arrays_zip
+ # TODO: Make toPandas support complex nested types like Array
+ # DataFrame.iloc[:, 0] (column name="arrays_zip(b, c)") values are different (66.66667 %)
+ # [index]: [0, 1, 2]
+ # [left]: [[{'b': 1, 'c': 1.0}, {'b': 2, 'c': None}, {'b': 3, 'c': 3.0}], None,
+ # [{'b': -1, 'c': None}, {'b': -2, 'c': None}, {'b': -3, 'c': None}]]
+ # [right]: [[(1, 1), (2, None), (3, 3)], None, [(-1, None), (-2, None), (-3, None)]]
+ self.compare_by_show(
+ cdf.select(CF.arrays_zip(cdf.b, "c")),
+ sdf.select(SF.arrays_zip(sdf.b, "c")),
+ )
+
+ # test concat
+ self.assert_eq(
+ cdf.select(CF.concat("d", cdf.e, CF.lit(-1))).toPandas(),
+ sdf.select(SF.concat("d", sdf.e, SF.lit(-1))).toPandas(),
+ )
+
+ # test create_map
+ self.compare_by_show(
+ cdf.select(CF.create_map(cdf.d, cdf.e)), sdf.select(SF.create_map(sdf.d, sdf.e))
+ )
+ self.compare_by_show(
+ cdf.select(CF.create_map(cdf.d, "e", "e", CF.lit(1))),
+ sdf.select(SF.create_map(sdf.d, "e", "e", SF.lit(1))),
+ )
+
+ # test element_at
+ self.assert_eq(
+ cdf.select(CF.element_at("a", 1)).toPandas(),
+ sdf.select(SF.element_at("a", 1)).toPandas(),
+ )
+ self.assert_eq(
+ cdf.select(CF.element_at(cdf.a, 1)).toPandas(),
+ sdf.select(SF.element_at(sdf.a, 1)).toPandas(),
+ )
+
+ # test get
+ self.assert_eq(
+ cdf.select(CF.get("a", 1)).toPandas(),
+ sdf.select(SF.get("a", 1)).toPandas(),
+ )
+ self.assert_eq(
+ cdf.select(CF.get(cdf.a, 1)).toPandas(),
+ sdf.select(SF.get(sdf.a, 1)).toPandas(),
+ )
+
+ # test shuffle
+ # Can not compare the values due to the random permutation
+ self.assertEqual(
+ cdf.select(CF.shuffle(cdf.a), CF.shuffle("b")).count(),
+ sdf.select(SF.shuffle(sdf.a), SF.shuffle("b")).count(),
+ )
+
+ # test slice
+ self.assert_eq(
+ cdf.select(CF.slice(cdf.a, 1, 2), CF.slice("c", 2, 3)).toPandas(),
+ sdf.select(SF.slice(sdf.a, 1, 2), SF.slice("c", 2, 3)).toPandas(),
+ )
+
+ # test sort_array
+ self.assert_eq(
+ cdf.select(CF.sort_array(cdf.a, True), CF.sort_array("c", False)).toPandas(),
+ sdf.select(SF.sort_array(sdf.a, True), SF.sort_array("c", False)).toPandas(),
+ )
+
+ # test struct
+ self.compare_by_show(
+ cdf.select(CF.struct(cdf.a, "d", "e", cdf.f)),
+ sdf.select(SF.struct(sdf.a, "d", "e", sdf.f)),
+ )
+
+ def test_map_collection_functions(self):
+ from pyspark.sql import functions as SF
+ from pyspark.sql.connect import functions as CF
+
+ query = """
+ SELECT * FROM VALUES
+ (MAP('a', 'ab'), MAP('x', 'ab'), MAP(1, 2, 3, 4), 1, 'a', ARRAY(1, 2), ARRAY('X', 'Y')),
+ (MAP('x', 'yz'), MAP('c', NULL), NULL, 2, 'x', ARRAY(3, 4), ARRAY('A', 'B')),
+ (MAP('c', 'de'), NULL, MAP(-1, NULL, -3, -4), -3, 'c', NULL, ARRAY('Z'))
+ AS tab(a, b, c, e, f, g, h)
+ """
+ # +---------+-----------+----------------------+---+---+------+------+
+ # | a| b| c| e| f| g| h|
+ # +---------+-----------+----------------------+---+---+------+------+
+ # |{a -> ab}| {x -> ab}| {1 -> 2, 3 -> 4}| 1| a|[1, 2]|[X, Y]|
+ # |{x -> yz}|{c -> null}| null| 2| x|[3, 4]|[A, B]|
+ # |{c -> de}| null|{-1 -> null, -3 -> -4}| -3| c| null| [Z]|
+ # +---------+-----------+----------------------+---+---+------+------+
+
+ cdf = self.connect.sql(query)
+ sdf = self.spark.sql(query)
+
+ # test map_concat
+ self.compare_by_show(
+ cdf.select(CF.map_concat(cdf.a, "b")),
+ sdf.select(SF.map_concat(sdf.a, "b")),
+ )
+
+ # test map_contains_key
+ self.compare_by_show(
+ cdf.select(CF.map_contains_key(cdf.a, "a"), CF.map_contains_key("c", 3)),
+ sdf.select(SF.map_contains_key(sdf.a, "a"), SF.map_contains_key("c", 3)),
+ )
+
+ # test map_entries
+ self.compare_by_show(
+ cdf.select(CF.map_entries(cdf.a), CF.map_entries("b")),
+ sdf.select(SF.map_entries(sdf.a), SF.map_entries("b")),
+ )
+
+ # test map_from_arrays
+ self.compare_by_show(
+ cdf.select(CF.map_from_arrays(cdf.g, "h")),
+ sdf.select(SF.map_from_arrays(sdf.g, "h")),
+ )
+
+ # test map_keys and map_values
+ self.compare_by_show(
+ cdf.select(CF.map_keys(cdf.a), CF.map_values("b")),
+ sdf.select(SF.map_keys(sdf.a), SF.map_values("b")),
+ )
+
+ # test size
+ self.assert_eq(
+ cdf.select(CF.size(cdf.a), CF.size("c")).toPandas(),
+ sdf.select(SF.size(sdf.a), SF.size("c")).toPandas(),
+ )
+
+ def test_generator_functions(self):
+ from pyspark.sql import functions as SF
+ from pyspark.sql.connect import functions as CF
+
+ query = """
+ SELECT * FROM VALUES
+ (ARRAY('a', 'ab'), ARRAY(1, 2, 3), ARRAY(1, NULL, 3),
+ MAP(1, 2, 3, 4), 1, FLOAT(2.0), 3),
+ (ARRAY('x', NULL), NULL, ARRAY(1, 3),
+ NULL, 3, FLOAT(4.0), 5),
+ (NULL, ARRAY(-1, -2, -3), Array(),
+ MAP(-1, NULL, -3, -4), 7, FLOAT('NAN'), 9)
+ AS tab(a, b, c, d, e, f, g)
+ """
+ # +---------+------------+------------+----------------------+---+---+---+
+ # | a| b| c| d| e| f| g|
+ # +---------+------------+------------+----------------------+---+---+---+
+ # | [a, ab]| [1, 2, 3]|[1, null, 3]| {1 -> 2, 3 -> 4}| 1|2.0| 3|
+ # |[x, null]| null| [1, 3]| null| 3|4.0| 5|
+ # | null|[-1, -2, -3]| []|{-1 -> null, -3 -> -4}| 7|NaN| 9|
+ # +---------+------------+------------+----------------------+---+---+---+
+
+ cdf = self.connect.sql(query)
+ sdf = self.spark.sql(query)
+
+ # test explode with arrays
+ self.assert_eq(
+ cdf.select(CF.explode(cdf.a), CF.col("b")).toPandas(),
+ sdf.select(SF.explode(sdf.a), SF.col("b")).toPandas(),
+ )
+ self.assert_eq(
+ cdf.select(CF.explode("a"), "b").toPandas(),
+ sdf.select(SF.explode("a"), "b").toPandas(),
+ )
+ # test explode with maps
+ self.assert_eq(
+ cdf.select(CF.explode(cdf.d), CF.col("c")).toPandas(),
+ sdf.select(SF.explode(sdf.d), SF.col("c")).toPandas(),
+ )
+ self.assert_eq(
+ cdf.select(CF.explode("d"), "c").toPandas(),
+ sdf.select(SF.explode("d"), "c").toPandas(),
+ )
+
+ # test explode_outer with arrays
+ self.assert_eq(
+ cdf.select(CF.explode_outer(cdf.a), CF.col("b")).toPandas(),
+ sdf.select(SF.explode_outer(sdf.a), SF.col("b")).toPandas(),
+ )
+ self.assert_eq(
+ cdf.select(CF.explode_outer("a"), "b").toPandas(),
+ sdf.select(SF.explode_outer("a"), "b").toPandas(),
+ )
+ # test explode_outer with maps
+ self.assert_eq(
+ cdf.select(CF.explode_outer(cdf.d), CF.col("c")).toPandas(),
+ sdf.select(SF.explode_outer(sdf.d), SF.col("c")).toPandas(),
+ )
+ self.assert_eq(
+ cdf.select(CF.explode_outer("d"), "c").toPandas(),
+ sdf.select(SF.explode_outer("d"), "c").toPandas(),
+ )
+
+ # test flatten
+ self.assert_eq(
+ cdf.select(CF.flatten(CF.array("b", cdf.c)), CF.col("b")).toPandas(),
+ sdf.select(SF.flatten(SF.array("b", sdf.c)), SF.col("b")).toPandas(),
+ )
+
+ # test inline
+ self.assert_eq(
+ cdf.select(CF.expr("ARRAY(STRUCT(e, f), STRUCT(g AS e, f))").alias("X"))
+ .select(CF.inline("X"))
+ .toPandas(),
+ sdf.select(SF.expr("ARRAY(STRUCT(e, f), STRUCT(g AS e, f))").alias("X"))
+ .select(SF.inline("X"))
+ .toPandas(),
+ )
+
+ # test inline_outer
+ self.assert_eq(
+ cdf.select(CF.expr("ARRAY(STRUCT(e, f), STRUCT(g AS e, f))").alias("X"))
+ .select(CF.inline_outer("X"))
+ .toPandas(),
+ sdf.select(SF.expr("ARRAY(STRUCT(e, f), STRUCT(g AS e, f))").alias("X"))
+ .select(SF.inline_outer("X"))
+ .toPandas(),
+ )
+
+ # test posexplode with arrays
+ self.assert_eq(
+ cdf.select(CF.posexplode(cdf.a), CF.col("b")).toPandas(),
+ sdf.select(SF.posexplode(sdf.a), SF.col("b")).toPandas(),
+ )
+ self.assert_eq(
+ cdf.select(CF.posexplode("a"), "b").toPandas(),
+ sdf.select(SF.posexplode("a"), "b").toPandas(),
+ )
+ # test posexplode with maps
+ self.assert_eq(
+ cdf.select(CF.posexplode(cdf.d), CF.col("c")).toPandas(),
+ sdf.select(SF.posexplode(sdf.d), SF.col("c")).toPandas(),
+ )
+ self.assert_eq(
+ cdf.select(CF.posexplode("d"), "c").toPandas(),
+ sdf.select(SF.posexplode("d"), "c").toPandas(),
+ )
+
+ # test posexplode_outer with arrays
+ self.assert_eq(
+ cdf.select(CF.posexplode_outer(cdf.a), CF.col("b")).toPandas(),
+ sdf.select(SF.posexplode_outer(sdf.a), SF.col("b")).toPandas(),
+ )
+ self.assert_eq(
+ cdf.select(CF.posexplode_outer("a"), "b").toPandas(),
+ sdf.select(SF.posexplode_outer("a"), "b").toPandas(),
+ )
+ # test posexplode_outer with maps
+ self.assert_eq(
+ cdf.select(CF.posexplode_outer(cdf.d), CF.col("c")).toPandas(),
+ sdf.select(SF.posexplode_outer(sdf.d), SF.col("c")).toPandas(),
+ )
+ self.assert_eq(
+ cdf.select(CF.posexplode_outer("d"), "c").toPandas(),
+ sdf.select(SF.posexplode_outer("d"), "c").toPandas(),
+ )
+
+ def test_csv_functions(self):
+ from pyspark.sql import functions as SF
+ from pyspark.sql.connect import functions as CF
+
+ query = """
+ SELECT * FROM VALUES
+ ('1,2,3', 'a,b,5.0'),
+ ('3,4,5', 'x,y,6.0')
+ AS tab(a, b)
+ """
+ # +-----+-------+
+ # | a| b|
+ # +-----+-------+
+ # |1,2,3|a,b,5.0|
+ # |3,4,5|x,y,6.0|
+ # +-----+-------+
+
+ cdf = self.connect.sql(query)
+ sdf = self.spark.sql(query)
+
+ # test from_csv
+ self.compare_by_show(
+ cdf.select(
+ CF.from_csv(cdf.a, "a INT, b INT, c INT"),
+ CF.from_csv("b", "x STRING, y STRING, z DOUBLE"),
+ ),
+ sdf.select(
+ SF.from_csv(sdf.a, "a INT, b INT, c INT"),
+ SF.from_csv("b", "x STRING, y STRING, z DOUBLE"),
+ ),
+ )
+ self.compare_by_show(
+ cdf.select(
+ CF.from_csv(cdf.a, CF.lit("a INT, b INT, c INT")),
+ CF.from_csv("b", CF.lit("x STRING, y STRING, z DOUBLE")),
+ ),
+ sdf.select(
+ SF.from_csv(sdf.a, SF.lit("a INT, b INT, c INT")),
+ SF.from_csv("b", SF.lit("x STRING, y STRING, z DOUBLE")),
+ ),
+ )
+
+ # test schema_of_csv
+ self.assert_eq(
+ cdf.select(CF.schema_of_csv(CF.lit('{"a": 0}'))).toPandas(),
+ sdf.select(SF.schema_of_csv(SF.lit('{"a": 0}'))).toPandas(),
+ )
+
+ # test to_csv
+ self.compare_by_show(
+ cdf.select(CF.to_csv(CF.struct(CF.lit("a"), CF.lit("b")))),
+ sdf.select(SF.to_csv(SF.struct(SF.lit("a"), SF.lit("b")))),
+ )
+
+ def test_json_functions(self):
+ from pyspark.sql import functions as SF
+ from pyspark.sql.connect import functions as CF
+
+ query = """
+ SELECT * FROM VALUES
+ ('{"a": 1}', '[1, 2, 3]', '{"f1": "value1", "f2": "value2"}'),
+ ('{"a": 0}', '[4, 5, 6]', '{"f1": "value12"}')
+ AS tab(a, b, c)
+ """
+ # +--------+---------+--------------------------------+
+ # | a| b| c|
+ # +--------+---------+--------------------------------+
+ # |{"a": 1}|[1, 2, 3]|{"f1": "value1", "f2": "value2"}|
+ # |{"a": 0}|[4, 5, 6]| {"f1": "value12"}|
+ # +--------+---------+--------------------------------+
+
+ cdf = self.connect.sql(query)
+ sdf = self.spark.sql(query)
+
+ # test from_json
+ for schema in [
+ "a INT",
+ "MAP",
+ # StructType([StructField("a", IntegerType())]),
+ # ArrayType(StructType([StructField("a", IntegerType())])),
+ ]:
+ self.compare_by_show(
+ cdf.select(CF.from_json(cdf.a, schema)),
+ sdf.select(SF.from_json(sdf.a, schema)),
+ )
+ self.compare_by_show(
+ cdf.select(CF.from_json("a", schema)),
+ sdf.select(SF.from_json("a", schema)),
+ )
+
+ for schema in [
+ "ARRAY",
+ # ArrayType(IntegerType()),
+ ]:
+ self.compare_by_show(
+ cdf.select(CF.from_json(cdf.b, schema)),
+ sdf.select(SF.from_json(sdf.b, schema)),
+ )
+ self.compare_by_show(
+ cdf.select(CF.from_json("b", schema)),
+ sdf.select(SF.from_json("b", schema)),
+ )
+
+ # test get_json_object
+ self.assert_eq(
+ cdf.select(
+ CF.get_json_object("c", "$.f1"),
+ CF.get_json_object(cdf.c, "$.f2"),
+ ).toPandas(),
+ sdf.select(
+ SF.get_json_object("c", "$.f1"),
+ SF.get_json_object(sdf.c, "$.f2"),
+ ).toPandas(),
+ )
+
+ # test json_tuple
+ self.assert_eq(
+ cdf.select(CF.json_tuple("c", "f1", "f2")).toPandas(),
+ sdf.select(SF.json_tuple("c", "f1", "f2")).toPandas(),
+ )
+ self.assert_eq(
+ cdf.select(CF.json_tuple(cdf.c, "f1", "f2")).toPandas(),
+ sdf.select(SF.json_tuple(sdf.c, "f1", "f2")).toPandas(),
+ )
+
+ # test schema_of_json
+ self.assert_eq(
+ cdf.select(CF.schema_of_json(CF.lit('{"a": 0}'))).toPandas(),
+ sdf.select(SF.schema_of_json(SF.lit('{"a": 0}'))).toPandas(),
+ )
+
+ # test to_json
+ self.compare_by_show(
+ cdf.select(CF.to_json(CF.struct(CF.lit("a"), CF.lit("b")))),
+ sdf.select(SF.to_json(SF.struct(SF.lit("a"), SF.lit("b")))),
+ )
+
def test_string_functions(self):
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF
@@ -456,6 +970,7 @@ def test_string_functions(self):
(CF.ltrim, SF.ltrim),
(CF.rtrim, SF.rtrim),
(CF.trim, SF.trim),
+ (CF.reverse, SF.reverse),
]:
self.assert_eq(
cdf.select(cfunc("a"), cfunc(cdf.b)).toPandas(),
@@ -467,28 +982,229 @@ def test_string_functions(self):
sdf.select(SF.concat_ws("-", sdf.a, "c")).toPandas(),
)
- # Disable the test for "decode" because of inconsistent column names,
- # as shown below
- #
- # >>> sdf.select(SF.decode("c", "UTF-8")).toPandas()
- # stringdecode(c, UTF-8)
- # 0 None
- # 1 ab
- # >>> cdf.select(CF.decode("c", "UTF-8")).toPandas()
- # decode(c, UTF-8)
- # 0 None
- # 1 ab
- #
- # self.assert_eq(
- # cdf.select(CF.decode("c", "UTF-8")).toPandas(),
- # sdf.select(SF.decode("c", "UTF-8")).toPandas(),
- # )
+ self.assert_eq(
+ cdf.select(CF.decode("c", "UTF-8")).toPandas(),
+ sdf.select(SF.decode("c", "UTF-8")).toPandas(),
+ )
self.assert_eq(
cdf.select(CF.encode("c", "UTF-8")).toPandas(),
sdf.select(SF.encode("c", "UTF-8")).toPandas(),
)
+ # TODO(SPARK-41283): To compare toPandas for test cases with dtypes marked
+ def test_date_ts_functions(self):
+ from pyspark.sql import functions as SF
+ from pyspark.sql.connect import functions as CF
+
+ query = """
+ SELECT * FROM VALUES
+ ('1997/02/28 10:30:00', '2023/03/01 06:00:00', 'JST', 1428476400, 2020, 12, 6),
+ ('2000/01/01 04:30:05', '2020/05/01 12:15:00', 'PST', 1403892395, 2022, 12, 6)
+ AS tab(ts1, ts2, tz, seconds, Y, M, D)
+ """
+ # +-------------------+-------------------+---+----------+----+---+---+
+ # | ts1| ts2| tz| seconds| Y| M| D|
+ # +-------------------+-------------------+---+----------+----+---+---+
+ # |1997/02/28 10:30:00|2023/03/01 06:00:00|JST|1428476400|2020| 12| 6|
+ # |2000/01/01 04:30:05|2020/05/01 12:15:00|PST|1403892395|2022| 12| 6|
+ # +-------------------+-------------------+---+----------+----+---+---+
+
+ cdf = self.connect.sql(query)
+ sdf = self.spark.sql(query)
+
+ # With no parameters
+ for cfunc, sfunc in [
+ (CF.current_date, SF.current_date),
+ ]:
+ self.assert_eq(
+ cdf.select(cfunc()).toPandas(),
+ sdf.select(sfunc()).toPandas(),
+ )
+
+ # current_timestamp
+ # [left]: datetime64[ns, America/Los_Angeles]
+ # [right]: datetime64[ns]
+ # TODO: compare the return values after resolving dtypes difference
+ self.assertEqual(
+ cdf.select(CF.current_timestamp()).count(),
+ sdf.select(SF.current_timestamp()).count(),
+ )
+
+ # localtimestamp
+ s_pdf0 = sdf.select(SF.localtimestamp()).toPandas()
+ c_pdf = cdf.select(CF.localtimestamp()).toPandas()
+ s_pdf1 = sdf.select(SF.localtimestamp()).toPandas()
+ self.assert_eq(s_pdf0 < c_pdf, c_pdf < s_pdf1)
+
+ # With only column parameter
+ for cfunc, sfunc in [
+ (CF.year, SF.year),
+ (CF.quarter, SF.quarter),
+ (CF.month, SF.month),
+ (CF.dayofweek, SF.dayofweek),
+ (CF.dayofmonth, SF.dayofmonth),
+ (CF.dayofyear, SF.dayofyear),
+ (CF.hour, SF.hour),
+ (CF.minute, SF.minute),
+ (CF.second, SF.second),
+ (CF.weekofyear, SF.weekofyear),
+ (CF.last_day, SF.last_day),
+ (CF.unix_timestamp, SF.unix_timestamp),
+ ]:
+ self.assert_eq(
+ cdf.select(cfunc(cdf.ts1)).toPandas(),
+ sdf.select(sfunc(sdf.ts1)).toPandas(),
+ )
+
+ # With format parameter
+ for cfunc, sfunc in [
+ (CF.date_format, SF.date_format),
+ (CF.to_date, SF.to_date),
+ ]:
+ self.assert_eq(
+ cdf.select(cfunc(cdf.ts1, format="yyyy-MM-dd")).toPandas(),
+ sdf.select(sfunc(sdf.ts1, format="yyyy-MM-dd")).toPandas(),
+ )
+ self.compare_by_show(
+ # [left]: datetime64[ns, America/Los_Angeles]
+ # [right]: datetime64[ns]
+ cdf.select(CF.to_timestamp(cdf.ts1, format="yyyy-MM-dd")),
+ sdf.select(SF.to_timestamp(sdf.ts1, format="yyyy-MM-dd")),
+ )
+
+ # With tz parameter
+ for cfunc, sfunc in [
+ (CF.from_utc_timestamp, SF.from_utc_timestamp),
+ (CF.to_utc_timestamp, SF.to_utc_timestamp),
+ # [left]: datetime64[ns, America/Los_Angeles]
+ # [right]: datetime64[ns]
+ ]:
+ self.compare_by_show(
+ cdf.select(cfunc(cdf.ts1, tz=cdf.tz)),
+ sdf.select(sfunc(sdf.ts1, tz=sdf.tz)),
+ )
+
+ # With numeric parameter
+ for cfunc, sfunc in [
+ (CF.date_add, SF.date_add),
+ (CF.date_sub, SF.date_sub),
+ (CF.add_months, SF.add_months),
+ ]:
+ self.assert_eq(
+ cdf.select(cfunc(cdf.ts1, cdf.D)).toPandas(),
+ sdf.select(sfunc(sdf.ts1, sdf.D)).toPandas(),
+ )
+
+ # With another timestamp as parameter
+ for cfunc, sfunc in [
+ (CF.datediff, SF.datediff),
+ (CF.months_between, SF.months_between),
+ ]:
+ self.assert_eq(
+ cdf.select(cfunc(cdf.ts1, cdf.ts2)).toPandas(),
+ sdf.select(sfunc(sdf.ts1, sdf.ts2)).toPandas(),
+ )
+
+ # With seconds parameter
+ self.compare_by_show(
+ # [left]: datetime64[ns, America/Los_Angeles]
+ # [right]: datetime64[ns]
+ cdf.select(CF.timestamp_seconds(cdf.seconds)),
+ sdf.select(SF.timestamp_seconds(sdf.seconds)),
+ )
+
+ # make_date
+ self.assert_eq(
+ cdf.select(CF.make_date(cdf.Y, cdf.M, cdf.D)).toPandas(),
+ sdf.select(SF.make_date(sdf.Y, sdf.M, sdf.D)).toPandas(),
+ )
+
+ # date_trunc
+ self.compare_by_show(
+ # [left]: datetime64[ns, America/Los_Angeles]
+ # [right]: datetime64[ns]
+ cdf.select(CF.date_trunc("day", cdf.ts1)),
+ sdf.select(SF.date_trunc("day", sdf.ts1)),
+ )
+
+ # trunc
+ self.assert_eq(
+ cdf.select(CF.trunc(cdf.ts1, "year")).toPandas(),
+ sdf.select(SF.trunc(sdf.ts1, "year")).toPandas(),
+ )
+
+ # next_day
+ self.assert_eq(
+ cdf.select(CF.next_day(cdf.ts1, "Mon")).toPandas(),
+ sdf.select(SF.next_day(sdf.ts1, "Mon")).toPandas(),
+ )
+
+ def test_misc_functions(self):
+ from pyspark.sql import functions as SF
+ from pyspark.sql.connect import functions as CF
+
+ query = """
+ SELECT a, b, c, BINARY(c) as d FROM VALUES
+ (0, float("NAN"), 'x'), (1, NULL, 'y'), (1, 2.1, 'z'), (0, 0.5, NULL)
+ AS tab(a, b, c)
+ """
+ # +---+----+----+----+
+ # | a| b| c| d|
+ # +---+----+----+----+
+ # | 0| NaN| x|[78]|
+ # | 1|null| y|[79]|
+ # | 1| 2.1| z|[7A]|
+ # | 0| 0.5|null|null|
+ # +---+----+----+----+
+
+ cdf = self.connect.sql(query)
+ sdf = self.spark.sql(query)
+
+ # test assert_true
+ with self.assertRaises(grpc.RpcError):
+ cdf.select(CF.assert_true(cdf.a > 0, "a should be positive!")).show()
+
+ # test raise_error
+ with self.assertRaises(grpc.RpcError):
+ cdf.select(CF.raise_error("a should be positive!")).show()
+
+ # test crc32
+ self.assert_eq(
+ cdf.select(CF.crc32(cdf.d)).toPandas(),
+ sdf.select(SF.crc32(sdf.d)).toPandas(),
+ )
+
+ # test hash
+ self.assert_eq(
+ cdf.select(CF.hash(cdf.a, "b", cdf.c)).toPandas(),
+ sdf.select(SF.hash(sdf.a, "b", sdf.c)).toPandas(),
+ )
+
+ # test xxhash64
+ self.assert_eq(
+ cdf.select(CF.xxhash64(cdf.a, "b", cdf.c)).toPandas(),
+ sdf.select(SF.xxhash64(sdf.a, "b", sdf.c)).toPandas(),
+ )
+
+ # test md5
+ self.assert_eq(
+ cdf.select(CF.md5(cdf.d), CF.md5("c")).toPandas(),
+ sdf.select(SF.md5(sdf.d), SF.md5("c")).toPandas(),
+ )
+
+ # test sha1
+ self.assert_eq(
+ cdf.select(CF.sha1(cdf.d), CF.sha1("c")).toPandas(),
+ sdf.select(SF.sha1(sdf.d), SF.sha1("c")).toPandas(),
+ )
+
+ # test sha2
+ self.assert_eq(
+ cdf.select(CF.sha2(cdf.c, 256), CF.sha2("d", 512)).toPandas(),
+ sdf.select(SF.sha2(sdf.c, 256), SF.sha2("d", 512)).toPandas(),
+ )
+
if __name__ == "__main__":
from pyspark.sql.tests.connect.test_connect_function import * # noqa: F401
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
index b9695eea78..e0cd54195f 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
@@ -14,13 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-from typing import cast
import unittest
-from pyspark.testing.connectutils import PlanOnlyTestFixture
-from pyspark.testing.sqlutils import have_pandas, pandas_requirement_message
+from pyspark.testing.connectutils import (
+ PlanOnlyTestFixture,
+ should_test_connect,
+ connect_requirement_message,
+)
-if have_pandas:
+if should_test_connect:
import pyspark.sql.connect.proto as proto
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.plan import WriteOperation
@@ -29,7 +31,7 @@
from pyspark.sql.types import StringType
-@unittest.skipIf(not have_pandas, cast(str, pandas_requirement_message))
+@unittest.skipIf(not should_test_connect, connect_requirement_message)
class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
"""These test cases exercise the interface to the proto plan
generation but do not call Spark."""
@@ -168,6 +170,64 @@ def test_replace(self):
self.assertEqual(plan.root.replace.replacements[1].old_value.string, "Bob")
self.assertEqual(plan.root.replace.replacements[1].new_value.string, "B")
+ def test_unpivot(self):
+ df = self.connect.readTable(table_name=self.tbl_name)
+
+ plan = (
+ df.filter(df.col_name > 3)
+ .unpivot(["id"], ["name"], "variable", "value")
+ ._plan.to_proto(self.connect)
+ )
+ self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.ids))
+ self.assertEqual(plan.root.unpivot.ids[0].unresolved_attribute.unparsed_identifier, "id")
+ self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.values))
+ self.assertEqual(
+ plan.root.unpivot.values[0].unresolved_attribute.unparsed_identifier, "name"
+ )
+ self.assertEqual(plan.root.unpivot.variable_column_name, "variable")
+ self.assertEqual(plan.root.unpivot.value_column_name, "value")
+
+ plan = (
+ df.filter(df.col_name > 3)
+ .unpivot(["id"], None, "variable", "value")
+ ._plan.to_proto(self.connect)
+ )
+ self.assertTrue(len(plan.root.unpivot.ids) == 1)
+ self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.ids))
+ self.assertEqual(plan.root.unpivot.ids[0].unresolved_attribute.unparsed_identifier, "id")
+ self.assertTrue(len(plan.root.unpivot.values) == 0)
+ self.assertEqual(plan.root.unpivot.variable_column_name, "variable")
+ self.assertEqual(plan.root.unpivot.value_column_name, "value")
+
+ def test_melt(self):
+ df = self.connect.readTable(table_name=self.tbl_name)
+
+ plan = (
+ df.filter(df.col_name > 3)
+ .melt(["id"], ["name"], "variable", "value")
+ ._plan.to_proto(self.connect)
+ )
+ self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.ids))
+ self.assertEqual(plan.root.unpivot.ids[0].unresolved_attribute.unparsed_identifier, "id")
+ self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.values))
+ self.assertEqual(
+ plan.root.unpivot.values[0].unresolved_attribute.unparsed_identifier, "name"
+ )
+ self.assertEqual(plan.root.unpivot.variable_column_name, "variable")
+ self.assertEqual(plan.root.unpivot.value_column_name, "value")
+
+ plan = (
+ df.filter(df.col_name > 3)
+ .melt(["id"], [], "variable", "value")
+ ._plan.to_proto(self.connect)
+ )
+ self.assertTrue(len(plan.root.unpivot.ids) == 1)
+ self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.ids))
+ self.assertEqual(plan.root.unpivot.ids[0].unresolved_attribute.unparsed_identifier, "id")
+ self.assertTrue(len(plan.root.unpivot.values) == 0)
+ self.assertEqual(plan.root.unpivot.variable_column_name, "variable")
+ self.assertEqual(plan.root.unpivot.value_column_name, "value")
+
def test_summary(self):
df = self.connect.readTable(table_name=self.tbl_name)
plan = df.filter(df.col_name > 3).summary()._plan.to_proto(self.connect)
@@ -379,7 +439,7 @@ def test_simple_udf(self):
self.assertIsNotNone(u)
expr = u("ThisCol", "ThatCol", "OtherCol")
self.assertTrue(isinstance(expr, Column))
- self.assertTrue(isinstance(cast(Column, expr)._expr, UserDefinedFunction))
+ self.assertTrue(isinstance(expr._expr, UserDefinedFunction))
u_plan = expr.to_plan(self.connect)
self.assertIsNotNone(u_plan)
diff --git a/python/pyspark/sql/tests/connect/test_connect_select_ops.py b/python/pyspark/sql/tests/connect/test_connect_select_ops.py
index 01d1819fdc..7f8153f7fc 100644
--- a/python/pyspark/sql/tests/connect/test_connect_select_ops.py
+++ b/python/pyspark/sql/tests/connect/test_connect_select_ops.py
@@ -14,19 +14,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-from typing import cast
import unittest
-from pyspark.testing.connectutils import PlanOnlyTestFixture
-from pyspark.testing.sqlutils import have_pandas, pandas_requirement_message
+from pyspark.testing.connectutils import (
+ PlanOnlyTestFixture,
+ should_test_connect,
+ connect_requirement_message,
+)
-if have_pandas:
+if should_test_connect:
from pyspark.sql.connect.functions import col
from pyspark.sql.connect.plan import Read
import pyspark.sql.connect.proto as proto
-@unittest.skipIf(not have_pandas, cast(str, pandas_requirement_message))
+@unittest.skipIf(not should_test_connect, connect_requirement_message)
class SparkConnectToProtoSuite(PlanOnlyTestFixture):
def test_select_with_columns_and_strings(self):
df = self.connect.with_plan(Read("table"))
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
index b3f4c7331d..c3557f4eb5 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
@@ -407,7 +407,7 @@ def merge_pandas(lft, rgt):
from pyspark.sql.tests.pandas.test_pandas_cogrouped_map import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
index 7f27671cfe..0044ae3c72 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
@@ -744,7 +744,7 @@ def my_pandas_udf(pdf):
from pyspark.sql.tests.pandas.test_pandas_grouped_map import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
index e75148e524..655f0bf151 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
@@ -240,7 +240,7 @@ def assert_test():
from pyspark.sql.tests.pandas.test_pandas_grouped_map_with_state import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_map.py b/python/pyspark/sql/tests/pandas/test_pandas_map.py
index 7f996ca55a..243cc36c67 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_map.py
@@ -207,7 +207,7 @@ def func(iterator):
from pyspark.sql.tests.pandas.test_pandas_map import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf.py b/python/pyspark/sql/tests/pandas/test_pandas_udf.py
index 077db2971e..d6d861edb3 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf.py
@@ -296,7 +296,7 @@ def noop(s: pd.Series) -> pd.Series:
from pyspark.sql.tests.pandas.test_pandas_udf import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
index aa844fc5fd..155695f497 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
@@ -551,7 +551,7 @@ def mean(x):
from pyspark.sql.tests.pandas.test_pandas_udf_grouped_agg import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
index 6580f839a8..a5b3bfc164 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
@@ -1333,7 +1333,7 @@ def udf(x):
from pyspark.sql.tests.pandas.test_pandas_udf_scalar import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py
index 8c77ed4b77..3cdf83e2d0 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py
@@ -238,7 +238,7 @@ def test_scalar_udf_type_hint(self):
df = self.spark.range(10).selectExpr("id", "id as v")
def plus_one(v: Union[pd.Series, pd.DataFrame]) -> pd.Series:
- return v + 1 # type: ignore[return-value]
+ return v + 1
plus_one = pandas_udf("long")(plus_one)
actual = df.select(plus_one(df.v).alias("plus_one"))
@@ -360,7 +360,7 @@ def func(col: "Union[pd.Series, pd.DataFrame]", *, col2: "pd.DataFrame") -> "pd.
from pyspark.sql.tests.pandas.test_pandas_udf_typehints import * # noqa: #401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints_with_future_annotations.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints_with_future_annotations.py
index a6d3bd608d..9b6751564c 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints_with_future_annotations.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints_with_future_annotations.py
@@ -241,7 +241,7 @@ def test_scalar_udf_type_hint(self):
df = self.spark.range(10).selectExpr("id", "id as v")
def plus_one(v: Union[pd.Series, pd.DataFrame]) -> pd.Series:
- return v + 1 # type: ignore[return-value]
+ return v + 1
plus_one = pandas_udf("long")(plus_one)
actual = df.select(plus_one(df.v).alias("plus_one"))
@@ -367,7 +367,7 @@ def func(col: "Union[pd.Series, pd.DataFrame]", *, col2: "pd.DataFrame") -> "pd.
from pyspark.sql.tests.pandas.test_pandas_udf_typehints_with_future_annotations import * # noqa: #401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
index 07e10a58d2..596742a23b 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
@@ -398,7 +398,7 @@ def test_bounded_mixed(self):
from pyspark.sql.tests.pandas.test_pandas_udf_window import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/streaming/test_streaming.py b/python/pyspark/sql/tests/streaming/test_streaming.py
index f170787ff7..a67e493a7c 100644
--- a/python/pyspark/sql/tests/streaming/test_streaming.py
+++ b/python/pyspark/sql/tests/streaming/test_streaming.py
@@ -653,7 +653,7 @@ def test_streaming_write_to_table(self):
from pyspark.sql.tests.streaming.test_streaming import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py
index de34565254..c6667e2517 100644
--- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py
+++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py
@@ -299,7 +299,7 @@ def onQueryTerminated(self, event):
from pyspark.sql.tests.streaming.test_streaming_listener import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/test_arrow_map.py b/python/pyspark/sql/tests/test_arrow_map.py
index a4c948fea3..6166cc5dcc 100644
--- a/python/pyspark/sql/tests/test_arrow_map.py
+++ b/python/pyspark/sql/tests/test_arrow_map.py
@@ -35,7 +35,7 @@
@unittest.skipIf(
not have_pandas or not have_pyarrow,
- pandas_requirement_message or pyarrow_requirement_message, # type: ignore[arg-type]
+ pandas_requirement_message or pyarrow_requirement_message,
)
class MapInArrowTests(ReusedSQLTestCase):
@classmethod
@@ -130,7 +130,7 @@ def test_self_join(self):
from pyspark.sql.tests.test_arrow_map import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/test_catalog.py b/python/pyspark/sql/tests/test_catalog.py
index 24cd67251a..2eccfab72f 100644
--- a/python/pyspark/sql/tests/test_catalog.py
+++ b/python/pyspark/sql/tests/test_catalog.py
@@ -398,7 +398,7 @@ def test_refresh_table(self):
from pyspark.sql.tests.test_catalog import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py
index 2c4730fd81..236fb1b539 100644
--- a/python/pyspark/sql/tests/test_column.py
+++ b/python/pyspark/sql/tests/test_column.py
@@ -192,7 +192,7 @@ def test_drop_fields(self):
from pyspark.sql.tests.test_column import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/test_conf.py b/python/pyspark/sql/tests/test_conf.py
index 4ea160818d..a8fa59c036 100644
--- a/python/pyspark/sql/tests/test_conf.py
+++ b/python/pyspark/sql/tests/test_conf.py
@@ -48,7 +48,7 @@ def test_conf(self):
from pyspark.sql.tests.test_conf import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/test_context.py b/python/pyspark/sql/tests/test_context.py
index 508a829975..b381833314 100644
--- a/python/pyspark/sql/tests/test_context.py
+++ b/python/pyspark/sql/tests/test_context.py
@@ -193,7 +193,7 @@ def test_get_or_create(self):
from pyspark.sql.tests.test_context import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/test_datasources.py b/python/pyspark/sql/tests/test_datasources.py
index 30c1855622..80ab8a3316 100644
--- a/python/pyspark/sql/tests/test_datasources.py
+++ b/python/pyspark/sql/tests/test_datasources.py
@@ -198,7 +198,7 @@ def test_ignore_column_of_all_nulls(self):
from pyspark.sql.tests.test_datasources import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py
index 55ef012b6d..94cb3c4f1e 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -1156,7 +1156,7 @@ def test_map_functions(self):
from pyspark.sql.tests.test_functions import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/test_group.py b/python/pyspark/sql/tests/test_group.py
index 19f1a0148b..19e1228d25 100644
--- a/python/pyspark/sql/tests/test_group.py
+++ b/python/pyspark/sql/tests/test_group.py
@@ -41,7 +41,7 @@ def test_aggregator(self):
from pyspark.sql.tests.test_group import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/test_pandas_sqlmetrics.py b/python/pyspark/sql/tests/test_pandas_sqlmetrics.py
index d182bafd8b..22a0e92e81 100644
--- a/python/pyspark/sql/tests/test_pandas_sqlmetrics.py
+++ b/python/pyspark/sql/tests/test_pandas_sqlmetrics.py
@@ -60,7 +60,7 @@ def test_pandas(col1):
from pyspark.sql.tests.test_pandas_sqlmetrics import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py
index 2e1bdb4424..4aa24fc2be 100644
--- a/python/pyspark/sql/tests/test_readwriter.py
+++ b/python/pyspark/sql/tests/test_readwriter.py
@@ -217,7 +217,7 @@ def test_partitioning_functions(self):
from pyspark.sql.tests.test_readwriter import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/test_serde.py b/python/pyspark/sql/tests/test_serde.py
index 4e9d347da9..e8017cfd38 100644
--- a/python/pyspark/sql/tests/test_serde.py
+++ b/python/pyspark/sql/tests/test_serde.py
@@ -145,7 +145,7 @@ def test_bytes_as_binary_type(self):
from pyspark.sql.tests.test_serde import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/test_session.py b/python/pyspark/sql/tests/test_session.py
index 80c05e1a3c..dacaff4d2d 100644
--- a/python/pyspark/sql/tests/test_session.py
+++ b/python/pyspark/sql/tests/test_session.py
@@ -404,7 +404,7 @@ def test_use_custom_class_for_extensions(self):
from pyspark.sql.tests.test_session import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py
index b1d2eccea4..bc7aafe5f0 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -1405,7 +1405,7 @@ def test_row_without_field_sorting(self):
from pyspark.sql.tests.test_types import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/sql/tests/test_utils.py b/python/pyspark/sql/tests/test_utils.py
index 954fe9f24a..080d88788b 100644
--- a/python/pyspark/sql/tests/test_utils.py
+++ b/python/pyspark/sql/tests/test_utils.py
@@ -81,7 +81,7 @@ def test_get_error_class_state(self):
from pyspark.sql.tests.test_utils import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/streaming/tests/test_context.py b/python/pyspark/streaming/tests/test_context.py
index 1e2c153176..1afcc90b9e 100644
--- a/python/pyspark/streaming/tests/test_context.py
+++ b/python/pyspark/streaming/tests/test_context.py
@@ -176,7 +176,7 @@ def test_await_termination_or_timeout(self):
from pyspark.streaming.tests.test_context import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/streaming/tests/test_dstream.py b/python/pyspark/streaming/tests/test_dstream.py
index a52d08a1b1..d37e64affb 100644
--- a/python/pyspark/streaming/tests/test_dstream.py
+++ b/python/pyspark/streaming/tests/test_dstream.py
@@ -698,7 +698,7 @@ def check_output(n):
from pyspark.streaming.tests.test_dstream import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/streaming/tests/test_kinesis.py b/python/pyspark/streaming/tests/test_kinesis.py
index 7b09f5b8f5..7efd7a7d0c 100644
--- a/python/pyspark/streaming/tests/test_kinesis.py
+++ b/python/pyspark/streaming/tests/test_kinesis.py
@@ -110,7 +110,7 @@ def get_output(_, rdd):
from pyspark.streaming.tests.test_kinesis import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/streaming/tests/test_listener.py b/python/pyspark/streaming/tests/test_listener.py
index f881b2d201..aeec278b38 100644
--- a/python/pyspark/streaming/tests/test_listener.py
+++ b/python/pyspark/streaming/tests/test_listener.py
@@ -152,7 +152,7 @@ def func(dstream):
from pyspark.streaming.tests.test_listener import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py
index 1979b6eb72..efc118b572 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -14,52 +14,66 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import typing
import os
-from typing import Any, Dict, Optional
import functools
import unittest
-from pyspark.testing.sqlutils import have_pandas
+from pyspark.testing.sqlutils import (
+ have_pandas,
+ have_pyarrow,
+ pandas_requirement_message,
+ pyarrow_requirement_message,
+)
-if have_pandas:
+
+grpc_requirement_message = None
+try:
+ import grpc
+except ImportError as e:
+ grpc_requirement_message = str(e)
+have_grpc = grpc_requirement_message is None
+
+connect_not_compiled_message = None
+if have_pandas and have_pyarrow and have_grpc:
from pyspark.sql.connect import DataFrame
from pyspark.sql.connect.plan import Read, Range, SQL
from pyspark.testing.utils import search_jar
- from pyspark.sql.connect.plan import LogicalPlan
from pyspark.sql.connect.session import SparkSession
connect_jar = search_jar("connector/connect/server", "spark-connect-assembly-", "spark-connect")
+ existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
+ jars_args = "--jars %s" % connect_jar
+ plugin_args = "--conf spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin"
+ os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, plugin_args, existing_args])
else:
- connect_jar = None
-
-
-if connect_jar is None:
- connect_requirement_message = (
+ connect_not_compiled_message = (
"Skipping all Spark Connect Python tests as the optional Spark Connect project was "
"not compiled into a JAR. To run these tests, you need to build Spark with "
"'build/sbt package' or 'build/mvn package' before running this test."
)
-else:
- existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
- jars_args = "--jars %s" % connect_jar
- plugin_args = "--conf spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin"
- os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, plugin_args, existing_args])
- connect_requirement_message = None # type: ignore
-should_test_connect = connect_requirement_message is None and have_pandas
+
+connect_requirement_message = (
+ pandas_requirement_message
+ or pyarrow_requirement_message
+ or grpc_requirement_message
+ or connect_not_compiled_message
+)
+should_test_connect: str = typing.cast(str, connect_requirement_message is None)
class MockRemoteSession:
- def __init__(self) -> None:
- self.hooks: Dict[str, Any] = {}
+ def __init__(self):
+ self.hooks = {}
- def set_hook(self, name: str, hook: Any) -> None:
+ def set_hook(self, name, hook):
self.hooks[name] = hook
- def drop_hook(self, name: str) -> None:
+ def drop_hook(self, name):
self.hooks.pop(name)
- def __getattr__(self, item: str) -> Any:
+ def __getattr__(self, item):
if item not in self.hooks:
raise LookupError(f"{item} is not defined as a method hook in MockRemoteSession")
return functools.partial(self.hooks[item])
@@ -67,43 +81,36 @@ def __getattr__(self, item: str) -> Any:
@unittest.skipIf(not should_test_connect, connect_requirement_message)
class PlanOnlyTestFixture(unittest.TestCase):
-
- connect: "MockRemoteSession"
- if have_pandas:
- session: SparkSession
-
@classmethod
- def _read_table(cls, table_name: str) -> "DataFrame":
- return DataFrame.withPlan(Read(table_name), cls.connect) # type: ignore
+ def _read_table(cls, table_name):
+ return DataFrame.withPlan(Read(table_name), cls.connect)
@classmethod
- def _udf_mock(cls, *args, **kwargs) -> str:
+ def _udf_mock(cls, *args, **kwargs):
return "internal_name"
@classmethod
def _session_range(
cls,
- start: int,
- end: int,
- step: int = 1,
- num_partitions: Optional[int] = None,
- ) -> "DataFrame":
- return DataFrame.withPlan(
- Range(start, end, step, num_partitions), cls.connect # type: ignore
- )
+ start,
+ end,
+ step=1,
+ num_partitions=None,
+ ):
+ return DataFrame.withPlan(Range(start, end, step, num_partitions), cls.connect)
@classmethod
- def _session_sql(cls, query: str) -> "DataFrame":
- return DataFrame.withPlan(SQL(query), cls.connect) # type: ignore
+ def _session_sql(cls, query):
+ return DataFrame.withPlan(SQL(query), cls.connect)
if have_pandas:
@classmethod
- def _with_plan(cls, plan: LogicalPlan) -> "DataFrame":
- return DataFrame.withPlan(plan, cls.connect) # type: ignore
+ def _with_plan(cls, plan):
+ return DataFrame.withPlan(plan, cls.connect)
@classmethod
- def setUpClass(cls: Any) -> None:
+ def setUpClass(cls):
cls.connect = MockRemoteSession()
cls.session = SparkSession.builder.remote().getOrCreate()
cls.tbl_name = "test_connect_plan_only_table_1"
@@ -115,7 +122,7 @@ def setUpClass(cls: Any) -> None:
cls.connect.set_hook("with_plan", cls._with_plan)
@classmethod
- def tearDownClass(cls: Any) -> None:
+ def tearDownClass(cls):
cls.connect.drop_hook("register_udf")
cls.connect.drop_hook("readTable")
cls.connect.drop_hook("range")
diff --git a/python/pyspark/testing/pandasutils.py b/python/pyspark/testing/pandasutils.py
index ad2f74e8af..6a828f1002 100644
--- a/python/pyspark/testing/pandasutils.py
+++ b/python/pyspark/testing/pandasutils.py
@@ -22,11 +22,6 @@
from contextlib import contextmanager
from distutils.version import LooseVersion
-import pandas as pd
-from pandas.api.types import is_list_like # type: ignore[attr-defined]
-from pandas.core.dtypes.common import is_numeric_dtype
-from pandas.testing import assert_frame_equal, assert_index_equal, assert_series_equal
-
from pyspark import pandas as ps
from pyspark.pandas.frame import DataFrame
from pyspark.pandas.indexes import Index
@@ -36,7 +31,7 @@
tabulate_requirement_message = None
try:
- from tabulate import tabulate # noqa: F401
+ from tabulate import tabulate
except ImportError as e:
# If tabulate requirement is not satisfied, skip related tests.
tabulate_requirement_message = str(e)
@@ -44,7 +39,7 @@
matplotlib_requirement_message = None
try:
- import matplotlib # noqa: F401
+ import matplotlib
except ImportError as e:
# If matplotlib requirement is not satisfied, skip related tests.
matplotlib_requirement_message = str(e)
@@ -52,7 +47,7 @@
plotly_requirement_message = None
try:
- import plotly # noqa: F401
+ import plotly
except ImportError as e:
# If plotly requirement is not satisfied, skip related tests.
plotly_requirement_message = str(e)
@@ -72,6 +67,10 @@ def convert_str_to_lambda(self, func):
return lambda x: getattr(x, func)()
def assertPandasEqual(self, left, right, check_exact=True):
+ import pandas as pd
+ from pandas.core.dtypes.common import is_numeric_dtype
+ from pandas.testing import assert_frame_equal, assert_index_equal, assert_series_equal
+
if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame):
try:
if LooseVersion(pd.__version__) >= LooseVersion("1.1"):
@@ -157,6 +156,8 @@ def assertPandasAlmostEqual(self, left, right):
- Compare floats rounding to the number of decimal places, 7 after
dropping missing values (NaN, NaT, None)
"""
+ import pandas as pd
+
if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame):
msg = (
"DataFrames are not almost equal: "
@@ -217,6 +218,9 @@ def assert_eq(self, left, right, check_exact=True, almost=False):
:param almost: if this is enabled, the comparison is delegated to `unittest`'s
`assertAlmostEqual`. See its documentation for more details.
"""
+ import pandas as pd
+ from pandas.api.types import is_list_like
+
lobj = self._to_pandas(left)
robj = self._to_pandas(right)
if isinstance(lobj, (pd.DataFrame, pd.Series, pd.Index)):
diff --git a/python/pyspark/tests/test_appsubmit.py b/python/pyspark/tests/test_appsubmit.py
index ff9f3e4f16..79b6b4fa91 100644
--- a/python/pyspark/tests/test_appsubmit.py
+++ b/python/pyspark/tests/test_appsubmit.py
@@ -298,7 +298,7 @@ def test_user_configuration(self):
from pyspark.tests.test_appsubmit import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/tests/test_broadcast.py b/python/pyspark/tests/test_broadcast.py
index bc4587ffa6..90d3caa736 100644
--- a/python/pyspark/tests/test_broadcast.py
+++ b/python/pyspark/tests/test_broadcast.py
@@ -188,7 +188,7 @@ def random_bytes(n):
from pyspark.tests.test_broadcast import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/tests/test_conf.py b/python/pyspark/tests/test_conf.py
index 6a7c7a05a9..cc9ff82909 100644
--- a/python/pyspark/tests/test_conf.py
+++ b/python/pyspark/tests/test_conf.py
@@ -36,7 +36,7 @@ def test_memory_conf(self):
from pyspark.tests.test_conf import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/tests/test_context.py b/python/pyspark/tests/test_context.py
index 1b63869562..d819656f3b 100644
--- a/python/pyspark/tests/test_context.py
+++ b/python/pyspark/tests/test_context.py
@@ -97,7 +97,7 @@ def test_add_py_file(self):
# this job fails due to `userlibrary` not being on the Python path:
# disable logging in log4j temporarily
def func(x):
- from userlibrary import UserClass # type: ignore
+ from userlibrary import UserClass
return UserClass().hello()
@@ -145,7 +145,7 @@ def test_add_egg_file_locally(self):
# To ensure that we're actually testing addPyFile's effects, check that
# this fails due to `userlibrary` not being on the Python path:
def func():
- from userlib import UserClass # type: ignore[import]
+ from userlib import UserClass
UserClass()
@@ -159,7 +159,7 @@ def func():
def test_overwrite_system_module(self):
self.sc.addPyFile(os.path.join(SPARK_HOME, "python/test_support/SimpleHTTPServer.py"))
- import SimpleHTTPServer # type: ignore[import]
+ import SimpleHTTPServer
self.assertEqual("My Server", SimpleHTTPServer.__name__)
@@ -338,7 +338,7 @@ def tearDown(self):
from pyspark.tests.test_context import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/tests/test_daemon.py b/python/pyspark/tests/test_daemon.py
index d4cb90c4e8..22196e5369 100644
--- a/python/pyspark/tests/test_daemon.py
+++ b/python/pyspark/tests/test_daemon.py
@@ -81,7 +81,7 @@ def test_termination_sigterm(self):
from pyspark.tests.test_daemon import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/tests/test_install_spark.py b/python/pyspark/tests/test_install_spark.py
index cd1c424a85..6f39a09ae1 100644
--- a/python/pyspark/tests/test_install_spark.py
+++ b/python/pyspark/tests/test_install_spark.py
@@ -142,7 +142,7 @@ def test_checked_versions(self):
from pyspark.tests.test_install_spark import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/tests/test_join.py b/python/pyspark/tests/test_join.py
index ce4c6e5dfe..de1c260696 100644
--- a/python/pyspark/tests/test_join.py
+++ b/python/pyspark/tests/test_join.py
@@ -61,7 +61,7 @@ def test_narrow_dependency_in_join(self):
from pyspark.tests.test_join import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/tests/test_memory_profiler.py b/python/pyspark/tests/test_memory_profiler.py
index cdb75e5b6a..7bd7debe6e 100644
--- a/python/pyspark/tests/test_memory_profiler.py
+++ b/python/pyspark/tests/test_memory_profiler.py
@@ -33,7 +33,7 @@
@unittest.skipIf(not has_memory_profiler, "Must have memory-profiler installed.")
-@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
+@unittest.skipIf(not have_pandas, pandas_requirement_message)
class MemoryProfilerTests(PySparkTestCase):
def setUp(self):
self._old_sys_path = list(sys.path)
@@ -156,7 +156,7 @@ def map(pdfs: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
from pyspark.tests.test_memory_profiler import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/tests/test_pin_thread.py b/python/pyspark/tests/test_pin_thread.py
index 2874e09853..dd291b8a0c 100644
--- a/python/pyspark/tests/test_pin_thread.py
+++ b/python/pyspark/tests/test_pin_thread.py
@@ -171,7 +171,7 @@ def get_outer_local_prop():
from pyspark.tests.test_pin_thread import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/tests/test_profiler.py b/python/pyspark/tests/test_profiler.py
index 8a078d36b4..1db33b59b8 100644
--- a/python/pyspark/tests/test_profiler.py
+++ b/python/pyspark/tests/test_profiler.py
@@ -155,7 +155,7 @@ def plus_one(v):
from pyspark.tests.test_profiler import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py
index 23e41d6c03..752b5d5599 100644
--- a/python/pyspark/tests/test_rdd.py
+++ b/python/pyspark/tests/test_rdd.py
@@ -931,7 +931,7 @@ def run_job(job_group, index):
from pyspark.tests.test_rdd import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/tests/test_rddbarrier.py b/python/pyspark/tests/test_rddbarrier.py
index 18d618e3e1..dd3d2d6b36 100644
--- a/python/pyspark/tests/test_rddbarrier.py
+++ b/python/pyspark/tests/test_rddbarrier.py
@@ -44,7 +44,7 @@ def f(index, iterator):
from pyspark.tests.test_rddbarrier import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/tests/test_rddsampler.py b/python/pyspark/tests/test_rddsampler.py
index b504c4ab98..b98f2668cd 100644
--- a/python/pyspark/tests/test_rddsampler.py
+++ b/python/pyspark/tests/test_rddsampler.py
@@ -58,7 +58,7 @@ def test_rdd_stratified_sampler_func(self):
from pyspark.tests.test_rddsampler import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/tests/test_readwrite.py b/python/pyspark/tests/test_readwrite.py
index d7086c4bce..73f1025635 100644
--- a/python/pyspark/tests/test_readwrite.py
+++ b/python/pyspark/tests/test_readwrite.py
@@ -360,7 +360,7 @@ def test_malformed_RDD(self):
from pyspark.tests.test_readwrite import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/tests/test_serializers.py b/python/pyspark/tests/test_serializers.py
index 0a89861a26..230723e105 100644
--- a/python/pyspark/tests/test_serializers.py
+++ b/python/pyspark/tests/test_serializers.py
@@ -108,7 +108,7 @@ def __getattr__(self, item):
def test_pickling_file_handles(self):
# to be corrected with SPARK-11160
try:
- import xmlrunner # type: ignore[import] # noqa: F401
+ import xmlrunner # noqa: F401
except ImportError:
ser = CloudPickleSerializer()
out1 = sys.stderr
diff --git a/python/pyspark/tests/test_shuffle.py b/python/pyspark/tests/test_shuffle.py
index fb11a84f8a..4fb73607a2 100644
--- a/python/pyspark/tests/test_shuffle.py
+++ b/python/pyspark/tests/test_shuffle.py
@@ -259,7 +259,7 @@ def test_external_sort_in_rdd(self):
from pyspark.tests.test_shuffle import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/tests/test_statcounter.py b/python/pyspark/tests/test_statcounter.py
index b10fe7cd91..747f42e67b 100644
--- a/python/pyspark/tests/test_statcounter.py
+++ b/python/pyspark/tests/test_statcounter.py
@@ -122,7 +122,7 @@ def test_merge_stats_with_self(self):
from pyspark.tests.test_statcounter import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/tests/test_taskcontext.py b/python/pyspark/tests/test_taskcontext.py
index b90a788ae2..5d410aa57e 100644
--- a/python/pyspark/tests/test_taskcontext.py
+++ b/python/pyspark/tests/test_taskcontext.py
@@ -342,7 +342,7 @@ def tearDown(self):
from pyspark.tests.test_taskcontext import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/tests/test_util.py b/python/pyspark/tests/test_util.py
index 0ba9a5852e..77f06721b1 100644
--- a/python/pyspark/tests/test_util.py
+++ b/python/pyspark/tests/test_util.py
@@ -89,7 +89,7 @@ def test_find_spark_home(self):
from pyspark.tests.test_util import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py
index 06ada8f81d..703690bf7f 100644
--- a/python/pyspark/tests/test_worker.py
+++ b/python/pyspark/tests/test_worker.py
@@ -263,7 +263,7 @@ def conf(cls):
from pyspark.tests.test_worker import * # noqa: F401
try:
- import xmlrunner # type: ignore[import]
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
diff --git a/python/setup.py b/python/setup.py
index af102f2308..65db3912ef 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -282,6 +282,7 @@ def run(self):
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
+ 'Programming Language :: Python :: 3.11',
'Programming Language :: Python :: Implementation :: CPython',
'Programming Language :: Python :: Implementation :: PyPy',
'Typing :: Typed'],
diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/R/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/R/Dockerfile
index e4d62cf45f..3a5b96dc12 100644
--- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/R/Dockerfile
+++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/R/Dockerfile
@@ -25,7 +25,7 @@ USER 0
RUN mkdir ${SPARK_HOME}/R
-# Install R 4.0.4 (http://cloud.r-project.org/bin/linux/debian/)
+# Install R 4.1.2 (http://cloud.r-project.org/bin/linux/debian/)
RUN \
apt-get update && \
apt install -y r-base r-base-dev && \
diff --git a/resource-managers/kubernetes/integration-tests/README.md b/resource-managers/kubernetes/integration-tests/README.md
index 17c1d117f5..82e4de9bad 100644
--- a/resource-managers/kubernetes/integration-tests/README.md
+++ b/resource-managers/kubernetes/integration-tests/README.md
@@ -27,7 +27,7 @@ To run tests with Hadoop 2.x instead of Hadoop 3.x, use `--hadoop-profile`.
./dev/dev-run-integration-tests.sh --hadoop-profile hadoop-2
-The minimum tested version of Minikube is 1.18.0. The kube-dns addon must be enabled. Minikube should
+The minimum tested version of Minikube is 1.28.0. The kube-dns addon must be enabled. Minikube should
run with a minimum of 4 CPUs and 6G of memory:
minikube start --cpus 4 --memory 6144
@@ -46,7 +46,7 @@ default this is set to `minikube`, the available backends are their prerequisite
### `minikube`
-Uses the local `minikube` cluster, this requires that `minikube` 1.18.0 or greater be installed and that it be allocated
+Uses the local `minikube` cluster, this requires that `minikube` 1.28.0 or greater be installed and that it be allocated
at least 4 CPUs and 6GB memory (some users have reported success with as few as 3 CPUs and 4GB memory). The tests will
check if `minikube` is started and abort early if it isn't currently running.
diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala
index 755feb9aca..70a849c37e 100644
--- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala
+++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala
@@ -48,9 +48,9 @@ private[spark] object Minikube extends Logging {
versionArrayOpt match {
case Some(Array(x, y, z)) =>
- if (Ordering.Tuple3[Int, Int, Int].lt((x, y, z), (1, 18, 0))) {
+ if (Ordering.Tuple3[Int, Int, Int].lt((x, y, z), (1, 28, 0))) {
assert(false, s"Unsupported Minikube version is detected: $minikubeVersionString." +
- "For integration testing Minikube version 1.18.0 or greater is expected.")
+ "For integration testing Minikube version 1.28.0 or greater is expected.")
}
case _ =>
assert(false, s"Unexpected version format detected in `$minikubeVersionString`." +
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 69dd72720a..9815fa6df8 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -822,6 +822,7 @@ private[spark] class ApplicationMaster(
case Shutdown(code) =>
exitCode = code
shutdown = true
+ allocator.setShutdown(true)
}
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
index ee1d10c204..4980d7e184 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
@@ -199,6 +199,8 @@ private[yarn] class YarnAllocator(
}
}
+ @volatile private var shutdown = false
+
// The default profile is always present so we need to initialize the datastructures keyed by
// ResourceProfile id to ensure its present if things start running before a request for
// executors could add it. This approach is easier then going and special casing everywhere.
@@ -215,6 +217,8 @@ private[yarn] class YarnAllocator(
initDefaultProfile()
+ def setShutdown(shutdown: Boolean): Unit = this.shutdown = shutdown
+
def getNumExecutorsRunning: Int = synchronized {
runningExecutorsPerResourceProfileId.values.map(_.size).sum
}
@@ -835,6 +839,8 @@ private[yarn] class YarnAllocator(
// now I think its ok as none of the containers are expected to exit.
val exitStatus = completedContainer.getExitStatus
val (exitCausedByApp, containerExitReason) = exitStatus match {
+ case _ if shutdown =>
+ (false, s"Executor for container $containerId exited after Application shutdown.")
case ContainerExitStatus.SUCCESS =>
(false, s"Executor for container $containerId exited because of a YARN event (e.g., " +
"preemption) and not because of an error in the running job.")
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
index 6e6d840604..717c620f5c 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -162,7 +162,7 @@ private[spark] class YarnClientSchedulerBackend(
*/
override def stop(exitCode: Int): Unit = {
assert(client != null, "Attempted to stop this scheduler before starting it!")
- yarnSchedulerEndpoint.handleClientModeDriverStop(exitCode)
+ yarnSchedulerEndpoint.signalDriverStop(exitCode)
if (monitorThread != null) {
monitorThread.stopMonitor()
}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala
index e70a78d3c4..3728c33228 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala
@@ -35,6 +35,11 @@ private[spark] class YarnClusterSchedulerBackend(
startBindings()
}
+ override def stop(exitCode: Int): Unit = {
+ yarnSchedulerEndpoint.signalDriverStop(exitCode)
+ super.stop()
+ }
+
override def getDriverLogUrls: Option[Map[String, String]] = {
YarnContainerInfoHelper.getLogUrls(sc.hadoopConfiguration, container = None)
}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
index 572c16d9e9..34848a7f3d 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -319,7 +319,7 @@ private[spark] abstract class YarnSchedulerBackend(
removeExecutorMessage.foreach { message => driverEndpoint.send(message) }
}
- private[cluster] def handleClientModeDriverStop(exitCode: Int): Unit = {
+ private[cluster] def signalDriverStop(exitCode: Int): Unit = {
amEndpoint match {
case Some(am) =>
am.send(Shutdown(exitCode))
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
index 5a80aa9c61..a5ca382fb4 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
@@ -693,6 +693,28 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers {
.updateBlacklist(hosts.slice(10, 11).asJava, Collections.emptyList())
}
+ test("SPARK-39601 YarnAllocator should not count executor failure after shutdown") {
+ val (handler, _) = createAllocator()
+ handler.updateResourceRequests()
+ handler.getNumExecutorsFailed should be(0)
+
+ val failedBeforeShutdown = createContainer("host1")
+ val failedAfterShutdown = createContainer("host2")
+ handler.handleAllocatedContainers(Seq(failedBeforeShutdown, failedAfterShutdown))
+
+ val failedBeforeShutdownStatus = ContainerStatus.newInstance(
+ failedBeforeShutdown.getId, ContainerState.COMPLETE, "Failed", -1)
+ val failedAfterShutdownStatus = ContainerStatus.newInstance(
+ failedAfterShutdown.getId, ContainerState.COMPLETE, "Failed", -1)
+
+ handler.processCompletedContainers(Seq(failedBeforeShutdownStatus))
+ handler.getNumExecutorsFailed should be(1)
+
+ handler.setShutdown(true)
+ handler.processCompletedContainers(Seq(failedAfterShutdownStatus))
+ handler.getNumExecutorsFailed should be(1)
+ }
+
test("SPARK-28577#YarnAllocator.resource.memory should include offHeapSize " +
"when offHeapEnabled is true.") {
val originalOffHeapEnabled = sparkConf.get(MEMORY_OFFHEAP_ENABLED)
diff --git a/scalastyle-config.xml b/scalastyle-config.xml
index 4e2da27569..f34b5d55e4 100644
--- a/scalastyle-config.xml
+++ b/scalastyle-config.xml
@@ -135,9 +135,9 @@ This file is divided into 3 sections:
-
+
- ^FunSuite[A-Za-z]*$
+ ^AnyFunSuite[A-Za-z]*$
Tests must extend org.apache.spark.SparkFunSuite instead.
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
index 41adbda7b1..38f52901aa 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
@@ -18,7 +18,7 @@ lexer grammar SqlBaseLexer;
@members {
/**
- * When true, parser should throw ParseExcetion for unclosed bracketed comment.
+ * When true, parser should throw ParseException for unclosed bracketed comment.
*/
public boolean has_unclosed_bracketed_comment = false;
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Statistics.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Statistics.java
index 270b750259..5afc869d68 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Statistics.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Statistics.java
@@ -17,8 +17,8 @@
package org.apache.spark.sql.connector.read;
+import java.util.HashMap;
import java.util.Map;
-import java.util.Optional;
import java.util.OptionalLong;
import org.apache.spark.annotation.Evolving;
@@ -35,7 +35,7 @@
public interface Statistics {
OptionalLong sizeInBytes();
OptionalLong numRows();
- default Optional