diff --git a/externals/kyuubi-trino-engine/pom.xml b/externals/kyuubi-trino-engine/pom.xml index 788abc9f442..07042da5de6 100644 --- a/externals/kyuubi-trino-engine/pom.xml +++ b/externals/kyuubi-trino-engine/pom.xml @@ -64,6 +64,25 @@ test + + org.testcontainers + testcontainers + test + + + + + org.testcontainers + trino + test + + + + io.trino + trino-jdbc + test + + diff --git a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoConf.scala b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoConf.scala new file mode 100644 index 00000000000..161a1d4cf7d --- /dev/null +++ b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoConf.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.trino + +import org.apache.kyuubi.config.ConfigBuilder +import org.apache.kyuubi.config.ConfigEntry +import org.apache.kyuubi.config.KyuubiConf + +object TrinoConf { + private def buildConf(key: String): ConfigBuilder = KyuubiConf.buildConf(key) + + val DATA_PROCESSING_POOL_SIZE: ConfigEntry[Int] = + buildConf("trino.client.data.processing.pool.size") + .doc("The size of the thread pool used by the trino client to processing data") + .version("1.5.0") + .intConf + .createWithDefault(3) +} diff --git a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoContext.scala b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoContext.scala new file mode 100644 index 00000000000..cad7c97befc --- /dev/null +++ b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoContext.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.trino + +import java.util.concurrent.atomic.AtomicReference + +import io.trino.client.ClientSession +import okhttp3.OkHttpClient + +class TrinoContext( + val httpClient: OkHttpClient, + val clientSession: AtomicReference[ClientSession]) { + + def getClientSession: ClientSession = clientSession.get + + def setCurrentSchema(schema: String): Unit = { + clientSession.set(ClientSession.builder(clientSession.get).withSchema(schema).build()) + } + +} + +object TrinoContext { + def apply(httpClient: OkHttpClient, clientSession: ClientSession): TrinoContext = + new TrinoContext(httpClient, new AtomicReference(clientSession)) +} diff --git a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoStatement.scala b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoStatement.scala new file mode 100644 index 00000000000..40a37c95ddb --- /dev/null +++ b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoStatement.scala @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.trino + +import java.util.ArrayList +import java.util.concurrent.ArrayBlockingQueue +import java.util.concurrent.Executors + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.ExecutionContext +import scala.concurrent.Future +import scala.concurrent.duration +import scala.concurrent.duration.Duration +import scala.util.control.Breaks._ + +import com.google.common.base.Verify +import io.trino.client.ClientSession +import io.trino.client.Column +import io.trino.client.StatementClient +import io.trino.client.StatementClientFactory + +import org.apache.kyuubi.KyuubiSQLException +import org.apache.kyuubi.config.KyuubiConf +import org.apache.kyuubi.engine.trino.TrinoConf.DATA_PROCESSING_POOL_SIZE +import org.apache.kyuubi.engine.trino.TrinoStatement._ + +/** + * Trino client communicate with trino cluster. + */ +class TrinoStatement(trinoContext: TrinoContext, kyuubiConf: KyuubiConf, sql: String) { + + private lazy val trino = StatementClientFactory + .newStatementClient(trinoContext.httpClient, trinoContext.getClientSession, sql) + + private lazy val dataProcessingPoolSize = kyuubiConf.get(DATA_PROCESSING_POOL_SIZE) + + implicit val ec: ExecutionContext = + ExecutionContext.fromExecutor(Executors.newFixedThreadPool(dataProcessingPoolSize)) + + def getTrinoClient: StatementClient = trino + + def getCurrentDatabase: String = trinoContext.getClientSession.getSchema + + def getColumns: List[Column] = { + while (trino.isRunning) { + val results = trino.currentStatusInfo() + val columns = results.getColumns() + if (columns != null) { + return columns.asScala.toList + } + trino.advance() + } + Verify.verify(trino.isFinished()) + val finalStatus = trino.finalStatusInfo() + if (finalStatus.getError == null) { + throw KyuubiSQLException(s"Query has no columns (#${finalStatus.getId})") + } else { + throw KyuubiSQLException( + s"Query failed (#${finalStatus.getId}): ${finalStatus.getError.getMessage}") + } + } + + /** + * Execute sql and return ResultSet. + */ + def execute(): Iterable[List[Any]] = { + val rowQueue = new ArrayBlockingQueue[List[Any]](MAX_QUEUED_ROWS) + + val dataProcessing = Future[Unit] { + while (trino.isRunning) { + val data = trino.currentData().getData() + if (data != null) { + data.asScala.map(_.asScala.toList) + .foreach(e => putOrThrow(rowQueue, e)) + } + trino.advance() + } + } + dataProcessing.onComplete { + case _ => putOrThrow(rowQueue, END_TOKEN) + } + + val rowBuffer = new ArrayList[List[Any]](MAX_BUFFERED_ROWS) + var bufferStart = System.nanoTime() + val result = ArrayBuffer[List[Any]]() + try { + breakable { + while (!dataProcessing.isCompleted) { + val atEnd = drainDetectingEnd(rowQueue, rowBuffer, MAX_BUFFERED_ROWS, END_TOKEN) + if (!atEnd) { + // Flush if needed + if (rowBuffer.size() >= MAX_BUFFERED_ROWS || + Duration.fromNanos(bufferStart).compareTo(MAX_BUFFER_TIME) >= 0) { + result ++= rowBuffer.asScala + rowBuffer.clear() + bufferStart = System.nanoTime() + } + + val row = rowQueue.poll(MAX_BUFFER_TIME.toMillis, duration.MILLISECONDS) + row match { + case END_TOKEN => break + case null => + case _ => rowBuffer.add(row) + } + } + } + } + if (!rowQueue.isEmpty()) { + drainDetectingEnd(rowQueue, rowBuffer, Integer.MAX_VALUE, END_TOKEN) + } + val finalStatus = trino.finalStatusInfo() + if (finalStatus.getError() != null) { + val exception = KyuubiSQLException( + s"Query ${finalStatus.getId} failed: ${finalStatus.getError.getMessage}") + throw exception + } + + updateTrinoContext() + } catch { + case e: Exception => + throw KyuubiSQLException(e) + } + result + } + + def updateTrinoContext(): Unit = { + val session = trinoContext.getClientSession + + var builder = ClientSession.builder(session) + // update catalog and schema + if (trino.getSetCatalog.isPresent || trino.getSetSchema.isPresent) { + builder = builder + .withCatalog(trino.getSetCatalog.orElse(session.getCatalog)) + .withSchema(trino.getSetSchema.orElse(session.getSchema)) + } + + // update path if present + if (trino.getSetPath.isPresent) { + builder = builder.withPath(trino.getSetPath.get) + } + + // update session properties if present + if (!trino.getSetSessionProperties.isEmpty || !trino.getResetSessionProperties.isEmpty) { + val properties = session.getProperties.asScala.clone() + properties ++= trino.getSetSessionProperties.asScala + properties --= trino.getResetSessionProperties.asScala + builder = builder.withProperties(properties.asJava) + } + + trinoContext.clientSession.set(builder.build()) + } + + private def drainDetectingEnd( + rowQueue: ArrayBlockingQueue[List[Any]], + buffer: ArrayList[List[Any]], + maxBufferSize: Int, + endToken: List[Any]): Boolean = { + val drained = rowQueue.drainTo(buffer, maxBufferSize - buffer.size) + if (drained > 0 && buffer.get(buffer.size() - 1) == endToken) { + buffer.remove(buffer.size() - 1); + true + } else { + false + } + } + + private def putOrThrow(rowQueue: ArrayBlockingQueue[List[Any]], e: List[Any]): Unit = { + try { + rowQueue.put(e) + } catch { + case e: InterruptedException => + Thread.currentThread().interrupt() + throw new RuntimeException(e) + } + } +} + +object TrinoStatement { + final private val MAX_QUEUED_ROWS = 50000 + final private val MAX_BUFFERED_ROWS = 10000 + final private val MAX_BUFFER_TIME = Duration(3, duration.SECONDS) + final private val END_TOKEN = List[Any]() + + def apply(trinoContext: TrinoContext, kyuubiConf: KyuubiConf, sql: String): TrinoStatement = { + new TrinoStatement(trinoContext, kyuubiConf, sql) + } +} diff --git a/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/TrinoContextSuite.scala b/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/TrinoContextSuite.scala new file mode 100644 index 00000000000..5b195ba49e3 --- /dev/null +++ b/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/TrinoContextSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.trino + +class TrinoContextSuite extends WithTrinoContainerServer { + + test("set current schema") { + val trinoContext = TrinoContext(httpClient, session) + + val trinoStatement = TrinoStatement(trinoContext, kyuubiConf, "select 1") + assert("tiny" === trinoStatement.getCurrentDatabase) + + trinoContext.setCurrentSchema("sf1") + val trinoStatement2 = TrinoStatement(trinoContext, kyuubiConf, "select 1") + assert("sf1" === trinoStatement2.getCurrentDatabase) + } +} diff --git a/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/TrinoStatementSuite.scala b/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/TrinoStatementSuite.scala new file mode 100644 index 00000000000..137b1265772 --- /dev/null +++ b/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/TrinoStatementSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.trino + +import org.apache.kyuubi.KyuubiSQLException + +class TrinoStatementSuite extends WithTrinoContainerServer { + + test("test query") { + val trinoStatement = TrinoStatement(TrinoContext(httpClient, session), kyuubiConf, "select 1") + val schema = trinoStatement.getColumns + val resultSet = trinoStatement.execute() + + assert(schema.size === 1) + assert(schema(0).getName === "_col0") + + assert(resultSet.toIterator.hasNext) + assert(resultSet.toIterator.next() === List(1)) + + val trinoStatement2 = + TrinoStatement(TrinoContext(httpClient, session), kyuubiConf, "show schemas") + val schema2 = trinoStatement2.getColumns + val resultSet2 = trinoStatement2.execute() + + assert(schema2.size === 1) + assert(resultSet2.toIterator.hasNext) + } + + test("test update session") { + val trinoStatement = TrinoStatement(TrinoContext(httpClient, session), kyuubiConf, "select 1") + val schema2 = trinoStatement.getColumns + + assert(schema2.size === 1) + assert(schema2(0).getName === "_col0") + assert(this.schema === trinoStatement.getCurrentDatabase) + + val trinoStatement2 = TrinoStatement(TrinoContext(httpClient, session), kyuubiConf, "use sf1") + trinoStatement2.execute() + + assert("sf1" === trinoStatement2.getCurrentDatabase) + } + + test("test exception") { + val trinoStatement = TrinoStatement(TrinoContext(httpClient, session), kyuubiConf, "use kyuubi") + val e1 = intercept[KyuubiSQLException](trinoStatement.execute()) + assert(e1.getMessage.contains("Schema does not exist: tpch.kyuubi")) + } +} diff --git a/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/WithTrinoContainerServer.scala b/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/WithTrinoContainerServer.scala new file mode 100644 index 00000000000..6b751549e9f --- /dev/null +++ b/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/WithTrinoContainerServer.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.trino + +import java.net.URI +import java.time.ZoneId +import java.util.Locale +import java.util.Optional +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ + +import io.airlift.units.Duration +import io.trino.client.ClientSelectedRole +import io.trino.client.ClientSession +import okhttp3.OkHttpClient +import org.testcontainers.containers.TrinoContainer + +import org.apache.kyuubi.KyuubiFunSuite +import org.apache.kyuubi.config.KyuubiConf + +trait WithTrinoContainerServer extends KyuubiFunSuite { + + final val IMAGE_VERSION = 363 + final val DOCKER_IMAGE_NAME = s"trinodb/trino:${IMAGE_VERSION}" + + val trino = new TrinoContainer(DOCKER_IMAGE_NAME) + val kyuubiConf: KyuubiConf = KyuubiConf() + + protected val catalog = "tpch" + protected val schema = "tiny" + + override def beforeAll(): Unit = { + trino.start() + super.beforeAll() + } + + override def afterAll(): Unit = { + trino.stop() + super.afterAll() + } + + lazy val connectionUrl = trino.getJdbcUrl.replace("jdbc:trino", "http") + + lazy val session = new ClientSession( + URI.create(connectionUrl), + "kyuubi_test", + Optional.empty(), + "kyuubi", + Optional.empty(), + Set[String]().asJava, + null, + catalog, + schema, + null, + ZoneId.systemDefault(), + Locale.getDefault, + Map[String, String]().asJava, + Map[String, String]().asJava, + Map[String, String]().asJava, + Map[String, ClientSelectedRole]().asJava, + Map[String, String]().asJava, + null, + new Duration(2, TimeUnit.MINUTES), + true) + + lazy val httpClient = new OkHttpClient.Builder().build() +} diff --git a/pom.xml b/pom.xml index 1828d8eff96..f5ddcfb56dd 100644 --- a/pom.xml +++ b/pom.xml @@ -90,6 +90,7 @@ 2.12.15 2.12 + 1.1.1 4.8 1.67 4.1.1 @@ -135,6 +136,7 @@ false 2.1.11 4.1.3 + 1.16.2 363 3.4.14 @@ -450,6 +452,30 @@ ${trino.client.version} + + io.trino + trino-jdbc + ${trino.client.version} + + + + org.testcontainers + testcontainers + ${testcontainers.version} + + + + org.testcontainers + trino + ${testcontainers.version} + + + + javax.activation + activation + ${activation.version} + + io.fabric8 kubernetes-client