|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | + * contributor license agreements. See the NOTICE file distributed with |
| 4 | + * this work for additional information regarding copyright ownership. |
| 5 | + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | + * (the "License"); you may not use this file except in compliance with |
| 7 | + * the License. You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + */ |
| 17 | + |
| 18 | +package org.apache.kyuubi.engine.trino |
| 19 | + |
| 20 | +import java.util.ArrayList |
| 21 | +import java.util.concurrent.ArrayBlockingQueue |
| 22 | +import java.util.concurrent.Executors |
| 23 | + |
| 24 | +import scala.collection.JavaConverters._ |
| 25 | +import scala.collection.mutable.ArrayBuffer |
| 26 | +import scala.concurrent.ExecutionContext |
| 27 | +import scala.concurrent.Future |
| 28 | +import scala.concurrent.duration |
| 29 | +import scala.concurrent.duration.Duration |
| 30 | +import scala.util.control.Breaks._ |
| 31 | + |
| 32 | +import com.google.common.base.Verify |
| 33 | +import io.trino.client.ClientSession |
| 34 | +import io.trino.client.Column |
| 35 | +import io.trino.client.StatementClient |
| 36 | +import io.trino.client.StatementClientFactory |
| 37 | + |
| 38 | +import org.apache.kyuubi.KyuubiSQLException |
| 39 | +import org.apache.kyuubi.config.KyuubiConf |
| 40 | +import org.apache.kyuubi.engine.trino.TrinoConf.DATA_PROCESSING_POOL_SIZE |
| 41 | +import org.apache.kyuubi.engine.trino.TrinoStatement._ |
| 42 | + |
| 43 | +/** |
| 44 | + * Trino client communicate with trino cluster. |
| 45 | + */ |
| 46 | +class TrinoStatement(trinoContext: TrinoContext, kyuubiConf: KyuubiConf, sql: String) { |
| 47 | + |
| 48 | + private lazy val trino = StatementClientFactory |
| 49 | + .newStatementClient(trinoContext.httpClient, trinoContext.getClientSession, sql) |
| 50 | + |
| 51 | + private lazy val dataProcessingPoolSize = kyuubiConf.get(DATA_PROCESSING_POOL_SIZE) |
| 52 | + |
| 53 | + implicit val ec: ExecutionContext = |
| 54 | + ExecutionContext.fromExecutor(Executors.newFixedThreadPool(dataProcessingPoolSize)) |
| 55 | + |
| 56 | + def getTrinoClient: StatementClient = trino |
| 57 | + |
| 58 | + def getCurrentDatabase: String = trinoContext.getClientSession.getSchema |
| 59 | + |
| 60 | + def getColumns: List[Column] = { |
| 61 | + while (trino.isRunning) { |
| 62 | + val results = trino.currentStatusInfo() |
| 63 | + val columns = results.getColumns() |
| 64 | + if (columns != null) { |
| 65 | + return columns.asScala.toList |
| 66 | + } |
| 67 | + trino.advance() |
| 68 | + } |
| 69 | + Verify.verify(trino.isFinished()) |
| 70 | + val finalStatus = trino.finalStatusInfo() |
| 71 | + if (finalStatus.getError == null) { |
| 72 | + throw KyuubiSQLException(s"Query has no columns (#${finalStatus.getId})") |
| 73 | + } else { |
| 74 | + throw KyuubiSQLException( |
| 75 | + s"Query failed (#${finalStatus.getId}): ${finalStatus.getError.getMessage}") |
| 76 | + } |
| 77 | + } |
| 78 | + |
| 79 | + /** |
| 80 | + * Execute sql and return ResultSet. |
| 81 | + */ |
| 82 | + def execute(): Iterable[List[Any]] = { |
| 83 | + val rowQueue = new ArrayBlockingQueue[List[Any]](MAX_QUEUED_ROWS) |
| 84 | + |
| 85 | + val dataProcessing = Future[Unit] { |
| 86 | + while (trino.isRunning) { |
| 87 | + val data = trino.currentData().getData() |
| 88 | + if (data != null) { |
| 89 | + data.asScala.map(_.asScala.toList) |
| 90 | + .foreach(e => putOrThrow(rowQueue, e)) |
| 91 | + } |
| 92 | + trino.advance() |
| 93 | + } |
| 94 | + } |
| 95 | + dataProcessing.onComplete { |
| 96 | + case _ => putOrThrow(rowQueue, END_TOKEN) |
| 97 | + } |
| 98 | + |
| 99 | + val rowBuffer = new ArrayList[List[Any]](MAX_BUFFERED_ROWS) |
| 100 | + var bufferStart = System.nanoTime() |
| 101 | + val result = ArrayBuffer[List[Any]]() |
| 102 | + try { |
| 103 | + breakable { |
| 104 | + while (!dataProcessing.isCompleted) { |
| 105 | + val atEnd = drainDetectingEnd(rowQueue, rowBuffer, MAX_BUFFERED_ROWS, END_TOKEN) |
| 106 | + if (!atEnd) { |
| 107 | + // Flush if needed |
| 108 | + if (rowBuffer.size() >= MAX_BUFFERED_ROWS || |
| 109 | + Duration.fromNanos(bufferStart).compareTo(MAX_BUFFER_TIME) >= 0) { |
| 110 | + result ++= rowBuffer.asScala |
| 111 | + rowBuffer.clear() |
| 112 | + bufferStart = System.nanoTime() |
| 113 | + } |
| 114 | + |
| 115 | + val row = rowQueue.poll(MAX_BUFFER_TIME.toMillis, duration.MILLISECONDS) |
| 116 | + row match { |
| 117 | + case END_TOKEN => break |
| 118 | + case null => |
| 119 | + case _ => rowBuffer.add(row) |
| 120 | + } |
| 121 | + } |
| 122 | + } |
| 123 | + } |
| 124 | + if (!rowQueue.isEmpty()) { |
| 125 | + drainDetectingEnd(rowQueue, rowBuffer, Integer.MAX_VALUE, END_TOKEN) |
| 126 | + } |
| 127 | + val finalStatus = trino.finalStatusInfo() |
| 128 | + if (finalStatus.getError() != null) { |
| 129 | + val exception = KyuubiSQLException( |
| 130 | + s"Query ${finalStatus.getId} failed: ${finalStatus.getError.getMessage}") |
| 131 | + throw exception |
| 132 | + } |
| 133 | + |
| 134 | + updateTrinoContext() |
| 135 | + } catch { |
| 136 | + case e: Exception => |
| 137 | + throw KyuubiSQLException(e) |
| 138 | + } |
| 139 | + result |
| 140 | + } |
| 141 | + |
| 142 | + def updateTrinoContext(): Unit = { |
| 143 | + val session = trinoContext.getClientSession |
| 144 | + |
| 145 | + var builder = ClientSession.builder(session) |
| 146 | + // update catalog and schema |
| 147 | + if (trino.getSetCatalog.isPresent || trino.getSetSchema.isPresent) { |
| 148 | + builder = builder |
| 149 | + .withCatalog(trino.getSetCatalog.orElse(session.getCatalog)) |
| 150 | + .withSchema(trino.getSetSchema.orElse(session.getSchema)) |
| 151 | + } |
| 152 | + |
| 153 | + // update path if present |
| 154 | + if (trino.getSetPath.isPresent) { |
| 155 | + builder = builder.withPath(trino.getSetPath.get) |
| 156 | + } |
| 157 | + |
| 158 | + // update session properties if present |
| 159 | + if (!trino.getSetSessionProperties.isEmpty || !trino.getResetSessionProperties.isEmpty) { |
| 160 | + val properties = session.getProperties.asScala.clone() |
| 161 | + properties ++= trino.getSetSessionProperties.asScala |
| 162 | + properties --= trino.getResetSessionProperties.asScala |
| 163 | + builder = builder.withProperties(properties.asJava) |
| 164 | + } |
| 165 | + |
| 166 | + trinoContext.clientSession.set(builder.build()) |
| 167 | + } |
| 168 | + |
| 169 | + private def drainDetectingEnd( |
| 170 | + rowQueue: ArrayBlockingQueue[List[Any]], |
| 171 | + buffer: ArrayList[List[Any]], |
| 172 | + maxBufferSize: Int, |
| 173 | + endToken: List[Any]): Boolean = { |
| 174 | + val drained = rowQueue.drainTo(buffer, maxBufferSize - buffer.size) |
| 175 | + if (drained > 0 && buffer.get(buffer.size() - 1) == endToken) { |
| 176 | + buffer.remove(buffer.size() - 1); |
| 177 | + true |
| 178 | + } else { |
| 179 | + false |
| 180 | + } |
| 181 | + } |
| 182 | + |
| 183 | + private def putOrThrow(rowQueue: ArrayBlockingQueue[List[Any]], e: List[Any]): Unit = { |
| 184 | + try { |
| 185 | + rowQueue.put(e) |
| 186 | + } catch { |
| 187 | + case e: InterruptedException => |
| 188 | + Thread.currentThread().interrupt() |
| 189 | + throw new RuntimeException(e) |
| 190 | + } |
| 191 | + } |
| 192 | +} |
| 193 | + |
| 194 | +object TrinoStatement { |
| 195 | + final private val MAX_QUEUED_ROWS = 50000 |
| 196 | + final private val MAX_BUFFERED_ROWS = 10000 |
| 197 | + final private val MAX_BUFFER_TIME = Duration(3, duration.SECONDS) |
| 198 | + final private val END_TOKEN = List[Any]() |
| 199 | + |
| 200 | + def apply(trinoContext: TrinoContext, kyuubiConf: KyuubiConf, sql: String): TrinoStatement = { |
| 201 | + new TrinoStatement(trinoContext, kyuubiConf, sql) |
| 202 | + } |
| 203 | +} |
0 commit comments