Skip to content

Commit

Permalink
[KYUUBI #1641] Add Trino client
Browse files Browse the repository at this point in the history
<!--
Thanks for sending a pull request!

Here are some tips for you:
  1. If this is your first time, please read our contributor guidelines: https://kyuubi.readthedocs.io/en/latest/community/contributions.html
  2. If the PR is related to an issue in https://github.com/apache/incubator-kyuubi/issues, add '[KYUUBI #XXXX]' in your PR title, e.g., '[KYUUBI #XXXX] Your PR title ...'.
  3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP][KYUUBI #XXXX] Your PR title ...'.
-->

### _Why are the changes needed?_
<!--
Please clarify why the changes are needed. For instance,
  1. If you add a feature, you can talk about the use case of it.
  2. If you fix a bug, you can clarify why it is a bug.
-->
Add trino client to communicate with trino cluster.

### _How was this patch tested?_
- [X] Add some test cases that check the changes thoroughly including negative and positive cases if possible

- [ ] Add screenshots for manual tests if appropriate

- [X] [Run test](https://kyuubi.readthedocs.io/en/latest/develop_tools/testing.html#running-tests) locally before make a pull request

Closes #1642 from hddong/add-trino-runner.

Closes #1641

67c2f1f [hongdongdong] fix
3e4f778 [hongdongdong] use testcontainers
3cd4d94 [hongdongdong] [KYUUBI #1641] Add Trino client

Authored-by: hongdongdong <hongdongdong@cmss.chinamobile.com>
Signed-off-by: Kent Yao <yao@apache.org>
  • Loading branch information
hddong authored and yaooqinn committed Jan 7, 2022
1 parent ce7916c commit e886524
Show file tree
Hide file tree
Showing 8 changed files with 499 additions and 0 deletions.
19 changes: 19 additions & 0 deletions externals/kyuubi-trino-engine/pom.xml
Expand Up @@ -64,6 +64,25 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>testcontainers</artifactId>
<scope>test</scope>
</dependency>

<!-- https://mvnrepository.com/artifact/org.testcontainers/trino -->
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>trino</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>io.trino</groupId>
<artifactId>trino-jdbc</artifactId>
<scope>test</scope>
</dependency>

</dependencies>

</project>
@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.kyuubi.engine.trino

import org.apache.kyuubi.config.ConfigBuilder
import org.apache.kyuubi.config.ConfigEntry
import org.apache.kyuubi.config.KyuubiConf

object TrinoConf {
private def buildConf(key: String): ConfigBuilder = KyuubiConf.buildConf(key)

val DATA_PROCESSING_POOL_SIZE: ConfigEntry[Int] =
buildConf("trino.client.data.processing.pool.size")
.doc("The size of the thread pool used by the trino client to processing data")
.version("1.5.0")
.intConf
.createWithDefault(3)
}
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.kyuubi.engine.trino

import java.util.concurrent.atomic.AtomicReference

import io.trino.client.ClientSession
import okhttp3.OkHttpClient

class TrinoContext(
val httpClient: OkHttpClient,
val clientSession: AtomicReference[ClientSession]) {

def getClientSession: ClientSession = clientSession.get

def setCurrentSchema(schema: String): Unit = {
clientSession.set(ClientSession.builder(clientSession.get).withSchema(schema).build())
}

}

object TrinoContext {
def apply(httpClient: OkHttpClient, clientSession: ClientSession): TrinoContext =
new TrinoContext(httpClient, new AtomicReference(clientSession))
}
@@ -0,0 +1,203 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.kyuubi.engine.trino

import java.util.ArrayList
import java.util.concurrent.ArrayBlockingQueue
import java.util.concurrent.Executors

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.concurrent.duration
import scala.concurrent.duration.Duration
import scala.util.control.Breaks._

import com.google.common.base.Verify
import io.trino.client.ClientSession
import io.trino.client.Column
import io.trino.client.StatementClient
import io.trino.client.StatementClientFactory

import org.apache.kyuubi.KyuubiSQLException
import org.apache.kyuubi.config.KyuubiConf
import org.apache.kyuubi.engine.trino.TrinoConf.DATA_PROCESSING_POOL_SIZE
import org.apache.kyuubi.engine.trino.TrinoStatement._

/**
* Trino client communicate with trino cluster.
*/
class TrinoStatement(trinoContext: TrinoContext, kyuubiConf: KyuubiConf, sql: String) {

private lazy val trino = StatementClientFactory
.newStatementClient(trinoContext.httpClient, trinoContext.getClientSession, sql)

private lazy val dataProcessingPoolSize = kyuubiConf.get(DATA_PROCESSING_POOL_SIZE)

implicit val ec: ExecutionContext =
ExecutionContext.fromExecutor(Executors.newFixedThreadPool(dataProcessingPoolSize))

def getTrinoClient: StatementClient = trino

def getCurrentDatabase: String = trinoContext.getClientSession.getSchema

def getColumns: List[Column] = {
while (trino.isRunning) {
val results = trino.currentStatusInfo()
val columns = results.getColumns()
if (columns != null) {
return columns.asScala.toList
}
trino.advance()
}
Verify.verify(trino.isFinished())
val finalStatus = trino.finalStatusInfo()
if (finalStatus.getError == null) {
throw KyuubiSQLException(s"Query has no columns (#${finalStatus.getId})")
} else {
throw KyuubiSQLException(
s"Query failed (#${finalStatus.getId}): ${finalStatus.getError.getMessage}")
}
}

/**
* Execute sql and return ResultSet.
*/
def execute(): Iterable[List[Any]] = {
val rowQueue = new ArrayBlockingQueue[List[Any]](MAX_QUEUED_ROWS)

val dataProcessing = Future[Unit] {
while (trino.isRunning) {
val data = trino.currentData().getData()
if (data != null) {
data.asScala.map(_.asScala.toList)
.foreach(e => putOrThrow(rowQueue, e))
}
trino.advance()
}
}
dataProcessing.onComplete {
case _ => putOrThrow(rowQueue, END_TOKEN)
}

val rowBuffer = new ArrayList[List[Any]](MAX_BUFFERED_ROWS)
var bufferStart = System.nanoTime()
val result = ArrayBuffer[List[Any]]()
try {
breakable {
while (!dataProcessing.isCompleted) {
val atEnd = drainDetectingEnd(rowQueue, rowBuffer, MAX_BUFFERED_ROWS, END_TOKEN)
if (!atEnd) {
// Flush if needed
if (rowBuffer.size() >= MAX_BUFFERED_ROWS ||
Duration.fromNanos(bufferStart).compareTo(MAX_BUFFER_TIME) >= 0) {
result ++= rowBuffer.asScala
rowBuffer.clear()
bufferStart = System.nanoTime()
}

val row = rowQueue.poll(MAX_BUFFER_TIME.toMillis, duration.MILLISECONDS)
row match {
case END_TOKEN => break
case null =>
case _ => rowBuffer.add(row)
}
}
}
}
if (!rowQueue.isEmpty()) {
drainDetectingEnd(rowQueue, rowBuffer, Integer.MAX_VALUE, END_TOKEN)
}
val finalStatus = trino.finalStatusInfo()
if (finalStatus.getError() != null) {
val exception = KyuubiSQLException(
s"Query ${finalStatus.getId} failed: ${finalStatus.getError.getMessage}")
throw exception
}

updateTrinoContext()
} catch {
case e: Exception =>
throw KyuubiSQLException(e)
}
result
}

def updateTrinoContext(): Unit = {
val session = trinoContext.getClientSession

var builder = ClientSession.builder(session)
// update catalog and schema
if (trino.getSetCatalog.isPresent || trino.getSetSchema.isPresent) {
builder = builder
.withCatalog(trino.getSetCatalog.orElse(session.getCatalog))
.withSchema(trino.getSetSchema.orElse(session.getSchema))
}

// update path if present
if (trino.getSetPath.isPresent) {
builder = builder.withPath(trino.getSetPath.get)
}

// update session properties if present
if (!trino.getSetSessionProperties.isEmpty || !trino.getResetSessionProperties.isEmpty) {
val properties = session.getProperties.asScala.clone()
properties ++= trino.getSetSessionProperties.asScala
properties --= trino.getResetSessionProperties.asScala
builder = builder.withProperties(properties.asJava)
}

trinoContext.clientSession.set(builder.build())
}

private def drainDetectingEnd(
rowQueue: ArrayBlockingQueue[List[Any]],
buffer: ArrayList[List[Any]],
maxBufferSize: Int,
endToken: List[Any]): Boolean = {
val drained = rowQueue.drainTo(buffer, maxBufferSize - buffer.size)
if (drained > 0 && buffer.get(buffer.size() - 1) == endToken) {
buffer.remove(buffer.size() - 1);
true
} else {
false
}
}

private def putOrThrow(rowQueue: ArrayBlockingQueue[List[Any]], e: List[Any]): Unit = {
try {
rowQueue.put(e)
} catch {
case e: InterruptedException =>
Thread.currentThread().interrupt()
throw new RuntimeException(e)
}
}
}

object TrinoStatement {
final private val MAX_QUEUED_ROWS = 50000
final private val MAX_BUFFERED_ROWS = 10000
final private val MAX_BUFFER_TIME = Duration(3, duration.SECONDS)
final private val END_TOKEN = List[Any]()

def apply(trinoContext: TrinoContext, kyuubiConf: KyuubiConf, sql: String): TrinoStatement = {
new TrinoStatement(trinoContext, kyuubiConf, sql)
}
}
@@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.kyuubi.engine.trino

class TrinoContextSuite extends WithTrinoContainerServer {

test("set current schema") {
val trinoContext = TrinoContext(httpClient, session)

val trinoStatement = TrinoStatement(trinoContext, kyuubiConf, "select 1")
assert("tiny" === trinoStatement.getCurrentDatabase)

trinoContext.setCurrentSchema("sf1")
val trinoStatement2 = TrinoStatement(trinoContext, kyuubiConf, "select 1")
assert("sf1" === trinoStatement2.getCurrentDatabase)
}
}

0 comments on commit e886524

Please sign in to comment.