Skip to content

Commit

Permalink
[KYUUBI #1821] Add trino ExecuteStatement
Browse files Browse the repository at this point in the history
<!--
Thanks for sending a pull request!

Here are some tips for you:
  1. If this is your first time, please read our contributor guidelines: https://kyuubi.readthedocs.io/en/latest/community/contributions.html
  2. If the PR is related to an issue in https://github.com/apache/incubator-kyuubi/issues, add '[KYUUBI #XXXX]' in your PR title, e.g., '[KYUUBI #XXXX] Your PR title ...'.
  3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP][KYUUBI #XXXX] Your PR title ...'.
-->

### _Why are the changes needed?_
<!--
Please clarify why the changes are needed. For instance,
  1. If you add a feature, you can talk about the use case of it.
  2. If you fix a bug, you can clarify why it is a bug.
-->
Add trino ExecuteStatement

### _How was this patch tested?_
- [X] Add some test cases that check the changes thoroughly including negative and positive cases if possible

- [ ] Add screenshots for manual tests if appropriate

- [X] [Run test](https://kyuubi.apache.org/docs/latest/develop_tools/testing.html#running-tests) locally before make a pull request

Closes #1830 from hddong/add-operation.

Closes #1821

067bde7 [hongdongdong] use flag instead breakable
f4d6cbb [hongdongdong] fix
351e2bc [hongdongdong] move context to impl
69d7d9b [hongdongdong] fix wrong func name
9cb757a [hongdongdong] fix
a20f2d0 [hongdongdong] fix time unit
c5072db [hongdongdong] [KYUUBI #1821] Add trino ExecuteStatement

Authored-by: hongdongdong <hongdongdong@cmss.chinamobile.com>
Signed-off-by: hongdongdong <hongdongdong@cmss.chinamobile.com>
  • Loading branch information
hddong committed Feb 10, 2022
1 parent 33eda21 commit b952b7b
Show file tree
Hide file tree
Showing 15 changed files with 926 additions and 63 deletions.
7 changes: 7 additions & 0 deletions externals/kyuubi-trino-engine/pom.xml
Expand Up @@ -82,6 +82,13 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.apache.kyuubi</groupId>
<artifactId>kyuubi-hive-jdbc-shaded</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>

</dependencies>

<build>
Expand Down
Expand Up @@ -17,16 +17,13 @@

package org.apache.kyuubi.engine.trino

class TrinoContextSuite extends WithTrinoContainerServer {
import org.apache.kyuubi.engine.trino.session.TrinoSessionManager
import org.apache.kyuubi.service.AbstractBackendService
import org.apache.kyuubi.session.SessionManager

test("set current schema") {
withTrinoContainer { trinoContext =>
val trinoStatement = TrinoStatement(trinoContext, kyuubiConf, "select 1")
assert("tiny" === trinoStatement.getCurrentDatabase)
class TrinoBackendService
extends AbstractBackendService("TrinoBackendService") {

override val sessionManager: SessionManager = new TrinoSessionManager()

trinoContext.setCurrentSchema("sf1")
val trinoStatement2 = TrinoStatement(trinoContext, kyuubiConf, "select 1")
assert("sf1" === trinoStatement2.getCurrentDatabase)
}
}
}
Expand Up @@ -17,6 +17,8 @@

package org.apache.kyuubi.engine.trino

import java.time.Duration

import org.apache.kyuubi.config.ConfigBuilder
import org.apache.kyuubi.config.ConfigEntry
import org.apache.kyuubi.config.KyuubiConf
Expand All @@ -30,4 +32,11 @@ object TrinoConf {
.version("1.5.0")
.intConf
.createWithDefault(3)

val CLIENT_REQUEST_TIMEOUT: ConfigEntry[Long] =
buildConf("trino.client.request.timeout")
.doc("Timeout for Trino client request to trino cluster")
.version("1.5.0")
.timeConf
.createWithDefault(Duration.ofMinutes(2).toMillis)
}
Expand Up @@ -22,19 +22,11 @@ import java.util.concurrent.atomic.AtomicReference
import io.trino.client.ClientSession
import okhttp3.OkHttpClient

class TrinoContext(
val httpClient: OkHttpClient,
val clientSession: AtomicReference[ClientSession]) {

def getClientSession: ClientSession = clientSession.get

def setCurrentSchema(schema: String): Unit = {
clientSession.set(ClientSession.builder(clientSession.get).withSchema(schema).build())
}

}
case class TrinoContext(
httpClient: OkHttpClient,
clientSession: AtomicReference[ClientSession])

object TrinoContext {
def apply(httpClient: OkHttpClient, clientSession: ClientSession): TrinoContext =
new TrinoContext(httpClient, new AtomicReference(clientSession))
TrinoContext(httpClient, new AtomicReference(clientSession))
}
Expand Up @@ -17,19 +17,74 @@

package org.apache.kyuubi.engine.trino

import java.util.concurrent.CountDownLatch

import org.apache.kyuubi.Logging
import org.apache.kyuubi.Utils.TRINO_ENGINE_SHUTDOWN_PRIORITY
import org.apache.kyuubi.Utils.addShutdownHook
import org.apache.kyuubi.config.KyuubiConf
import org.apache.kyuubi.engine.trino.TrinoSqlEngine.countDownLatch
import org.apache.kyuubi.engine.trino.TrinoSqlEngine.currentEngine
import org.apache.kyuubi.ha.HighAvailabilityConf.HA_ZK_CONN_RETRY_POLICY
import org.apache.kyuubi.ha.client.RetryPolicies
import org.apache.kyuubi.service.Serverable
import org.apache.kyuubi.util.SignalRegister

case class TrinoSqlEngine()
extends Serverable("TrinoSQLEngine") {

override val backendService = new TrinoBackendService()

override val frontendServices = Seq(new TrinoTBinaryFrontendService(this))

override def start(): Unit = {
super.start()
// Start engine self-terminating checker after all services are ready and it can be reached by
// all servers in engine spaces.
backendService.sessionManager.startTerminatingChecker(() => {
assert(currentEngine.isDefined)
currentEngine.get.stop()
})
}

override protected def stopServer(): Unit = {
countDownLatch.countDown()
}
}

object TrinoSqlEngine extends Logging {
private val countDownLatch = new CountDownLatch(1)

val kyuubiConf: KyuubiConf = KyuubiConf()

var currentEngine: Option[TrinoSqlEngine] = None

def startEngine(): Unit = {
currentEngine = Some(new TrinoSqlEngine())
currentEngine.foreach { engine =>
engine.initialize(kyuubiConf)
engine.start()
addShutdownHook(() => engine.stop(), TRINO_ENGINE_SHUTDOWN_PRIORITY + 1)
}
}

def main(args: Array[String]): Unit = {
SignalRegister.registerLogger(logger)

// TODO start engine
warn("Trino engine under development...")
info(kyuubiConf.getAll)
try {
kyuubiConf.setIfMissing(KyuubiConf.FRONTEND_THRIFT_BINARY_BIND_PORT, 0)
kyuubiConf.setIfMissing(HA_ZK_CONN_RETRY_POLICY, RetryPolicies.N_TIME.toString)

startEngine()
// blocking main thread
countDownLatch.await()
} catch {
case t: Throwable if currentEngine.isDefined =>
currentEngine.foreach { engine =>
error(t)
engine.stop()
}
case t: Throwable => error("Create Trino Engine Failed", t)
}
}
}
Expand Up @@ -27,7 +27,6 @@ import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.concurrent.duration
import scala.concurrent.duration.Duration
import scala.util.control.Breaks._

import com.google.common.base.Verify
import io.trino.client.ClientSession
Expand All @@ -46,7 +45,7 @@ import org.apache.kyuubi.engine.trino.TrinoStatement._
class TrinoStatement(trinoContext: TrinoContext, kyuubiConf: KyuubiConf, sql: String) {

private lazy val trino = StatementClientFactory
.newStatementClient(trinoContext.httpClient, trinoContext.getClientSession, sql)
.newStatementClient(trinoContext.httpClient, trinoContext.clientSession.get, sql)

private lazy val dataProcessingPoolSize = kyuubiConf.get(DATA_PROCESSING_POOL_SIZE)

Expand All @@ -55,7 +54,7 @@ class TrinoStatement(trinoContext: TrinoContext, kyuubiConf: KyuubiConf, sql: St

def getTrinoClient: StatementClient = trino

def getCurrentDatabase: String = trinoContext.getClientSession.getSchema
def getCurrentDatabase: String = trinoContext.clientSession.get.getSchema

def getColumns: List[Column] = {
while (trino.isRunning) {
Expand Down Expand Up @@ -99,50 +98,44 @@ class TrinoStatement(trinoContext: TrinoContext, kyuubiConf: KyuubiConf, sql: St
val rowBuffer = new ArrayList[List[Any]](MAX_BUFFERED_ROWS)
var bufferStart = System.nanoTime()
val result = ArrayBuffer[List[Any]]()
try {
breakable {
while (!dataProcessing.isCompleted) {
val atEnd = drainDetectingEnd(rowQueue, rowBuffer, MAX_BUFFERED_ROWS, END_TOKEN)
if (!atEnd) {
// Flush if needed
if (rowBuffer.size() >= MAX_BUFFERED_ROWS ||
Duration.fromNanos(bufferStart).compareTo(MAX_BUFFER_TIME) >= 0) {
result ++= rowBuffer.asScala
rowBuffer.clear()
bufferStart = System.nanoTime()
}

val row = rowQueue.poll(MAX_BUFFER_TIME.toMillis, duration.MILLISECONDS)
row match {
case END_TOKEN => break
case null =>
case _ => rowBuffer.add(row)
}
}

var getDataEnd = false
while (!dataProcessing.isCompleted && !getDataEnd) {
val atEnd = drainDetectingEnd(rowQueue, rowBuffer, MAX_BUFFERED_ROWS, END_TOKEN)
if (!atEnd) {
// Flush if needed
if (rowBuffer.size() >= MAX_BUFFERED_ROWS ||
Duration.fromNanos(bufferStart).compareTo(MAX_BUFFER_TIME) >= 0) {
result ++= rowBuffer.asScala
rowBuffer.clear()
bufferStart = System.nanoTime()
}
}
if (!rowQueue.isEmpty()) {
drainDetectingEnd(rowQueue, rowBuffer, Integer.MAX_VALUE, END_TOKEN)
}
result ++= rowBuffer.asScala

val finalStatus = trino.finalStatusInfo()
if (finalStatus.getError() != null) {
val exception = KyuubiSQLException(
s"Query ${finalStatus.getId} failed: ${finalStatus.getError.getMessage}")
throw exception
val row = rowQueue.poll(MAX_BUFFER_TIME.toMillis, duration.MILLISECONDS)
row match {
case END_TOKEN => getDataEnd = true
case null =>
case _ => rowBuffer.add(row)
}
}
}
if (!rowQueue.isEmpty()) {
drainDetectingEnd(rowQueue, rowBuffer, Integer.MAX_VALUE, END_TOKEN)
}
result ++= rowBuffer.asScala

updateTrinoContext()
} catch {
case e: Exception =>
throw KyuubiSQLException(e)
val finalStatus = trino.finalStatusInfo()
if (finalStatus.getError() != null) {
throw KyuubiSQLException(
s"Query ${finalStatus.getId} failed: ${finalStatus.getError.getMessage}")
}
updateTrinoContext()

result
}

def updateTrinoContext(): Unit = {
val session = trinoContext.getClientSession
val session = trinoContext.clientSession.get

var builder = ClientSession.builder(session)
// update catalog and schema
Expand Down
@@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.kyuubi.engine.trino.operation

import java.util.concurrent.RejectedExecutionException

import org.apache.kyuubi.KyuubiSQLException
import org.apache.kyuubi.Logging
import org.apache.kyuubi.engine.trino.TrinoStatement
import org.apache.kyuubi.operation.ArrayFetchIterator
import org.apache.kyuubi.operation.IterableFetchIterator
import org.apache.kyuubi.operation.OperationState
import org.apache.kyuubi.operation.OperationType
import org.apache.kyuubi.operation.log.OperationLog
import org.apache.kyuubi.session.Session

class ExecuteStatement(
session: Session,
override val statement: String,
override val shouldRunAsync: Boolean,
incrementalCollect: Boolean)
extends TrinoOperation(OperationType.EXECUTE_STATEMENT, session) with Logging {

private val operationLog: OperationLog = OperationLog.createOperationLog(session, getHandle)
override def getOperationLog: Option[OperationLog] = Option(operationLog)

override protected def beforeRun(): Unit = {
OperationLog.setCurrentOperationLog(operationLog)
setState(OperationState.PENDING)
setHasResultSet(true)
}

override protected def afterRun(): Unit = {
OperationLog.removeCurrentOperationLog()
}

override protected def runInternal(): Unit = {
val trinoStatement = TrinoStatement(trinoContext, session.sessionManager.getConf, statement)
trino = trinoStatement.getTrinoClient
if (shouldRunAsync) {
val asyncOperation = new Runnable {
override def run(): Unit = {
OperationLog.setCurrentOperationLog(operationLog)
executeStatement(trinoStatement)
}
}

try {
val trinoSessionManager = session.sessionManager
val backgroundHandle = trinoSessionManager.submitBackgroundOperation(asyncOperation)
setBackgroundHandle(backgroundHandle)
} catch {
case rejected: RejectedExecutionException =>
setState(OperationState.ERROR)
val ke =
KyuubiSQLException("Error submitting query in background, query rejected", rejected)
setOperationException(ke)
throw ke
}
} else {
executeStatement(trinoStatement)
}
}

private def executeStatement(trinoStatement: TrinoStatement): Unit = {
setState(OperationState.RUNNING)
try {
schema = trinoStatement.getColumns
val resultSet = trinoStatement.execute()
iter =
if (incrementalCollect) {
info("Execute in incremental collect mode")
new IterableFetchIterator(resultSet)
} else {
info("Execute in full collect mode")
new ArrayFetchIterator(resultSet.toArray)
}
setState(OperationState.FINISHED)
} catch {
onError(cancel = true)
}
}
}

0 comments on commit b952b7b

Please sign in to comment.