Skip to content

Commit e886524

Browse files
hddongyaooqinn
authored andcommitted
[KYUUBI #1641] Add Trino client
<!-- Thanks for sending a pull request! Here are some tips for you: 1. If this is your first time, please read our contributor guidelines: https://kyuubi.readthedocs.io/en/latest/community/contributions.html 2. If the PR is related to an issue in https://github.com/apache/incubator-kyuubi/issues, add '[KYUUBI #XXXX]' in your PR title, e.g., '[KYUUBI #XXXX] Your PR title ...'. 3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP][KYUUBI #XXXX] Your PR title ...'. --> ### _Why are the changes needed?_ <!-- Please clarify why the changes are needed. For instance, 1. If you add a feature, you can talk about the use case of it. 2. If you fix a bug, you can clarify why it is a bug. --> Add trino client to communicate with trino cluster. ### _How was this patch tested?_ - [X] Add some test cases that check the changes thoroughly including negative and positive cases if possible - [ ] Add screenshots for manual tests if appropriate - [X] [Run test](https://kyuubi.readthedocs.io/en/latest/develop_tools/testing.html#running-tests) locally before make a pull request Closes #1642 from hddong/add-trino-runner. Closes #1641 67c2f1f [hongdongdong] fix 3e4f778 [hongdongdong] use testcontainers 3cd4d94 [hongdongdong] [KYUUBI #1641] Add Trino client Authored-by: hongdongdong <hongdongdong@cmss.chinamobile.com> Signed-off-by: Kent Yao <yao@apache.org>
1 parent ce7916c commit e886524

File tree

8 files changed

+499
-0
lines changed

8 files changed

+499
-0
lines changed

externals/kyuubi-trino-engine/pom.xml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,25 @@
6464
<scope>test</scope>
6565
</dependency>
6666

67+
<dependency>
68+
<groupId>org.testcontainers</groupId>
69+
<artifactId>testcontainers</artifactId>
70+
<scope>test</scope>
71+
</dependency>
72+
73+
<!-- https://mvnrepository.com/artifact/org.testcontainers/trino -->
74+
<dependency>
75+
<groupId>org.testcontainers</groupId>
76+
<artifactId>trino</artifactId>
77+
<scope>test</scope>
78+
</dependency>
79+
80+
<dependency>
81+
<groupId>io.trino</groupId>
82+
<artifactId>trino-jdbc</artifactId>
83+
<scope>test</scope>
84+
</dependency>
85+
6786
</dependencies>
6887

6988
</project>
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.kyuubi.engine.trino
19+
20+
import org.apache.kyuubi.config.ConfigBuilder
21+
import org.apache.kyuubi.config.ConfigEntry
22+
import org.apache.kyuubi.config.KyuubiConf
23+
24+
object TrinoConf {
25+
private def buildConf(key: String): ConfigBuilder = KyuubiConf.buildConf(key)
26+
27+
val DATA_PROCESSING_POOL_SIZE: ConfigEntry[Int] =
28+
buildConf("trino.client.data.processing.pool.size")
29+
.doc("The size of the thread pool used by the trino client to processing data")
30+
.version("1.5.0")
31+
.intConf
32+
.createWithDefault(3)
33+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.kyuubi.engine.trino
19+
20+
import java.util.concurrent.atomic.AtomicReference
21+
22+
import io.trino.client.ClientSession
23+
import okhttp3.OkHttpClient
24+
25+
class TrinoContext(
26+
val httpClient: OkHttpClient,
27+
val clientSession: AtomicReference[ClientSession]) {
28+
29+
def getClientSession: ClientSession = clientSession.get
30+
31+
def setCurrentSchema(schema: String): Unit = {
32+
clientSession.set(ClientSession.builder(clientSession.get).withSchema(schema).build())
33+
}
34+
35+
}
36+
37+
object TrinoContext {
38+
def apply(httpClient: OkHttpClient, clientSession: ClientSession): TrinoContext =
39+
new TrinoContext(httpClient, new AtomicReference(clientSession))
40+
}
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.kyuubi.engine.trino
19+
20+
import java.util.ArrayList
21+
import java.util.concurrent.ArrayBlockingQueue
22+
import java.util.concurrent.Executors
23+
24+
import scala.collection.JavaConverters._
25+
import scala.collection.mutable.ArrayBuffer
26+
import scala.concurrent.ExecutionContext
27+
import scala.concurrent.Future
28+
import scala.concurrent.duration
29+
import scala.concurrent.duration.Duration
30+
import scala.util.control.Breaks._
31+
32+
import com.google.common.base.Verify
33+
import io.trino.client.ClientSession
34+
import io.trino.client.Column
35+
import io.trino.client.StatementClient
36+
import io.trino.client.StatementClientFactory
37+
38+
import org.apache.kyuubi.KyuubiSQLException
39+
import org.apache.kyuubi.config.KyuubiConf
40+
import org.apache.kyuubi.engine.trino.TrinoConf.DATA_PROCESSING_POOL_SIZE
41+
import org.apache.kyuubi.engine.trino.TrinoStatement._
42+
43+
/**
44+
* Trino client communicate with trino cluster.
45+
*/
46+
class TrinoStatement(trinoContext: TrinoContext, kyuubiConf: KyuubiConf, sql: String) {
47+
48+
private lazy val trino = StatementClientFactory
49+
.newStatementClient(trinoContext.httpClient, trinoContext.getClientSession, sql)
50+
51+
private lazy val dataProcessingPoolSize = kyuubiConf.get(DATA_PROCESSING_POOL_SIZE)
52+
53+
implicit val ec: ExecutionContext =
54+
ExecutionContext.fromExecutor(Executors.newFixedThreadPool(dataProcessingPoolSize))
55+
56+
def getTrinoClient: StatementClient = trino
57+
58+
def getCurrentDatabase: String = trinoContext.getClientSession.getSchema
59+
60+
def getColumns: List[Column] = {
61+
while (trino.isRunning) {
62+
val results = trino.currentStatusInfo()
63+
val columns = results.getColumns()
64+
if (columns != null) {
65+
return columns.asScala.toList
66+
}
67+
trino.advance()
68+
}
69+
Verify.verify(trino.isFinished())
70+
val finalStatus = trino.finalStatusInfo()
71+
if (finalStatus.getError == null) {
72+
throw KyuubiSQLException(s"Query has no columns (#${finalStatus.getId})")
73+
} else {
74+
throw KyuubiSQLException(
75+
s"Query failed (#${finalStatus.getId}): ${finalStatus.getError.getMessage}")
76+
}
77+
}
78+
79+
/**
80+
* Execute sql and return ResultSet.
81+
*/
82+
def execute(): Iterable[List[Any]] = {
83+
val rowQueue = new ArrayBlockingQueue[List[Any]](MAX_QUEUED_ROWS)
84+
85+
val dataProcessing = Future[Unit] {
86+
while (trino.isRunning) {
87+
val data = trino.currentData().getData()
88+
if (data != null) {
89+
data.asScala.map(_.asScala.toList)
90+
.foreach(e => putOrThrow(rowQueue, e))
91+
}
92+
trino.advance()
93+
}
94+
}
95+
dataProcessing.onComplete {
96+
case _ => putOrThrow(rowQueue, END_TOKEN)
97+
}
98+
99+
val rowBuffer = new ArrayList[List[Any]](MAX_BUFFERED_ROWS)
100+
var bufferStart = System.nanoTime()
101+
val result = ArrayBuffer[List[Any]]()
102+
try {
103+
breakable {
104+
while (!dataProcessing.isCompleted) {
105+
val atEnd = drainDetectingEnd(rowQueue, rowBuffer, MAX_BUFFERED_ROWS, END_TOKEN)
106+
if (!atEnd) {
107+
// Flush if needed
108+
if (rowBuffer.size() >= MAX_BUFFERED_ROWS ||
109+
Duration.fromNanos(bufferStart).compareTo(MAX_BUFFER_TIME) >= 0) {
110+
result ++= rowBuffer.asScala
111+
rowBuffer.clear()
112+
bufferStart = System.nanoTime()
113+
}
114+
115+
val row = rowQueue.poll(MAX_BUFFER_TIME.toMillis, duration.MILLISECONDS)
116+
row match {
117+
case END_TOKEN => break
118+
case null =>
119+
case _ => rowBuffer.add(row)
120+
}
121+
}
122+
}
123+
}
124+
if (!rowQueue.isEmpty()) {
125+
drainDetectingEnd(rowQueue, rowBuffer, Integer.MAX_VALUE, END_TOKEN)
126+
}
127+
val finalStatus = trino.finalStatusInfo()
128+
if (finalStatus.getError() != null) {
129+
val exception = KyuubiSQLException(
130+
s"Query ${finalStatus.getId} failed: ${finalStatus.getError.getMessage}")
131+
throw exception
132+
}
133+
134+
updateTrinoContext()
135+
} catch {
136+
case e: Exception =>
137+
throw KyuubiSQLException(e)
138+
}
139+
result
140+
}
141+
142+
def updateTrinoContext(): Unit = {
143+
val session = trinoContext.getClientSession
144+
145+
var builder = ClientSession.builder(session)
146+
// update catalog and schema
147+
if (trino.getSetCatalog.isPresent || trino.getSetSchema.isPresent) {
148+
builder = builder
149+
.withCatalog(trino.getSetCatalog.orElse(session.getCatalog))
150+
.withSchema(trino.getSetSchema.orElse(session.getSchema))
151+
}
152+
153+
// update path if present
154+
if (trino.getSetPath.isPresent) {
155+
builder = builder.withPath(trino.getSetPath.get)
156+
}
157+
158+
// update session properties if present
159+
if (!trino.getSetSessionProperties.isEmpty || !trino.getResetSessionProperties.isEmpty) {
160+
val properties = session.getProperties.asScala.clone()
161+
properties ++= trino.getSetSessionProperties.asScala
162+
properties --= trino.getResetSessionProperties.asScala
163+
builder = builder.withProperties(properties.asJava)
164+
}
165+
166+
trinoContext.clientSession.set(builder.build())
167+
}
168+
169+
private def drainDetectingEnd(
170+
rowQueue: ArrayBlockingQueue[List[Any]],
171+
buffer: ArrayList[List[Any]],
172+
maxBufferSize: Int,
173+
endToken: List[Any]): Boolean = {
174+
val drained = rowQueue.drainTo(buffer, maxBufferSize - buffer.size)
175+
if (drained > 0 && buffer.get(buffer.size() - 1) == endToken) {
176+
buffer.remove(buffer.size() - 1);
177+
true
178+
} else {
179+
false
180+
}
181+
}
182+
183+
private def putOrThrow(rowQueue: ArrayBlockingQueue[List[Any]], e: List[Any]): Unit = {
184+
try {
185+
rowQueue.put(e)
186+
} catch {
187+
case e: InterruptedException =>
188+
Thread.currentThread().interrupt()
189+
throw new RuntimeException(e)
190+
}
191+
}
192+
}
193+
194+
object TrinoStatement {
195+
final private val MAX_QUEUED_ROWS = 50000
196+
final private val MAX_BUFFERED_ROWS = 10000
197+
final private val MAX_BUFFER_TIME = Duration(3, duration.SECONDS)
198+
final private val END_TOKEN = List[Any]()
199+
200+
def apply(trinoContext: TrinoContext, kyuubiConf: KyuubiConf, sql: String): TrinoStatement = {
201+
new TrinoStatement(trinoContext, kyuubiConf, sql)
202+
}
203+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.kyuubi.engine.trino
19+
20+
class TrinoContextSuite extends WithTrinoContainerServer {
21+
22+
test("set current schema") {
23+
val trinoContext = TrinoContext(httpClient, session)
24+
25+
val trinoStatement = TrinoStatement(trinoContext, kyuubiConf, "select 1")
26+
assert("tiny" === trinoStatement.getCurrentDatabase)
27+
28+
trinoContext.setCurrentSchema("sf1")
29+
val trinoStatement2 = TrinoStatement(trinoContext, kyuubiConf, "select 1")
30+
assert("sf1" === trinoStatement2.getCurrentDatabase)
31+
}
32+
}

0 commit comments

Comments
 (0)