diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 08ebf8b10fefc..64eeca64d827c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -18,19 +18,24 @@ package org.apache.spark.sql import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.net.{InetAddress, Socket} import java.sql.{Date, Timestamp} -import org.apache.spark.SparkException +import scala.io.Source + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.security.SocketAuthHelper import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.util.sideBySide -import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec} +import org.apache.spark.sql.execution.{LogicalRDD, QueryExecution, RDDScanExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.QueryExecutionListener case class TestDataPoint(x: Int, y: Double, s: String, t: TestDataPoint2) case class TestDataPoint2(x: Int, s: String) @@ -1586,6 +1591,30 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-34726: Fix collectToPython timeouts") { + val listener = new QueryExecutionListener { + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + // Longer than 15s in `PythonServer.setupOneConnectionServer` + Thread.sleep(20 * 1000) + } + } + spark.listenerManager.register(listener) + + val Array(port: Int, secretToPython: String) = spark.range(5).toDF().collectToPython() + + // Mimic Python side + val socket = new Socket(InetAddress.getByAddress(Array(127, 0, 0, 1)), port) + val authHelper = new SocketAuthHelper(new SparkConf()) { + override val secret: String = secretToPython + } + authHelper.authToServer(socket) + Source.fromInputStream(socket.getInputStream) + + spark.listenerManager.unregister(listener) + } } case class TestDataUnion(x: Int, y: Int, z: Int)