Skip to content

Commit

Permalink
[KYUUBI #4018] Execute python code supports asynchronous and query ti…
Browse files Browse the repository at this point in the history
…meout

### _Why are the changes needed?_

### _How was this patch tested?_
- [x] Add some test cases that check the changes thoroughly including negative and positive cases if possible

- [ ] Add screenshots for manual tests if appropriate

- [ ] [Run test](https://kyuubi.apache.org/docs/latest/develop_tools/testing.html#running-tests) locally before make a pull request

Closes #4018 from turboFei/python_async.

Closes #4018

2afe397 [fwang12] move ut
0d9d2f1 [fwang12] only OK
46d14f4 [fwang12] add ut
0e3a039 [fwang12] add ut
0f2eba5 [fwang12] add ut
e2718ab [fwang12] async python

Authored-by: fwang12 <fwang12@ebay.com>
Signed-off-by: Cheng Pan <chengpan@apache.org>
  • Loading branch information
turboFei authored and pan3793 committed Dec 21, 2022
1 parent 9e7a8f2 commit 3a0f08e
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.{BufferedReader, File, FilenameFilter, FileOutputStream, InputStr
import java.lang.ProcessBuilder.Redirect
import java.net.URI
import java.nio.file.{Files, Path, Paths}
import java.util.concurrent.RejectedExecutionException
import java.util.concurrent.atomic.AtomicBoolean
import javax.ws.rs.core.UriBuilder

Expand All @@ -34,17 +35,19 @@ import org.apache.spark.api.python.KyuubiPythonGatewayServer
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.StructType

import org.apache.kyuubi.{Logging, Utils}
import org.apache.kyuubi.{KyuubiSQLException, Logging, Utils}
import org.apache.kyuubi.config.KyuubiConf.{ENGINE_SPARK_PYTHON_ENV_ARCHIVE, ENGINE_SPARK_PYTHON_ENV_ARCHIVE_EXEC_PATH, ENGINE_SPARK_PYTHON_HOME_ARCHIVE}
import org.apache.kyuubi.config.KyuubiReservedKeys.{KYUUBI_SESSION_USER_KEY, KYUUBI_STATEMENT_ID_KEY}
import org.apache.kyuubi.engine.spark.KyuubiSparkUtil._
import org.apache.kyuubi.operation.ArrayFetchIterator
import org.apache.kyuubi.operation.{ArrayFetchIterator, OperationState}
import org.apache.kyuubi.operation.log.OperationLog
import org.apache.kyuubi.session.Session

class ExecutePython(
session: Session,
override val statement: String,
override val shouldRunAsync: Boolean,
queryTimeout: Long,
worker: SessionPythonWorker) extends SparkOperation(session) {

private val operationLog: OperationLog = OperationLog.createOperationLog(session, getHandle)
Expand All @@ -64,22 +67,62 @@ class ExecutePython(

override protected def beforeRun(): Unit = {
OperationLog.setCurrentOperationLog(operationLog)
super.beforeRun()
setState(OperationState.PENDING)
setHasResultSet(true)
}

override protected def runInternal(): Unit = withLocalProperties {
override protected def afterRun(): Unit = {
OperationLog.removeCurrentOperationLog()
}

private def executePython(): Unit = withLocalProperties {
try {
setState(OperationState.RUNNING)
info(diagnostics)
val response = worker.runCode(statement)
val output = response.map(_.content.getOutput()).getOrElse("")
val status = response.map(_.content.status).getOrElse("UNKNOWN_STATUS")
val ename = response.map(_.content.getEname()).getOrElse("")
val evalue = response.map(_.content.getEvalue()).getOrElse("")
val traceback = response.map(_.content.getTraceback()).getOrElse(Array.empty)
iter =
new ArrayFetchIterator[Row](Array(Row(output, status, ename, evalue, Row(traceback: _*))))
if (PythonResponse.OK_STATUS.equalsIgnoreCase(status)) {
val output = response.map(_.content.getOutput()).getOrElse("")
val ename = response.map(_.content.getEname()).getOrElse("")
val evalue = response.map(_.content.getEvalue()).getOrElse("")
val traceback = response.map(_.content.getTraceback()).getOrElse(Array.empty)
iter =
new ArrayFetchIterator[Row](Array(Row(output, status, ename, evalue, Row(traceback: _*))))
setState(OperationState.FINISHED)
} else {
throw KyuubiSQLException(s"Interpret error:\n$statement\n $response")
}
} catch {
onError(cancel = true)
} finally {
shutdownTimeoutMonitor()
}
}

override protected def runInternal(): Unit = withLocalProperties {
addTimeoutMonitor(queryTimeout)
if (shouldRunAsync) {
val asyncOperation = new Runnable {
override def run(): Unit = {
OperationLog.setCurrentOperationLog(operationLog)
executePython()
}
}

try {
val sparkSQLSessionManager = session.sessionManager
val backgroundHandle = sparkSQLSessionManager.submitBackgroundOperation(asyncOperation)
setBackgroundHandle(backgroundHandle)
} catch {
case rejected: RejectedExecutionException =>
setState(OperationState.ERROR)
val ke =
KyuubiSQLException("Error submitting python in background", rejected)
setOperationException(ke)
throw ke
}
} else {
executePython()
}
}

Expand Down Expand Up @@ -321,6 +364,10 @@ case class PythonResponse(
msg_type: String,
content: PythonResponseContent)

object PythonResponse {
final val OK_STATUS = "ok"
}

case class PythonResponseContent(
data: Map[String, String],
ename: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class SparkSQLOperationManager private (name: String) extends OperationManager(n
val worker = sessionToPythonProcess.getOrElseUpdate(
session.handle,
ExecutePython.createSessionPythonWorker(spark, session))
new ExecutePython(session, statement, worker)
new ExecutePython(session, statement, runAsync, queryTimeout, worker)
} catch {
case e: Throwable =>
spark.conf.set(OPERATION_LANGUAGE.key, OperationLanguages.SQL.toString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@ package org.apache.kyuubi.engine.spark.operation

import java.io.PrintWriter
import java.nio.file.Files
import java.sql.SQLTimeoutException
import java.util.Properties

import scala.sys.process._

import org.apache.kyuubi.engine.spark.WithSparkSQLEngine
import org.apache.kyuubi.jdbc.KyuubiHiveDriver
import org.apache.kyuubi.jdbc.hive.{KyuubiSQLException, KyuubiStatement}
import org.apache.kyuubi.operation.HiveJDBCTestHelper
import org.apache.kyuubi.tags.PySparkTest

Expand Down Expand Up @@ -62,6 +66,30 @@ class PySparkTests extends WithSparkSQLEngine with HiveJDBCTestHelper {
runPySparkTest(code, output)
}

test("executePython support timeout") {
val driver = new KyuubiHiveDriver()
val connection = driver.connect(getJdbcUrl, new Properties())
val statement = connection.createStatement().asInstanceOf[KyuubiStatement]
statement.setQueryTimeout(5)
try {
var code =
"""
|import time
|time.sleep(10)
|""".stripMargin
var e = intercept[SQLTimeoutException] {
statement.executePython(code)
}.getMessage
assert(e.contains("Query timed out"))
code = "bad_code"
e = intercept[KyuubiSQLException](statement.executePython(code)).getMessage
assert(e.contains("Interpret error"))
} finally {
statement.close()
connection.close()
}
}

private def runPySparkTest(
pyCode: String,
output: String): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,14 @@ public ResultSet executeScala(String code) throws SQLException {
return resultSet;
}

public ResultSet executePython(String code) throws SQLException {
if (!executeWithConfOverlay(
code, Collections.singletonMap("kyuubi.operation.language", "PYTHON"))) {
throw new KyuubiSQLException("The query did not generate a result set!");
}
return resultSet;
}

public void executeSetCurrentCatalog(String sql, String catalog) throws SQLException {
if (executeWithConfOverlay(
sql, Collections.singletonMap("kyuubi.operation.set.current.catalog", catalog))) {
Expand Down

0 comments on commit 3a0f08e

Please sign in to comment.