From ce4eeccf5aa812f0ccd02967a6612bbed476bac5 Mon Sep 17 00:00:00 2001 From: wangzhigang Date: Thu, 23 Apr 2026 00:14:33 +0800 Subject: [PATCH 1/2] [KYUUBI #7379][2b/4] Data Agent Engine: agent runtime, middleware stack, OpenAI provider, and live E2E tests This PR delivers the runtime layer of the Data Agent Engine on top of the tool system and data source plumbing from 2a/4: - ReactAgent: ReAct-style loop with streaming LLM responses, per-step tool dispatch, and AgentRunContext tracking token usage, iterations, and session. - Middleware stack (AgentMiddleware + ReactAgent.Builder): * LoggingMiddleware -- structured per-step/LLM/tool/finish logs with MDC. * ApprovalMiddleware -- CompletableFuture-based resolve for DESTRUCTIVE tools; modes NORMAL / STRICT / AUTO_APPROVE. * CompactionMiddleware -- token-threshold-triggered history summarization with KEEP_RECENT_TURNS=4, emits a Compaction AgentEvent so clients can observe the mechanism firing. * ToolResultOffloadMiddleware -- spills large tool outputs to disk and surfaces `read_tool_output` / `grep_tool_output` companion tools for the LLM to re-query truncated previews. - OpenAiProvider: single shared ReactAgent, per-session ConversationMemory, streaming chat completions, Hikari-pooled JDBC data source; reads model and thresholds from KyuubiConf. - ExecuteStatement (Scala): encodes all AgentEvents (including compaction and approval_request) as SSE JSON rows streamed through the JDBC reply column. - KyuubiConf: new keys for LLM provider/api-url/model/api-key, approval mode, compaction trigger tokens, offload root/thresholds, max iterations, etc. - Tests: * Unit tests for runtime, middlewares, offload store, and event shapes. * Live tests gated on DATA_AGENT_LLM_API_KEY covering full LLM round-trips: ReactAgentLiveTest (offload+grep, approval approve/deny), DataAgentE2ESuite and DataAgentApprovalE2ESuite (JDBC layer), DataAgentCompactionE2ESuite (JDBC-observable compaction event + post-compaction recovery), CompactionMiddlewareLiveTest. * Compatibility verified against qwen3.6-plus, glm-5, and kimi-k2.5 via per-call `model=` logging in ReactAgent. --- docs/configuration/settings.md | 1 + externals/kyuubi-data-agent-engine/pom.xml | 43 +- .../dataagent/datasource/JdbcDialect.java | 6 + .../{ => dialect}/GenericDialect.java | 4 +- .../{ => dialect}/MysqlDialect.java | 6 +- .../{ => dialect}/SparkDialect.java | 6 +- .../{ => dialect}/SqliteDialect.java | 6 +- .../{ => dialect}/TrinoDialect.java | 6 +- .../provider/ProviderRunRequest.java | 26 +- .../provider/openai/OpenAiProvider.java | 176 +++++ .../dataagent/runtime/AgentInvocation.java | 72 +++ .../dataagent/runtime/AgentRunContext.java | 112 ++++ .../dataagent/runtime/ApprovalMode.java | 28 + .../dataagent/runtime/ConversationMemory.java | 200 ++++++ .../engine/dataagent/runtime/ReactAgent.java | 606 ++++++++++++++++++ .../dataagent/runtime/ToolOutputStore.java | 242 +++++++ .../dataagent/runtime/event/Compaction.java | 69 ++ .../dataagent/runtime/event/EventType.java | 5 + .../runtime/middleware/AgentMiddleware.java | 152 +++++ .../middleware/ApprovalMiddleware.java | 152 +++++ .../middleware/CompactionMiddleware.java | 409 ++++++++++++ .../runtime/middleware/LoggingMiddleware.java | 160 +++++ .../ToolResultOffloadMiddleware.java | 191 ++++++ .../engine/dataagent/tool/AgentTool.java | 5 +- .../engine/dataagent/tool/ToolContext.java | 40 ++ .../engine/dataagent/tool/ToolRegistry.java | 149 +++-- .../tool/output/GrepToolOutputArgs.java | 39 ++ .../tool/output/GrepToolOutputTool.java | 68 ++ .../tool/output/ReadToolOutputArgs.java | 37 ++ .../tool/output/ReadToolOutputTool.java | 71 ++ .../tool/sql/RunMutationQueryTool.java | 3 +- .../tool/sql/RunSelectQueryTool.java | 3 +- .../tool/sql/SqlReadOnlyChecker.java | 4 +- .../engine/dataagent/util/ConfUtils.java | 62 ++ .../operation/ExecuteStatement.scala | 13 +- .../dataagent/datasource/JdbcDialectTest.java | 1 + .../engine/dataagent/mysql/DialectTest.java | 5 +- .../provider/mock/MockLlmProvider.java | 185 ++++++ .../runtime/ConversationMemoryTest.java | 47 ++ .../dataagent/runtime/ReactAgentLiveTest.java | 568 ++++++++++++++++ .../runtime/ToolOutputStoreTest.java | 116 ++++ .../dataagent/runtime/event/EventTest.java | 3 +- .../middleware/ApprovalMiddlewareTest.java | 294 +++++++++ .../CompactionMiddlewareLiveTest.java | 100 +++ .../middleware/CompactionMiddlewareTest.java | 322 ++++++++++ .../ToolResultOffloadMiddlewareTest.java | 141 ++++ .../tool/ToolRegistryThreadSafetyTest.java | 6 +- .../engine/dataagent/tool/ToolTest.java | 4 +- .../tool/sql/RunMutationQueryToolTest.java | 17 +- .../tool/sql/RunSelectQueryToolTest.java | 45 +- .../DataAgentCompactionE2ESuite.scala | 195 ++++++ .../operation/DataAgentE2ESuite.scala | 92 ++- .../org/apache/kyuubi/config/KyuubiConf.scala | 15 + 53 files changed, 5187 insertions(+), 141 deletions(-) rename externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/{ => dialect}/GenericDialect.java (92%) rename externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/{ => dialect}/MysqlDialect.java (85%) rename externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/{ => dialect}/SparkDialect.java (85%) rename externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/{ => dialect}/SqliteDialect.java (85%) rename externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/{ => dialect}/TrinoDialect.java (85%) create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/openai/OpenAiProvider.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentInvocation.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentRunContext.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ApprovalMode.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemory.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgent.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStore.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/Compaction.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/AgentMiddleware.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddleware.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddleware.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/LoggingMiddleware.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddleware.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolContext.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputArgs.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputTool.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputArgs.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputTool.java create mode 100644 externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/util/ConfUtils.java create mode 100644 externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/provider/mock/MockLlmProvider.java create mode 100644 externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemoryTest.java create mode 100644 externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgentLiveTest.java create mode 100644 externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStoreTest.java create mode 100644 externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddlewareTest.java create mode 100644 externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareLiveTest.java create mode 100644 externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareTest.java create mode 100644 externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddlewareTest.java create mode 100644 externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentCompactionE2ESuite.scala diff --git a/docs/configuration/settings.md b/docs/configuration/settings.md index 3e34882b28f..a2910abfca6 100644 --- a/docs/configuration/settings.md +++ b/docs/configuration/settings.md @@ -144,6 +144,7 @@ You can configure the Kyuubi properties in `$KYUUBI_HOME/conf/kyuubi-defaults.co | kyuubi.engine.chat.provider | ECHO | The provider for the Chat engine. Candidates: | string | 1.8.0 | | kyuubi.engine.connection.url.use.hostname | true | (deprecated) When true, the engine registers with hostname to zookeeper. When Spark runs on K8s with cluster mode, set to false to ensure that server can connect to engine | boolean | 1.3.0 | | kyuubi.engine.data.agent.approval.mode | NORMAL | Default approval mode for tool execution in the Data Agent engine. Candidates: | string | 1.12.0 | +| kyuubi.engine.data.agent.compaction.trigger.tokens | 128000 | The prompt-token threshold above which the Data Agent's compaction middleware summarizes older conversation history into a compact message. The check is made each turn as real_prompt_tokens_of_previous_LLM_call + estimate_of_newly_appended_tail; when this predicted prompt size reaches the configured value, older messages are replaced by a single summary message while the most recent exchanges are kept verbatim. Set to a very large value (e.g., 9223372036854775807) to effectively disable compaction. | long | 1.12.0 | | kyuubi.engine.data.agent.extra.classpath | <undefined> | The extra classpath for the Data Agent engine, for configuring the location of the LLM SDK and etc. | string | 1.12.0 | | kyuubi.engine.data.agent.java.options | <undefined> | The extra Java options for the Data Agent engine | string | 1.12.0 | | kyuubi.engine.data.agent.jdbc.url | <undefined> | The JDBC URL for the Data Agent engine to connect to the target database. If not set, the Data Agent will connect back to Kyuubi server via ZooKeeper service discovery. | string | 1.12.0 | diff --git a/externals/kyuubi-data-agent-engine/pom.xml b/externals/kyuubi-data-agent-engine/pom.xml index c34d049360c..74da5005784 100644 --- a/externals/kyuubi-data-agent-engine/pom.xml +++ b/externals/kyuubi-data-agent-engine/pom.xml @@ -50,19 +50,48 @@ ${project.version} + com.openai openai-java + ${openai.sdk.version} + com.github.victools jsonschema-generator + ${victools.jsonschema.version} - com.github.victools jsonschema-module-jackson + ${victools.jsonschema.version} + + + + + org.xerial + sqlite-jdbc + ${sqlite.version} + + + + + com.mysql + mysql-connector-j + + + + + io.trino + trino-jdbc + + + + + com.zaxxer + HikariCP @@ -74,24 +103,12 @@ test - - org.xerial - sqlite-jdbc - test - - org.testcontainers testcontainers-mysql test - - com.mysql - mysql-connector-j - test - - junit junit diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialect.java index c3be1dad61a..c771ad222aa 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialect.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialect.java @@ -17,6 +17,12 @@ package org.apache.kyuubi.engine.dataagent.datasource; +import org.apache.kyuubi.engine.dataagent.datasource.dialect.GenericDialect; +import org.apache.kyuubi.engine.dataagent.datasource.dialect.MysqlDialect; +import org.apache.kyuubi.engine.dataagent.datasource.dialect.SparkDialect; +import org.apache.kyuubi.engine.dataagent.datasource.dialect.SqliteDialect; +import org.apache.kyuubi.engine.dataagent.datasource.dialect.TrinoDialect; + /** * SQL dialect abstraction for datasource-specific SQL generation. * diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/GenericDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/GenericDialect.java similarity index 92% rename from externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/GenericDialect.java rename to externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/GenericDialect.java index 3ea22ed54e3..d8c4512de03 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/GenericDialect.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/GenericDialect.java @@ -15,7 +15,9 @@ * limitations under the License. */ -package org.apache.kyuubi.engine.dataagent.datasource; +package org.apache.kyuubi.engine.dataagent.datasource.dialect; + +import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect; /** * Fallback dialect for JDBC subprotocols that have no dedicated implementation. Carries the diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/MysqlDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/MysqlDialect.java similarity index 85% rename from externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/MysqlDialect.java rename to externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/MysqlDialect.java index 98747ffa30c..350789a6a87 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/MysqlDialect.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/MysqlDialect.java @@ -15,12 +15,14 @@ * limitations under the License. */ -package org.apache.kyuubi.engine.dataagent.datasource; +package org.apache.kyuubi.engine.dataagent.datasource.dialect; + +import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect; /** MySQL dialect. Uses backtick quoting for identifiers. */ public final class MysqlDialect implements JdbcDialect { - static final MysqlDialect INSTANCE = new MysqlDialect(); + public static final MysqlDialect INSTANCE = new MysqlDialect(); private MysqlDialect() {} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SparkDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SparkDialect.java similarity index 85% rename from externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SparkDialect.java rename to externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SparkDialect.java index 3adb43fa398..34e20034bfb 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SparkDialect.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SparkDialect.java @@ -15,12 +15,14 @@ * limitations under the License. */ -package org.apache.kyuubi.engine.dataagent.datasource; +package org.apache.kyuubi.engine.dataagent.datasource.dialect; + +import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect; /** Spark SQL dialect. Uses backtick quoting for identifiers. */ public final class SparkDialect implements JdbcDialect { - static final SparkDialect INSTANCE = new SparkDialect(); + public static final SparkDialect INSTANCE = new SparkDialect(); private SparkDialect() {} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SqliteDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SqliteDialect.java similarity index 85% rename from externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SqliteDialect.java rename to externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SqliteDialect.java index a53255a9c67..eb98ca8edfa 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SqliteDialect.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SqliteDialect.java @@ -15,12 +15,14 @@ * limitations under the License. */ -package org.apache.kyuubi.engine.dataagent.datasource; +package org.apache.kyuubi.engine.dataagent.datasource.dialect; + +import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect; /** SQLite dialect. Uses double-quote quoting for identifiers. */ public final class SqliteDialect implements JdbcDialect { - static final SqliteDialect INSTANCE = new SqliteDialect(); + public static final SqliteDialect INSTANCE = new SqliteDialect(); private SqliteDialect() {} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/TrinoDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/TrinoDialect.java similarity index 85% rename from externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/TrinoDialect.java rename to externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/TrinoDialect.java index edacf2f87e2..75fbd4bb242 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/TrinoDialect.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/TrinoDialect.java @@ -15,12 +15,14 @@ * limitations under the License. */ -package org.apache.kyuubi.engine.dataagent.datasource; +package org.apache.kyuubi.engine.dataagent.datasource.dialect; + +import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect; /** Trino SQL dialect. Uses double-quote quoting for identifiers. */ public final class TrinoDialect implements JdbcDialect { - static final TrinoDialect INSTANCE = new TrinoDialect(); + public static final TrinoDialect INSTANCE = new TrinoDialect(); private TrinoDialect() {} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/ProviderRunRequest.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/ProviderRunRequest.java index f4e40b2fae8..26ad8be77fb 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/ProviderRunRequest.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/ProviderRunRequest.java @@ -17,13 +17,23 @@ package org.apache.kyuubi.engine.dataagent.provider; +import org.apache.kyuubi.engine.dataagent.runtime.ApprovalMode; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + /** * User-facing request parameters for a provider-level agent invocation. Only contains fields from * the caller (question, model override, etc.). Adding new per-request options does not require * changing the {@link DataAgentProvider} interface. + * + *

The approval mode is accepted as a raw string (natural for config-driven callers) and parsed + * into {@link ApprovalMode} by {@link #getApprovalMode()}. Unrecognised values fall back to {@link + * ApprovalMode#NORMAL} with a warning. */ public class ProviderRunRequest { + private static final Logger LOG = LoggerFactory.getLogger(ProviderRunRequest.class); + private final String question; private String modelName; private String approvalMode; @@ -45,8 +55,20 @@ public ProviderRunRequest modelName(String modelName) { return this; } - public String getApprovalMode() { - return approvalMode; + /** + * Resolved approval mode. Returns {@link ApprovalMode#NORMAL} when the caller did not set one or + * supplied an unknown value. + */ + public ApprovalMode getApprovalMode() { + if (approvalMode == null || approvalMode.isEmpty()) { + return ApprovalMode.NORMAL; + } + try { + return ApprovalMode.valueOf(approvalMode.toUpperCase()); + } catch (IllegalArgumentException e) { + LOG.warn("Unknown approval mode '{}', using default NORMAL", approvalMode); + return ApprovalMode.NORMAL; + } } public ProviderRunRequest approvalMode(String approvalMode) { diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/openai/OpenAiProvider.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/openai/OpenAiProvider.java new file mode 100644 index 00000000000..bcd647b9326 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/openai/OpenAiProvider.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.provider.openai; + +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.zaxxer.hikari.HikariDataSource; +import java.time.Duration; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; +import javax.sql.DataSource; +import org.apache.kyuubi.config.KyuubiConf; +import org.apache.kyuubi.config.KyuubiReservedKeys; +import org.apache.kyuubi.engine.dataagent.datasource.DataSourceFactory; +import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect; +import org.apache.kyuubi.engine.dataagent.prompt.SystemPromptBuilder; +import org.apache.kyuubi.engine.dataagent.provider.DataAgentProvider; +import org.apache.kyuubi.engine.dataagent.provider.ProviderRunRequest; +import org.apache.kyuubi.engine.dataagent.runtime.AgentInvocation; +import org.apache.kyuubi.engine.dataagent.runtime.ConversationMemory; +import org.apache.kyuubi.engine.dataagent.runtime.ReactAgent; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.ApprovalMiddleware; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.CompactionMiddleware; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.LoggingMiddleware; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.ToolResultOffloadMiddleware; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; +import org.apache.kyuubi.engine.dataagent.tool.sql.RunMutationQueryTool; +import org.apache.kyuubi.engine.dataagent.tool.sql.RunSelectQueryTool; +import org.apache.kyuubi.engine.dataagent.util.ConfUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An OpenAI-compatible provider that wires up the full ReactAgent with streaming LLM, tools, and + * middleware pipeline. Uses the official OpenAI Java SDK. + * + *

The ReactAgent, DataSource, and ToolRegistry are shared across all sessions within this engine + * instance. Each session only maintains its own {@link ConversationMemory}. This works because each + * engine is bound to one user + one datasource, so all sessions within the engine naturally share + * the same data connection. + */ +public class OpenAiProvider implements DataAgentProvider { + + private static final Logger LOG = LoggerFactory.getLogger(OpenAiProvider.class); + + private final ReactAgent agent; + private final ToolRegistry toolRegistry; + private final DataSource dataSource; + private final OpenAIClient client; + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + public OpenAiProvider(KyuubiConf conf) { + String apiKey = ConfUtils.requireString(conf, KyuubiConf.ENGINE_DATA_AGENT_LLM_API_KEY()); + String baseUrl = ConfUtils.requireString(conf, KyuubiConf.ENGINE_DATA_AGENT_LLM_API_URL()); + String modelName = ConfUtils.requireString(conf, KyuubiConf.ENGINE_DATA_AGENT_LLM_MODEL()); + + int maxIterations = ConfUtils.intConf(conf, KyuubiConf.ENGINE_DATA_AGENT_MAX_ITERATIONS()); + long compactionTriggerTokens = + ConfUtils.longConf(conf, KyuubiConf.ENGINE_DATA_AGENT_COMPACTION_TRIGGER_TOKENS()); + int queryTimeoutSeconds = + (int) ConfUtils.millisAsSeconds(conf, KyuubiConf.ENGINE_DATA_AGENT_QUERY_TIMEOUT()); + long toolCallTimeoutSeconds = + ConfUtils.millisAsSeconds(conf, KyuubiConf.ENGINE_DATA_AGENT_TOOL_CALL_TIMEOUT()); + + this.client = + OpenAIOkHttpClient.builder() + .apiKey(apiKey) + .baseUrl(baseUrl) + .maxRetries(3) + .timeout(Duration.ofSeconds(180)) + .build(); + + this.toolRegistry = new ToolRegistry(toolCallTimeoutSeconds); + + SystemPromptBuilder promptBuilder = SystemPromptBuilder.create(); + this.dataSource = attachJdbcDataSource(conf, toolRegistry, promptBuilder, queryTimeoutSeconds); + + this.agent = + ReactAgent.builder() + .client(client) + .modelName(modelName) + .toolRegistry(toolRegistry) + .addMiddleware(new ToolResultOffloadMiddleware()) + .addMiddleware(new LoggingMiddleware()) + .addMiddleware(new CompactionMiddleware(client, modelName, compactionTriggerTokens)) + .addMiddleware(new ApprovalMiddleware()) + .maxIterations(maxIterations) + .systemPrompt(promptBuilder.build()) + .build(); + } + + /** + * Register JDBC-backed SQL tools if a JDBC URL is configured. Returns the created {@link + * DataSource} so the provider can close it on shutdown, or {@code null} when no JDBC is wired. + */ + private static DataSource attachJdbcDataSource( + KyuubiConf conf, + ToolRegistry registry, + SystemPromptBuilder promptBuilder, + int queryTimeoutSeconds) { + String jdbcUrl = ConfUtils.optionalString(conf, KyuubiConf.ENGINE_DATA_AGENT_JDBC_URL()); + if (jdbcUrl == null) { + return null; + } + LOG.info("Data Agent JDBC URL configured ({})", jdbcUrl.replaceAll("//.*@", "//@")); + + String sessionUser = + ConfUtils.optionalString(conf, KyuubiReservedKeys.KYUUBI_SESSION_USER_KEY()); + + DataSource ds = DataSourceFactory.create(jdbcUrl, sessionUser); + registry.register(new RunSelectQueryTool(ds, queryTimeoutSeconds)); + registry.register(new RunMutationQueryTool(ds, queryTimeoutSeconds)); + promptBuilder.datasource(JdbcDialect.fromUrl(jdbcUrl).datasourceName()); + return ds; + } + + @Override + public void open(String sessionId, String user) { + sessions.put(sessionId, new ConversationMemory()); + LOG.info("Opened Data Agent session {} for user {}", sessionId, user); + } + + @Override + public void run(String sessionId, ProviderRunRequest request, Consumer onEvent) { + ConversationMemory memory = sessions.get(sessionId); + if (memory == null) { + throw new IllegalStateException("No open Data Agent session for id=" + sessionId); + } + + AgentInvocation invocation = + new AgentInvocation(request.getQuestion()) + .modelName(request.getModelName()) + .approvalMode(request.getApprovalMode()) + .sessionId(sessionId); + agent.run(invocation, memory, onEvent); + } + + @Override + public boolean resolveApproval(String requestId, boolean approved) { + return agent.resolveApproval(requestId, approved); + } + + @Override + public void close(String sessionId) { + sessions.remove(sessionId); + agent.closeSession(sessionId); + LOG.info("Closed Data Agent session {}", sessionId); + } + + @Override + public void stop() { + agent.stop(); + toolRegistry.close(); + if (dataSource instanceof HikariDataSource) { + ((HikariDataSource) dataSource).close(); + LOG.info("Closed Data Agent connection pool"); + } + client.close(); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentInvocation.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentInvocation.java new file mode 100644 index 00000000000..0695d556058 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentInvocation.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +import java.util.Objects; + +/** + * User-facing request parameters for a single agent invocation. Only contains fields that come from + * the caller (question, model override, etc.). Framework-level concerns like memory and event + * consumer are separate method parameters. + * + *

Adding new per-request options (e.g. temperature, maxTokens) does not require changing the + * {@code ReactAgent.run()} signature. + */ +public class AgentInvocation { + + private final String userInput; + private String modelName; + private ApprovalMode approvalMode = ApprovalMode.NORMAL; + private String sessionId; + + public AgentInvocation(String userInput) { + this.userInput = Objects.requireNonNull(userInput, "userInput must not be null"); + } + + public String getUserInput() { + return userInput; + } + + public String getModelName() { + return modelName; + } + + public AgentInvocation modelName(String modelName) { + this.modelName = modelName; + return this; + } + + public ApprovalMode getApprovalMode() { + return approvalMode; + } + + public AgentInvocation approvalMode(ApprovalMode approvalMode) { + this.approvalMode = approvalMode; + return this; + } + + public String getSessionId() { + return sessionId; + } + + /** Upstream session id, propagated into {@link AgentRunContext#getSessionId()}. */ + public AgentInvocation sessionId(String sessionId) { + this.sessionId = sessionId; + return this; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentRunContext.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentRunContext.java new file mode 100644 index 00000000000..e7c92df8033 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentRunContext.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +import java.util.function.Consumer; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent; + +/** + * Mutable context passed through the middleware pipeline and agent loop. Tracks the current state + * of agent execution including iteration count, token usage, and custom middleware state. + */ +public class AgentRunContext { + + private final ConversationMemory memory; + private final String sessionId; + private Consumer eventEmitter; + private int iteration; + private long promptTokens; + private long completionTokens; + private long totalTokens; + private ApprovalMode approvalMode; + + public AgentRunContext(ConversationMemory memory, ApprovalMode approvalMode) { + this(memory, approvalMode, null); + } + + public AgentRunContext(ConversationMemory memory, ApprovalMode approvalMode, String sessionId) { + this.memory = memory; + this.iteration = 0; + this.approvalMode = approvalMode; + this.sessionId = sessionId; + } + + public ConversationMemory getMemory() { + return memory; + } + + /** + * The upstream session identifier this run belongs to. Threaded down from {@code + * DataAgentProvider.run(sessionId, ...)}. May be {@code null} in unit tests that do not exercise + * session-scoped middleware. + */ + public String getSessionId() { + return sessionId; + } + + public int getIteration() { + return iteration; + } + + public void setIteration(int iteration) { + this.iteration = iteration; + } + + public long getPromptTokens() { + return promptTokens; + } + + public long getCompletionTokens() { + return completionTokens; + } + + public long getTotalTokens() { + return totalTokens; + } + + /** + * Record one LLM call's usage. Updates both the per-run counters on this context and the + * session-level cumulative on the underlying {@link ConversationMemory}, so middlewares that need + * a session-wide picture can read it directly from memory without keeping their own bookkeeping. + */ + public void addTokenUsage(long prompt, long completion, long total) { + this.promptTokens += prompt; + this.completionTokens += completion; + this.totalTokens += total; + memory.addCumulativeTokens(prompt, completion, total); + } + + public ApprovalMode getApprovalMode() { + return approvalMode; + } + + public void setApprovalMode(ApprovalMode approvalMode) { + this.approvalMode = approvalMode; + } + + public void setEventEmitter(Consumer eventEmitter) { + this.eventEmitter = eventEmitter; + } + + /** Emit an event through the agent's event pipeline. Available for use by middlewares. */ + public void emit(AgentEvent event) { + if (eventEmitter != null) { + eventEmitter.accept(event); + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ApprovalMode.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ApprovalMode.java new file mode 100644 index 00000000000..57bc20bc2bb --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ApprovalMode.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +/** Approval modes for tool execution in the Data Agent engine. */ +public enum ApprovalMode { + /** All tools require explicit user approval. */ + STRICT, + /** Only non-readonly tools require approval. */ + NORMAL, + /** All tools are auto-approved. */ + AUTO_APPROVE +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemory.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemory.java new file mode 100644 index 00000000000..0bfae26ec27 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemory.java @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import com.openai.models.chat.completions.ChatCompletionSystemMessageParam; +import com.openai.models.chat.completions.ChatCompletionToolMessageParam; +import com.openai.models.chat.completions.ChatCompletionUserMessageParam; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * Manages conversation history for a Data Agent session. Ensures tool result messages are never + * orphaned from their corresponding AI messages. + * + *

Each instance is session-scoped and accessed sequentially within a single ReAct loop — no + * synchronization is needed. Cross-session concurrency is handled by the provider's session map. + */ +public class ConversationMemory { + + /** + * System prompt prepended to every LLM call built from this memory. Rebuilt by the provider on + * each invocation (datasource/tool metadata can change between turns), so it lives outside the + * {@link #messages} list rather than being inserted as the first entry. May be {@code null} until + * the first {@link #setSystemPrompt} call — {@link #buildLlmMessages()} will simply omit the + * system slot in that case. + */ + private String systemPrompt; + + /** + * The raw content of the most recent user message added via {@link #addUserMessage}. Cached + * separately from {@link #messages} so middleware and callers can recover the current turn's + * question after compaction rewrites history. Not used by the LLM call path itself. + */ + private String lastUserInput; + + /** + * The ordered conversation history: user / assistant / tool messages, in the order the LLM will + * see them. The system prompt is intentionally NOT stored here (see {@link #systemPrompt}). + * Mutated in place by {@link #replaceHistory} during compaction; otherwise append-only. + */ + private final List messages = new ArrayList<>(); + + /** + * Session-level running total of {@code prompt_tokens} reported by every LLM call on this + * conversation (across ReAct turns). Intended for billing, quota, and observability — not used by + * any runtime decision. Updated via {@link #addCumulativeTokens}. + */ + private long cumulativePromptTokens; + + /** + * Session-level running total of {@code completion_tokens}. See {@link #cumulativePromptTokens}. + */ + private long cumulativeCompletionTokens; + + /** + * Session-level running total of {@code total_tokens} (prompt + completion as reported by the + * provider — not necessarily the sum of the two counters above, since providers may count + * cached/reasoning tokens differently). See {@link #cumulativePromptTokens}. + */ + private long cumulativeTotalTokens; + + /** + * The {@code total_tokens} reported by the single most recent LLM call, or {@code 0} if no call + * has completed yet. Distinct from the cumulative counters: this is a snapshot, overwritten every + * call. Used by {@link + * org.apache.kyuubi.engine.dataagent.runtime.middleware.CompactionMiddleware} to estimate the + * next prompt size (the last response becomes part of the next prompt, so the next call's prompt + * is at least {@code lastTotalTokens}). Persists across ReAct turns until the next call + * overwrites it. + */ + private long lastTotalTokens; + + public ConversationMemory() {} + + public String getSystemPrompt() { + return systemPrompt; + } + + public void setSystemPrompt(String prompt) { + this.systemPrompt = prompt; + } + + public void addUserMessage(String content) { + this.lastUserInput = content; + messages.add( + ChatCompletionMessageParam.ofUser( + ChatCompletionUserMessageParam.builder().content(content).build())); + } + + public String getLastUserInput() { + return lastUserInput; + } + + public void addAssistantMessage(ChatCompletionAssistantMessageParam message) { + messages.add(ChatCompletionMessageParam.ofAssistant(message)); + } + + public void addToolResult(String toolCallId, String content) { + messages.add( + ChatCompletionMessageParam.ofTool( + ChatCompletionToolMessageParam.builder() + .toolCallId(toolCallId) + .content(content) + .build())); + } + + /** + * Build the full message list for LLM API invocation: [system prompt] + conversation history. + * + *

No windowing is applied — callers are responsible for managing context length (e.g. via a + * token-based truncation strategy). + * + * @see #getHistory() for history-only access without system prompt + */ + public List buildLlmMessages() { + List result = new ArrayList<>(messages.size() + 1); + if (systemPrompt != null) { + result.add( + ChatCompletionMessageParam.ofSystem( + ChatCompletionSystemMessageParam.builder().content(systemPrompt).build())); + } + result.addAll(messages); + return Collections.unmodifiableList(result); + } + + /** + * Returns the conversation history (user, assistant, tool messages) without the system prompt. + * Useful for middleware that needs to inspect or compact history. + */ + public List getHistory() { + return Collections.unmodifiableList(new ArrayList<>(messages)); + } + + /** + * Replace the conversation history with a compacted version. Useful for context-length management + * strategies (e.g., summarizing older messages). + * + *

Also clears {@link #lastTotalTokens}: the prior snapshot referred to a prompt whose bulk we + * just discarded, so it no longer describes anything in memory. Leaving it stale would keep the + * compaction trigger armed until the next successful LLM call overwrites it — fine on the happy + * path, but if that call fails the next turn would re-enter compaction against already-compacted + * history. Zeroing means "unknown, wait for the next real usage report". Cumulative totals are + * intentionally preserved (session-level accounting, must not regress on internal compaction). + */ + public void replaceHistory(List compacted) { + messages.clear(); + messages.addAll(compacted); + this.lastTotalTokens = 0; + } + + public void clear() { + messages.clear(); + } + + public int size() { + return messages.size(); + } + + public long getCumulativePromptTokens() { + return cumulativePromptTokens; + } + + public long getCumulativeCompletionTokens() { + return cumulativeCompletionTokens; + } + + public long getCumulativeTotalTokens() { + return cumulativeTotalTokens; + } + + public long getLastTotalTokens() { + return lastTotalTokens; + } + + /** Add one LLM call's usage to the session cumulative. Intended for {@link AgentRunContext}. */ + public void addCumulativeTokens(long prompt, long completion, long total) { + this.cumulativePromptTokens += prompt; + this.cumulativeCompletionTokens += completion; + this.cumulativeTotalTokens += total; + this.lastTotalTokens = total; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgent.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgent.java new file mode 100644 index 00000000000..520cd963ef8 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgent.java @@ -0,0 +1,606 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.openai.client.OpenAIClient; +import com.openai.core.http.StreamResponse; +import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; +import com.openai.models.chat.completions.ChatCompletionChunk; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import com.openai.models.chat.completions.ChatCompletionMessageFunctionToolCall; +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageToolCall; +import com.openai.models.chat.completions.ChatCompletionStreamOptions; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentError; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentFinish; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentStart; +import org.apache.kyuubi.engine.dataagent.runtime.event.ContentComplete; +import org.apache.kyuubi.engine.dataagent.runtime.event.ContentDelta; +import org.apache.kyuubi.engine.dataagent.runtime.event.StepEnd; +import org.apache.kyuubi.engine.dataagent.runtime.event.StepStart; +import org.apache.kyuubi.engine.dataagent.runtime.event.ToolCall; +import org.apache.kyuubi.engine.dataagent.runtime.event.ToolResult; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.AgentMiddleware; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.ApprovalMiddleware; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * ReAct (Reasoning + Acting) agent loop using the OpenAI official Java SDK. Iterates through LLM + * reasoning, tool execution, and result verification until the agent produces a final answer or + * hits the iteration limit. + * + *

Emits {@link AgentEvent}s via the provided consumer for real-time token-level streaming. + */ +public class ReactAgent { + + private static final Logger LOG = LoggerFactory.getLogger(ReactAgent.class); + private static final ObjectMapper JSON = new ObjectMapper(); + + private final OpenAIClient client; + private final String defaultModelName; + private final ToolRegistry toolRegistry; + private final List middlewares; + private final ApprovalMiddleware approvalMiddleware; + private final int maxIterations; + private final String systemPrompt; + + public ReactAgent( + OpenAIClient client, + String modelName, + ToolRegistry toolRegistry, + List middlewares, + int maxIterations, + String systemPrompt) { + this.client = client; + this.defaultModelName = modelName; + this.toolRegistry = toolRegistry; + this.middlewares = middlewares; + this.approvalMiddleware = findApprovalMiddleware(middlewares); + this.maxIterations = maxIterations; + this.systemPrompt = systemPrompt; + } + + private static ApprovalMiddleware findApprovalMiddleware(List middlewares) { + for (AgentMiddleware mw : middlewares) { + if (mw instanceof ApprovalMiddleware) return (ApprovalMiddleware) mw; + } + return null; + } + + /** Resolve a pending approval request. Returns false if no pending request matches. */ + public boolean resolveApproval(String requestId, boolean approved) { + if (approvalMiddleware == null) return false; + return approvalMiddleware.resolve(requestId, approved); + } + + /** Fan out session-close to every middleware. Errors in one middleware don't block others. */ + public void closeSession(String sessionId) { + for (AgentMiddleware mw : middlewares) { + try { + mw.onSessionClose(sessionId); + } catch (Exception e) { + LOG.warn("Middleware onSessionClose error", e); + } + } + } + + /** Fan out engine-stop to every middleware. Errors in one middleware don't block others. */ + public void stop() { + for (AgentMiddleware mw : middlewares) { + try { + mw.onStop(); + } catch (Exception e) { + LOG.warn("Middleware onStop error", e); + } + } + } + + /** + * Run the ReAct loop for the given request. + * + * @param request user-facing parameters (question, model override, etc.) + * @param memory the conversation memory (may contain prior context) + * @param eventConsumer callback for each agent event (token-level streaming) + */ + public void run( + AgentInvocation request, ConversationMemory memory, Consumer eventConsumer) { + String userInput = request.getUserInput(); + ApprovalMode approvalMode = request.getApprovalMode(); + String modelNameOverride = request.getModelName(); + + String effectiveModel = + (modelNameOverride != null && !modelNameOverride.isEmpty()) + ? modelNameOverride + : defaultModelName; + + // System prompt is immutable for the lifetime of this agent — only set it on first run + // to avoid redundant overwrites on multi-turn conversations. + if (memory.getSystemPrompt() == null) { + memory.setSystemPrompt(systemPrompt); + } + memory.addUserMessage(userInput); + + AgentRunContext ctx = new AgentRunContext(memory, approvalMode, request.getSessionId()); + ctx.setEventEmitter(event -> emit(ctx, event, eventConsumer)); + dispatchAgentStart(ctx); + emit(ctx, new AgentStart(), eventConsumer); + + try { + for (int step = 1; step <= maxIterations; step++) { + ctx.setIteration(step); + emit(ctx, new StepStart(step), eventConsumer); + + List messages = + resolveMessagesForCall(ctx, memory.buildLlmMessages(), eventConsumer); + if (messages == null) { + // Middleware asked us to skip — AgentError + AgentFinish have already been emitted. + return; + } + + StreamResult result = streamLlmResponse(ctx, messages, effectiveModel, eventConsumer); + if (result.isEmpty()) { + emit(ctx, new AgentError("LLM returned empty response"), eventConsumer); + emitFinish(ctx, eventConsumer); + return; + } + + if (!result.content.isEmpty()) { + emit(ctx, new ContentComplete(result.content), eventConsumer); + } + ChatCompletionAssistantMessageParam assistantMsg = buildAssistantMessage(result); + memory.addAssistantMessage(assistantMsg); + dispatchAfterLlmCall(ctx, assistantMsg); + + if (result.toolCalls == null || result.toolCalls.isEmpty()) { + // No tool calls — agent is done. + emit(ctx, new StepEnd(step), eventConsumer); + emitFinish(ctx, eventConsumer); + return; + } + + executeToolCalls(ctx, memory, result.toolCalls, eventConsumer); + emit(ctx, new StepEnd(step), eventConsumer); + } + + emit( + ctx, new AgentError("Reached maximum iterations (" + maxIterations + ")"), eventConsumer); + emitFinish(ctx, eventConsumer); + + } catch (Exception e) { + LOG.error("Agent run failed", e); + emit( + ctx, new AgentError(e.getClass().getSimpleName() + ": " + e.getMessage()), eventConsumer); + emitFinish(ctx, eventConsumer); + } finally { + dispatchAgentFinish(ctx); + } + } + + /** + * Run {@code beforeLlmCall} middleware against {@code messages}. Returns the messages to send, + * possibly rewritten by middleware, or {@code null} if middleware aborted the call (in which case + * this method has already emitted the terminal events). + */ + private List resolveMessagesForCall( + AgentRunContext ctx, + List messages, + Consumer eventConsumer) { + AgentMiddleware.LlmCallAction action = dispatchBeforeLlmCall(ctx, messages); + if (action instanceof AgentMiddleware.LlmSkip) { + String reason = ((AgentMiddleware.LlmSkip) action).reason(); + emit(ctx, new AgentError("LLM call skipped by middleware: " + reason), eventConsumer); + emitFinish(ctx, eventConsumer); + return null; + } + if (action instanceof AgentMiddleware.LlmModifyMessages) { + return ((AgentMiddleware.LlmModifyMessages) action).messages(); + } + return messages; + } + + private static ChatCompletionAssistantMessageParam buildAssistantMessage(StreamResult result) { + ChatCompletionAssistantMessageParam.Builder b = ChatCompletionAssistantMessageParam.builder(); + if (!result.content.isEmpty()) { + b.content(result.content); + } + if (result.toolCalls != null && !result.toolCalls.isEmpty()) { + b.toolCalls(result.toolCalls); + } + return b.build(); + } + + /** + * Execute the assistant's tool calls in 3 phases: + * + *

    + *
  1. Serial: run {@code beforeToolCall} middleware, emit {@link ToolCall} events, and collect + * the calls that survived approval. + *
  2. Concurrent: fan out to {@link ToolRegistry#submitTool}, which always returns a future + * that completes normally — timeouts and execution errors surface as error strings. + *
  3. Serial: join futures in order, run {@code afterToolCall}, and record results to memory. + *
+ */ + private void executeToolCalls( + AgentRunContext ctx, + ConversationMemory memory, + List toolCalls, + Consumer eventConsumer) { + List approved = new ArrayList<>(); + for (ChatCompletionMessageToolCall toolCall : toolCalls) { + ChatCompletionMessageFunctionToolCall fnCall = toolCall.asFunction(); + String toolName = fnCall.function().name(); + Map toolArgs; + try { + toolArgs = parseToolArgs(fnCall.function().arguments()); + } catch (IllegalArgumentException e) { + // Malformed JSON from the LLM: record an error tool_result (preserves the + // assistant/tool_result pairing the next API call needs) and let the loop self-correct. + String err = "Tool call failed: " + e.getMessage(); + memory.addToolResult(fnCall.id(), err); + emit(ctx, new ToolResult(fnCall.id(), toolName, err, true), eventConsumer); + continue; + } + + AgentMiddleware.ToolCallDenial denial = + dispatchBeforeToolCall(ctx, fnCall.id(), toolName, toolArgs); + if (denial != null) { + String denied = "Tool call denied: " + denial.reason(); + memory.addToolResult(fnCall.id(), denied); + emit(ctx, new ToolResult(fnCall.id(), toolName, denied, true), eventConsumer); + continue; + } + + emit(ctx, new ToolCall(fnCall.id(), toolName, toolArgs), eventConsumer); + approved.add(new ToolCallEntry(fnCall, toolName, toolArgs)); + } + + ToolContext toolCtx = new ToolContext(ctx.getSessionId()); + List> futures = new ArrayList<>(approved.size()); + for (ToolCallEntry entry : approved) { + futures.add( + toolRegistry.submitTool(entry.toolName, entry.fnCall.function().arguments(), toolCtx)); + } + + for (int i = 0; i < approved.size(); i++) { + ToolCallEntry entry = approved.get(i); + String output = futures.get(i).join(); + String modified = dispatchAfterToolCall(ctx, entry.toolName, entry.toolArgs, output); + if (modified != null) { + output = modified; + } + memory.addToolResult(entry.fnCall.id(), output); + emit(ctx, new ToolResult(entry.fnCall.id(), entry.toolName, output, false), eventConsumer); + } + } + + /** Result of a streaming LLM call, assembled from chunks. */ + private static class StreamResult { + final String content; + final List toolCalls; + + StreamResult(String content, List toolCalls) { + this.content = content; + this.toolCalls = toolCalls; + } + + boolean isEmpty() { + return content.isEmpty() && (toolCalls == null || toolCalls.isEmpty()); + } + } + + /** Holds an approved tool call's parsed metadata for the 3-phase execution pipeline. */ + private static class ToolCallEntry { + final ChatCompletionMessageFunctionToolCall fnCall; + final String toolName; + final Map toolArgs; + + ToolCallEntry( + ChatCompletionMessageFunctionToolCall fnCall, + String toolName, + Map toolArgs) { + this.fnCall = fnCall; + this.toolName = toolName; + this.toolArgs = toolArgs; + } + } + + /** + * Stream LLM response, emitting ContentDelta for each text chunk. Assembles tool calls directly + * from streamed chunks — no non-streaming fallback. Exceptions propagate to the top-level handler + * in {@link #run}. + */ + private StreamResult streamLlmResponse( + AgentRunContext ctx, + List messages, + String effectiveModel, + Consumer eventConsumer) { + ChatCompletionCreateParams.Builder paramsBuilder = + ChatCompletionCreateParams.builder() + .model(effectiveModel) + .streamOptions(ChatCompletionStreamOptions.builder().includeUsage(true).build()); + for (ChatCompletionMessageParam msg : messages) { + paramsBuilder.addMessage(msg); + } + toolRegistry.addToolsTo(paramsBuilder); + + LOG.info("LLM request: model={}", effectiveModel); + StreamAccumulator acc = new StreamAccumulator(); + try (StreamResponse stream = + client.chat().completions().createStreaming(paramsBuilder.build())) { + stream.stream().forEach(chunk -> consumeChunk(ctx, chunk, acc, eventConsumer)); + } + return new StreamResult(acc.content.toString(), acc.buildToolCalls()); + } + + /** Fold one streaming chunk into {@code acc}, emitting per-token {@link ContentDelta}s. */ + private void consumeChunk( + AgentRunContext ctx, + ChatCompletionChunk chunk, + StreamAccumulator acc, + Consumer eventConsumer) { + if (!acc.serverModelLogged) { + LOG.info("LLM response: server-echoed model={}", chunk.model()); + acc.serverModelLogged = true; + } + chunk + .usage() + .ifPresent(u -> ctx.addTokenUsage(u.promptTokens(), u.completionTokens(), u.totalTokens())); + + for (ChatCompletionChunk.Choice c : chunk.choices()) { + c.delta() + .content() + .ifPresent( + text -> { + acc.content.append(text); + emit(ctx, new ContentDelta(text), eventConsumer); + }); + c.delta().toolCalls().ifPresent(acc::mergeToolCallDeltas); + } + } + + /** + * Mutable accumulator for a single streaming LLM turn. Tool call fields are keyed by the chunk's + * {@code index} because provider SDKs may deliver a single logical call across multiple chunks + * and only surface the {@code id}/{@code name} on the first one. + */ + private static final class StreamAccumulator { + final StringBuilder content = new StringBuilder(); + final Map toolCallIds = new HashMap<>(); + final Map toolCallNames = new HashMap<>(); + final Map toolCallArgs = new HashMap<>(); + boolean serverModelLogged = false; + + void mergeToolCallDeltas(List deltas) { + for (ChatCompletionChunk.Choice.Delta.ToolCall tc : deltas) { + int idx = (int) tc.index(); + tc.id().ifPresent(id -> toolCallIds.put(idx, id)); + tc.function() + .ifPresent( + fn -> { + fn.name().ifPresent(name -> toolCallNames.put(idx, name)); + fn.arguments() + .ifPresent( + args -> + toolCallArgs + .computeIfAbsent(idx, k -> new StringBuilder()) + .append(args)); + }); + } + } + + /** + * Materialize accumulated deltas into SDK tool-call objects. Returns {@code null} (not an empty + * list) if no tool calls were seen, matching the existing {@link StreamResult} contract. + */ + List buildToolCalls() { + if (toolCallIds.isEmpty()) return null; + List out = new ArrayList<>(toolCallIds.size()); + for (Map.Entry e : toolCallIds.entrySet()) { + int idx = e.getKey(); + String id = (e.getValue() == null || e.getValue().isEmpty()) ? synthId() : e.getValue(); + String args = toolCallArgs.containsKey(idx) ? toolCallArgs.get(idx).toString() : "{}"; + out.add( + ChatCompletionMessageToolCall.ofFunction( + ChatCompletionMessageFunctionToolCall.builder() + .id(id) + .function( + ChatCompletionMessageFunctionToolCall.Function.builder() + .name(toolCallNames.getOrDefault(idx, "")) + .arguments(args) + .build()) + .build())); + } + return out; + } + + /** + * Synthesize an id for tool calls whose id never arrived on the stream (some OpenAI-compatible + * providers omit it). The id has to be stable within a turn and unique across turns so the + * assistant/tool_result pairing downstream holds. + */ + private static String synthId() { + return "local_" + java.util.UUID.randomUUID().toString().replace("-", "").substring(0, 24); + } + } + + private static Map parseToolArgs(String json) { + if (json == null || json.isEmpty()) { + return new HashMap<>(); + } + try { + return JSON.readValue(json, new TypeReference>() {}); + } catch (java.io.IOException e) { + throw new IllegalArgumentException("Malformed tool-call arguments JSON: " + json, e); + } + } + + // --- Middleware dispatch methods --- + // + // Middlewares are internal framework code. If one throws, the agent run fails via the + // top-level catch in run() — we do not wrap individual dispatch calls in try/catch. + + private void emitFinish(AgentRunContext ctx, Consumer eventConsumer) { + emit( + ctx, + new AgentFinish( + ctx.getIteration(), + ctx.getPromptTokens(), + ctx.getCompletionTokens(), + ctx.getTotalTokens()), + eventConsumer); + } + + private void emit(AgentRunContext ctx, AgentEvent event, Consumer consumer) { + AgentEvent filtered = event; + for (AgentMiddleware mw : middlewares) { + filtered = mw.onEvent(ctx, filtered); + if (filtered == null) return; + } + consumer.accept(filtered); + } + + private void dispatchAgentStart(AgentRunContext ctx) { + for (AgentMiddleware mw : middlewares) { + mw.onAgentStart(ctx); + } + } + + private void dispatchAgentFinish(AgentRunContext ctx) { + // Runs even when the agent body threw, so swallow here to ensure every middleware's cleanup + // gets a chance to run; otherwise we'd leak session state in later middlewares. + for (int i = middlewares.size() - 1; i >= 0; i--) { + try { + middlewares.get(i).onAgentFinish(ctx); + } catch (Exception e) { + LOG.warn("Middleware onAgentFinish error", e); + } + } + } + + private AgentMiddleware.LlmCallAction dispatchBeforeLlmCall( + AgentRunContext ctx, List messages) { + for (AgentMiddleware mw : middlewares) { + AgentMiddleware.LlmCallAction action = mw.beforeLlmCall(ctx, messages); + if (action != null) return action; + } + return null; + } + + private void dispatchAfterLlmCall( + AgentRunContext ctx, ChatCompletionAssistantMessageParam response) { + for (int i = middlewares.size() - 1; i >= 0; i--) { + middlewares.get(i).afterLlmCall(ctx, response); + } + } + + private AgentMiddleware.ToolCallDenial dispatchBeforeToolCall( + AgentRunContext ctx, String toolCallId, String toolName, Map toolArgs) { + for (AgentMiddleware mw : middlewares) { + AgentMiddleware.ToolCallDenial denial = + mw.beforeToolCall(ctx, toolCallId, toolName, toolArgs); + if (denial != null) return denial; + } + return null; + } + + private String dispatchAfterToolCall( + AgentRunContext ctx, String toolName, Map toolArgs, String result) { + String modified = null; + for (int i = middlewares.size() - 1; i >= 0; i--) { + String mwResult = + middlewares + .get(i) + .afterToolCall(ctx, toolName, toolArgs, modified != null ? modified : result); + if (mwResult != null) { + modified = mwResult; + } + } + return modified; + } + + // --- Builder --- + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private OpenAIClient client; + private String modelName; + private ToolRegistry toolRegistry = new ToolRegistry(ToolRegistry.DEFAULT_TIMEOUT_SECONDS); + private final List middlewares = new ArrayList<>(); + private int maxIterations = 20; + private String systemPrompt; + + public Builder client(OpenAIClient client) { + this.client = client; + return this; + } + + public Builder modelName(String modelName) { + this.modelName = modelName; + return this; + } + + public Builder toolRegistry(ToolRegistry toolRegistry) { + this.toolRegistry = toolRegistry; + return this; + } + + public Builder addMiddleware(AgentMiddleware middleware) { + this.middlewares.add(middleware); + return this; + } + + public Builder maxIterations(int maxIterations) { + if (maxIterations < 1) { + throw new IllegalArgumentException("maxIterations must be >= 1, got " + maxIterations); + } + this.maxIterations = maxIterations; + return this; + } + + public Builder systemPrompt(String systemPrompt) { + this.systemPrompt = systemPrompt; + return this; + } + + public ReactAgent build() { + if (client == null) throw new IllegalStateException("client is required"); + if (modelName == null) throw new IllegalStateException("modelName is required"); + if (toolRegistry == null) throw new IllegalStateException("toolRegistry is required"); + for (AgentMiddleware mw : middlewares) { + mw.onRegister(toolRegistry); + } + return new ReactAgent( + client, modelName, toolRegistry, middlewares, maxIterations, systemPrompt); + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStore.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStore.java new file mode 100644 index 00000000000..d4a8a97e97d --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStore.java @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +import java.io.BufferedReader; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.regex.Pattern; +import java.util.regex.PatternSyntaxException; +import java.util.stream.Stream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Per-session temp-file store for large tool outputs that the input gate has offloaded from + * conversation history. + * + *

Layout: {@code //tool_.txt}, where {@code } is a + * per-engine randomly-named temp directory (see {@link #create()}). UTF-8 text. + * + *

Isolation: every {@code read}/{@code grep} call requires the current session id and + * validates that the caller-supplied path resolves under {@code /}, not merely + * under {@code }. A session cannot read another session's offloaded output even if it somehow + * obtained the absolute path. The per-engine random root additionally isolates different engine + * processes sharing the same host. + * + *

Failures (traversal, missing file, session-id mismatch, IO) are reported as error strings + * rather than thrown — the tool must keep the agent loop alive. + */ +public class ToolOutputStore implements AutoCloseable { + + private static final Logger LOG = LoggerFactory.getLogger(ToolOutputStore.class); + private static final String ROOT_PREFIX = "kyuubi-data-agent-"; + + private final Path root; + + /** + * Create a store backed by a fresh, per-engine random temp directory under {@code + * java.io.tmpdir}. + */ + public static ToolOutputStore create() { + try { + return new ToolOutputStore(Files.createTempDirectory(ROOT_PREFIX)); + } catch (IOException e) { + throw new IllegalStateException("Failed to create ToolOutputStore temp root", e); + } + } + + private ToolOutputStore(Path root) { + try { + Files.createDirectories(root); + this.root = root.toRealPath(); + } catch (IOException e) { + throw new IllegalStateException("Failed to initialize ToolOutputStore root: " + root, e); + } + } + + public Path getRoot() { + return root; + } + + /** Write {@code content} to {@code //tool_.txt}. */ + public Path write(String sessionId, String toolCallId, String content) throws IOException { + Path dir = root.resolve(safeSegment(sessionId)); + Files.createDirectories(dir); + Path file = dir.resolve("tool_" + safeSegment(toolCallId) + ".txt"); + Files.write(file, content.getBytes(StandardCharsets.UTF_8)); + return file; + } + + /** + * Read a line window. Returns a human-readable block including a 1-based {@code [lines X-Y of Z + * total]} header, or an error string on traversal / IO failure / cross-session access. + */ + public String read(String sessionId, String pathStr, long offset, int limit) { + Path file = validatePath(sessionId, pathStr); + if (file == null) { + return "Error: path is outside this session's tool-output directory or does not exist: " + + pathStr; + } + if (offset < 0) offset = 0; + if (limit <= 0) limit = 1; + + List taken = new ArrayList<>(limit); + long totalLines = 0; + try (BufferedReader br = Files.newBufferedReader(file, StandardCharsets.UTF_8)) { + String line; + while ((line = br.readLine()) != null) { + if (totalLines >= offset && taken.size() < limit) { + taken.add(line); + } + totalLines++; + } + } catch (IOException e) { + return "Error reading " + pathStr + ": " + e.getMessage(); + } + + long fromLine = offset + 1; // 1-based + long toLineExclusive = Math.min(offset + limit, totalLines); + StringBuilder sb = new StringBuilder(); + sb.append("[lines ") + .append(fromLine) + .append("-") + .append(toLineExclusive) + .append(" of ") + .append(totalLines) + .append(" total]\n"); + for (String line : taken) { + sb.append(line).append('\n'); + } + return sb.toString(); + } + + /** + * Stream-grep the file. Returns at most {@code maxMatches} matches as {@code lineNo:content}, one + * per line; or an error string on traversal / regex / IO failure / cross-session access. + */ + public String grep(String sessionId, String pathStr, String patternStr, int maxMatches) { + Path file = validatePath(sessionId, pathStr); + if (file == null) { + return "Error: path is outside this session's tool-output directory or does not exist: " + + pathStr; + } + if (patternStr == null || patternStr.isEmpty()) { + return "Error: 'pattern' parameter is required."; + } + if (maxMatches <= 0) maxMatches = 50; + + Pattern pattern; + try { + pattern = Pattern.compile(patternStr); + } catch (PatternSyntaxException e) { + return "Error: invalid regex pattern: " + e.getMessage(); + } + + StringBuilder sb = new StringBuilder(); + int matches = 0; + long lineNo = 0; + try (BufferedReader br = Files.newBufferedReader(file, StandardCharsets.UTF_8)) { + String line; + while ((line = br.readLine()) != null) { + lineNo++; + if (pattern.matcher(line).find()) { + sb.append(lineNo).append(':').append(line).append('\n'); + matches++; + if (matches >= maxMatches) break; + } + } + } catch (IOException e) { + return "Error reading " + pathStr + ": " + e.getMessage(); + } + if (matches == 0) { + return "[no matches for pattern: " + patternStr + "]"; + } + return "[" + matches + " match" + (matches == 1 ? "" : "es") + "]\n" + sb; + } + + /** Recursively delete the session's subtree. Safe to call on missing sessions. */ + public void cleanupSession(String sessionId) { + if (sessionId == null) return; + Path dir = root.resolve(safeSegment(sessionId)); + deleteTree(dir); + } + + /** Delete everything below (and including) the root. Idempotent; safe to call multiple times. */ + @Override + public void close() { + deleteTree(root); + } + + private static void deleteTree(Path dir) { + if (!Files.exists(dir)) return; + try (Stream stream = Files.walk(dir)) { + stream.sorted(Comparator.reverseOrder()).forEach(ToolOutputStore::deleteQuietly); + } catch (IOException e) { + LOG.warn("Failed to clean up dir {}", dir, e); + } + } + + private static void deleteQuietly(Path p) { + try { + Files.deleteIfExists(p); + } catch (IOException e) { + LOG.debug("Failed to delete {}", p, e); + } + } + + /** + * Resolve {@code pathStr} and return it only if (a) it exists as a regular file and (b) the real + * path is under {@code /}. Returns null on any violation — including a null or + * empty session id, since without one we cannot scope the check. + */ + private Path validatePath(String sessionId, String pathStr) { + if (pathStr == null || pathStr.isEmpty()) return null; + if (sessionId == null || sessionId.isEmpty()) return null; + Path sessionRoot = root.resolve(safeSegment(sessionId)); + try { + Path real = Paths.get(pathStr).toRealPath(); + if (!real.startsWith(sessionRoot)) return null; + if (!Files.isRegularFile(real)) return null; + return real; + } catch (IOException | SecurityException e) { + return null; + } + } + + /** Strip anything that could escape a single path segment. */ + private static String safeSegment(String raw) { + if (raw == null || raw.isEmpty()) return "_"; + StringBuilder sb = new StringBuilder(raw.length()); + for (int i = 0; i < raw.length(); i++) { + char c = raw.charAt(i); + if (Character.isLetterOrDigit(c) || c == '-' || c == '_' || c == '.') { + sb.append(c); + } else { + sb.append('_'); + } + } + return sb.toString(); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/Compaction.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/Compaction.java new file mode 100644 index 00000000000..26eb6024e98 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/Compaction.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.event; + +/** + * Emitted by {@code CompactionMiddleware} after it has summarized a prefix of the conversation + * history and replaced it in memory. Purely observational — the LLM call that immediately follows + * uses the already-compacted history, so consumers just see this as a side-channel notice that + * compaction happened. The summary text itself is intentionally not included: it can be large and + * would bloat the event stream; operators who need it can read the middleware log. + */ +public final class Compaction extends AgentEvent { + private final int summarizedCount; + private final int keptCount; + private final long triggerTokens; + private final long observedTokens; + + public Compaction(int summarizedCount, int keptCount, long triggerTokens, long observedTokens) { + super(EventType.COMPACTION); + this.summarizedCount = summarizedCount; + this.keptCount = keptCount; + this.triggerTokens = triggerTokens; + this.observedTokens = observedTokens; + } + + public int summarizedCount() { + return summarizedCount; + } + + public int keptCount() { + return keptCount; + } + + public long triggerTokens() { + return triggerTokens; + } + + public long observedTokens() { + return observedTokens; + } + + @Override + public String toString() { + return "Compaction{summarized=" + + summarizedCount + + ", kept=" + + keptCount + + ", triggerTokens=" + + triggerTokens + + ", observedTokens=" + + observedTokens + + "}"; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventType.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventType.java index 937422e2bf5..d58e5de2ee7 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventType.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventType.java @@ -21,6 +21,8 @@ * Enumerates the types of events emitted by the ReAct agent loop. Each value maps to a * corresponding {@link AgentEvent} subclass and carries an SSE event name used for wire * serialization. + * + * @see org.apache.kyuubi.engine.dataagent.runtime.ReactAgent */ public enum EventType { @@ -51,6 +53,9 @@ public enum EventType { /** The agent requires user approval before executing a tool. */ APPROVAL_REQUEST("approval_request"), + /** The conversation history was compacted by the compaction middleware. */ + COMPACTION("compaction"), + /** The agent has finished its analysis. */ AGENT_FINISH("agent_finish"); diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/AgentMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/AgentMiddleware.java new file mode 100644 index 00000000000..c934bb0882a --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/AgentMiddleware.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import java.util.List; +import java.util.Map; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; + +/** + * Middleware interface for the Data Agent ReAct loop. Middlewares are executed in onion-model + * order: before_* hooks run first-to-last, after_* hooks run last-to-first. + * + *

All hooks have default no-op implementations. Override only what you need. + */ +public interface AgentMiddleware { + + /** + * Called once when the middleware is wired into the agent. Register companion tools that are part + * of the middleware's contract, or capture a reference to the registry for later use. Dispatched + * by {@code ReactAgent.Builder.build} before the agent accepts any requests. + */ + default void onRegister(ToolRegistry registry) {} + + /** Called when the agent starts processing a user query. Runs first-to-last. */ + default void onAgentStart(AgentRunContext ctx) {} + + /** Called when the agent finishes. Runs last-to-first (cleanup order). */ + default void onAgentFinish(AgentRunContext ctx) {} + + /** + * Called before each LLM invocation. Return non-null to skip or modify the LLM call. Runs + * first-to-last. + * + * @return {@code null} to proceed normally, {@link LlmSkip} to abort, or {@link + * LlmModifyMessages} to replace the message list for this call. + */ + default LlmCallAction beforeLlmCall( + AgentRunContext ctx, List messages) { + return null; + } + + /** Called after each LLM invocation. Runs last-to-first. */ + default void afterLlmCall(AgentRunContext ctx, ChatCompletionAssistantMessageParam response) {} + + /** Called before each tool execution. Return non-null to deny the call. Runs first-to-last. */ + default ToolCallDenial beforeToolCall( + AgentRunContext ctx, String toolCallId, String toolName, Map toolArgs) { + return null; + } + + /** + * Called after each tool execution. Runs last-to-first. + * + *

Returns {@code String} (not {@code void}) so that middlewares can intercept and transform + * the tool result before it is fed back to the LLM — e.g. for data masking, output truncation, or + * injecting metadata. Return {@code null} to keep the original result unchanged; return a + * non-null value to replace it. + */ + default String afterToolCall( + AgentRunContext ctx, String toolName, Map toolArgs, String result) { + return null; + } + + /** + * Called for every event before it is emitted. Return null to suppress the event. Runs + * first-to-last. + */ + default AgentEvent onEvent(AgentRunContext ctx, AgentEvent event) { + return event; + } + + /** + * Called when a session is closed. Clean up per-session state (scratch files, pending tasks, + * counters). Idempotent. Dispatched by {@code ReactAgent.closeSession}. + */ + default void onSessionClose(String sessionId) {} + + /** + * Called when the engine is stopping. Release global resources and unblock any threads still + * waiting on this middleware. Dispatched by {@code ReactAgent.stop}. + */ + default void onStop() {} + + /** + * Base type for {@code beforeLlmCall} return values. Subtypes: {@link LlmSkip} to abort the LLM + * call, {@link LlmModifyMessages} to replace the message list for this call. + */ + abstract class LlmCallAction { + private LlmCallAction() {} + } + + /** Returned from {@code beforeLlmCall} to skip the LLM call and abort the agent loop. */ + class LlmSkip extends LlmCallAction { + private final String reason; + + public LlmSkip(String reason) { + this.reason = reason; + } + + public String reason() { + return reason; + } + } + + /** + * Returned from {@code beforeLlmCall} to replace the message list for this LLM invocation. The + * agent loop continues normally with the modified messages. + */ + class LlmModifyMessages extends LlmCallAction { + private final List messages; + + public LlmModifyMessages(List messages) { + this.messages = messages; + } + + public List messages() { + return messages; + } + } + + /** Returned from {@code beforeToolCall} to deny a tool call. Non-null means denied. */ + class ToolCallDenial { + private final String reason; + + public ToolCallDenial(String reason) { + this.reason = reason; + } + + public String reason() { + return reason; + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddleware.java new file mode 100644 index 00000000000..92d25b47b9d --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddleware.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.ApprovalMode; +import org.apache.kyuubi.engine.dataagent.runtime.event.ApprovalRequest; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; +import org.apache.kyuubi.engine.dataagent.tool.ToolRiskLevel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Middleware that enforces human-in-the-loop approval for tool calls based on the {@link + * ApprovalMode} and the tool's {@link ToolRiskLevel}. + * + *

When approval is required, an {@link ApprovalRequest} event is emitted to the client via + * {@link AgentRunContext#emit}, and the agent thread blocks until the client responds via {@link + * #resolve} or the timeout expires. + */ +public class ApprovalMiddleware implements AgentMiddleware { + + private static final Logger LOG = LoggerFactory.getLogger(ApprovalMiddleware.class); + + private static final long DEFAULT_TIMEOUT_SECONDS = 300; // 5 minutes + + private final long timeoutSeconds; + private final ConcurrentHashMap> pending = + new ConcurrentHashMap<>(); + private ToolRegistry toolRegistry; + + public ApprovalMiddleware() { + this(DEFAULT_TIMEOUT_SECONDS); + } + + public ApprovalMiddleware(long timeoutSeconds) { + this.timeoutSeconds = timeoutSeconds; + } + + @Override + public void onRegister(ToolRegistry registry) { + this.toolRegistry = registry; + } + + @Override + public ToolCallDenial beforeToolCall( + AgentRunContext ctx, String toolCallId, String toolName, Map toolArgs) { + ToolRiskLevel riskLevel = toolRegistry.getRiskLevel(toolName); + + if (shouldAutoApprove(ctx.getApprovalMode(), riskLevel)) { + return null; + } + + String requestId = UUID.randomUUID().toString(); + CompletableFuture future = new CompletableFuture<>(); + pending.put(requestId, future); + + ctx.emit(new ApprovalRequest(requestId, toolCallId, toolName, toolArgs, riskLevel)); + LOG.info("Approval requested for tool '{}' (requestId={})", toolName, requestId); + + try { + boolean approved = future.get(timeoutSeconds, TimeUnit.SECONDS); + if (!approved) { + LOG.info("Tool '{}' denied by user (requestId={})", toolName, requestId); + return new ToolCallDenial("User denied execution of " + toolName); + } + LOG.info("Tool '{}' approved by user (requestId={})", toolName, requestId); + return null; + } catch (TimeoutException e) { + // Complete the future so that a late resolve() call is a harmless no-op + // instead of completing a dangling future. + future.completeExceptionally(e); + LOG.warn("Approval timed out for tool '{}' (requestId={})", toolName, requestId); + return new ToolCallDenial("Approval timed out after " + timeoutSeconds + "s for " + toolName); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return new ToolCallDenial("Approval interrupted for " + toolName); + } catch (Exception e) { + LOG.error("Unexpected error waiting for approval", e); + return new ToolCallDenial("Approval error: " + e.getMessage()); + } finally { + pending.remove(requestId); + } + } + + /** + * Resolve a pending approval request. Called by the external approval channel (e.g. a Kyuubi + * operation or REST endpoint). + * + * @param requestId the request ID from the {@link ApprovalRequest} event + * @param approved true to approve, false to deny + * @return true if the request was found and resolved, false if not found (already timed out or + * invalid ID) + */ + public boolean resolve(String requestId, boolean approved) { + CompletableFuture future = pending.get(requestId); + if (future != null) { + return future.complete(approved); + } + LOG.warn("No pending approval found for requestId={}", requestId); + return false; + } + + /** + * Cancel all pending approval requests to unblock any waiting agent threads. Invoked as part of + * engine shutdown via {@code ReactAgent.stop}. + */ + @Override + public void onStop() { + InterruptedException ex = new InterruptedException("Session closed"); + pending.forEachKey( + Long.MAX_VALUE, + key -> { + CompletableFuture future = pending.remove(key); + if (future != null) { + future.completeExceptionally(ex); + } + }); + } + + private static boolean shouldAutoApprove(ApprovalMode mode, ToolRiskLevel riskLevel) { + if (mode == ApprovalMode.AUTO_APPROVE) { + return true; + } + if (mode == ApprovalMode.NORMAL && riskLevel == ToolRiskLevel.SAFE) { + return true; + } + // STRICT: all tools require approval + return false; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddleware.java new file mode 100644 index 00000000000..acab7ae8b75 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddleware.java @@ -0,0 +1,409 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import com.openai.client.OpenAIClient; +import com.openai.models.chat.completions.ChatCompletion; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageToolCall; +import com.openai.models.chat.completions.ChatCompletionSystemMessageParam; +import com.openai.models.chat.completions.ChatCompletionUserMessageParam; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.ConversationMemory; +import org.apache.kyuubi.engine.dataagent.runtime.event.Compaction; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Middleware that compacts conversation history when the prompt grows large. + * + *

Trigger formula: + * + *

+ *   predicted_this_turn_prompt_tokens
+ *       = last_llm_call_total_tokens            // prompt + completion of the previous call
+ *       + estimate_new_tail(messages)           // chars / 4, for content appended after the
+ *                                               // last assistant message (tool results, new user)
+ * 
+ * + * History (already sent to the LLM) must use real token counts — we read them straight off {@link + * ConversationMemory#getLastTotalTokens()}, which the provider updates after every call. We use + * total (prompt + completion) rather than just prompt because the previous call's completion + * — e.g. an assistant {@code tool_call} message — is already appended to history and will be + * tokenized into the next prompt; the tail estimator starts strictly after the last + * assistant message, so without completion we would miss it. Only content appended since that + * assistant message is estimated, which catches spikes before the next API report arrives. + * + *

Post-compaction persistence: the summary + kept tail replace {@link + * ConversationMemory}'s history via {@link ConversationMemory#replaceHistory}. All subsequent turns + * in this session read the compacted history — we do not re-summarize each turn. Because the next + * LLM call uses the compacted messages, its reported {@code prompt_tokens} will be small, naturally + * preventing retriggering. + * + *

Thread safety / shared instance: Instances of this middleware are shared across all + * sessions inside a provider (see {@code OpenAiProvider} javadoc). All per-session state + * (cumulative and last-call totals) lives on {@link ConversationMemory}, so this middleware itself + * is stateless across sessions and requires no per-session cleanup. + * + *

Tool-call pair invariant: The split never separates an assistant message bearing {@code + * tool_calls} from the {@code tool_result} messages that satisfy those calls — an orphan + * tool_result is rejected by the OpenAI API with HTTP 400. + * + *

Failure handling: summarizer failures propagate out of {@code beforeLlmCall} to the + * agent's top-level catch, which surfaces an {@code AgentError} event. We don't silently skip + * compaction — a broken summarizer is a real problem the operator needs to see. + * + *

Disabling: to effectively turn compaction off, construct with a very large {@code + * triggerPromptTokens} (e.g., {@link Long#MAX_VALUE}). + * + *

TODO: support a separate (cheaper) summarization model distinct from the main agent model. + */ +public class CompactionMiddleware implements AgentMiddleware { + + private static final Logger LOG = LoggerFactory.getLogger(CompactionMiddleware.class); + + /** Number of recent user turns (and their assistant/tool companions) to preserve verbatim. */ + private static final int KEEP_RECENT_TURNS = 4; + + private static final String COMPACTION_SYSTEM_PROMPT = + "SYSTEM OPERATION — this is an automated context compaction step, NOT a user message.\n" + + "\n" + + "You are summarizing a conversation between a user and a Data Agent (a ReAct agent" + + " that executes SQL and analytics tools against Kyuubi). The user has NOT asked you" + + " to summarize anything. Do not address the user. Do not ask questions. Produce only" + + " the summary in the schema below.\n" + + "\n" + + "Goal: produce a dense, structured summary the agent can resume from without losing" + + " critical context. Preserve concrete details verbatim — file paths, table names," + + " schema definitions, SQL snippets, column names, error messages.\n" + + "\n" + + "Output EXACTLY these 8 sections, in order, as markdown headers:\n" + + "\n" + + "1. ## User Intent\n" + + " The user's original request, restated in full. Preserve the literal phrasing of" + + " the ask. Include follow-up refinements.\n" + + "\n" + + "2. ## Key Concepts\n" + + " Domain terms, data sources, tables, schemas, SQL dialects, and business logic" + + " the agent has been reasoning about.\n" + + "\n" + + "3. ## Files and Code\n" + + " File paths, query text, DDL, or code artifacts referenced. Include verbatim SQL" + + " snippets that produced meaningful results.\n" + + "\n" + + "4. ## Errors and Recoveries\n" + + " Errors encountered (SQL syntax, permission, timeout, tool failures), what was" + + " tried, and what resolved them. Preserve error messages verbatim.\n" + + "\n" + + "5. ## Pending Work\n" + + " Tasks the agent identified but has not completed yet.\n" + + "\n" + + "6. ## Current State\n" + + " Where the agent is right now — what question is open, what data has been" + + " retrieved, what hypothesis is being tested.\n" + + "\n" + + "7. ## Next Step\n" + + " The immediate next action the agent should take when resuming.\n" + + "\n" + + "8. ## Tool Usage Summary\n" + + " Which tools were called, how many times, and notable results.\n" + + "\n" + + "CRITICAL:\n" + + "- DO NOT ask the user about this summary.\n" + + "- DO NOT mention that compaction occurred in any future assistant response.\n" + + "- DO NOT invent details not present in the conversation.\n" + + "- DO NOT output anything outside the 8 sections.\n"; + + private final OpenAIClient client; + private final String summarizerModel; + private final long triggerPromptTokens; + + public CompactionMiddleware( + OpenAIClient client, String summarizerModel, long triggerPromptTokens) { + this.client = client; + this.summarizerModel = summarizerModel; + this.triggerPromptTokens = triggerPromptTokens; + } + + @Override + public LlmCallAction beforeLlmCall( + AgentRunContext ctx, List messages) { + ConversationMemory mem = ctx.getMemory(); + // 1) Real token count of the previous LLM call (prompt + completion, i.e. everything through + // the last assistant message, which is now part of history). 0 on the first call. + long lastTotal = mem.getLastTotalTokens(); + // 2) Estimated tokens appended to the tail after the last assistant (tool_results, new user). + long newTailEstimate = estimateTailAfterLastAssistant(messages); + + if (lastTotal + newTailEstimate < triggerPromptTokens) { + return null; + } + + List history = mem.getHistory(); + + // 3) Split history into old (to summarize) and kept (recent tail), never orphaning a + // tool_result. + Split split = computeSplit(history, KEEP_RECENT_TURNS); + if (split.old.isEmpty()) { + return null; + } + + String summary = summarize(mem.getSystemPrompt(), split.old); + + // 4) Build the compacted history and persist into ConversationMemory. + List compacted = new ArrayList<>(1 + split.kept.size()); + compacted.add(wrapSummaryAsUserMessage(summary)); + compacted.addAll(split.kept); + mem.replaceHistory(compacted); + + LOG.info( + "Compacted {} old msgs into 1 summary; kept {} tail msgs (lastTotal={}, newTail~={})", + split.old.size(), + split.kept.size(), + lastTotal, + newTailEstimate); + + ctx.emit( + new Compaction( + split.old.size(), split.kept.size(), triggerPromptTokens, lastTotal + newTailEstimate)); + + return new LlmModifyMessages(mem.buildLlmMessages()); + } + + /** Call the LLM to produce a summary of {@code oldMessages}. Failures propagate. */ + private String summarize(String agentSystemPrompt, List oldMessages) { + String systemPrompt = COMPACTION_SYSTEM_PROMPT; + if (agentSystemPrompt != null && !agentSystemPrompt.isEmpty()) { + systemPrompt = + systemPrompt + + "\n---\nFor context, the agent's own system prompt is:\n" + + agentSystemPrompt; + } + + String rendered = renderAsText(oldMessages); + + ChatCompletionCreateParams params = + ChatCompletionCreateParams.builder() + .model(summarizerModel) + .temperature(0.0) + .addMessage( + ChatCompletionMessageParam.ofSystem( + ChatCompletionSystemMessageParam.builder().content(systemPrompt).build())) + .addMessage( + ChatCompletionMessageParam.ofUser( + ChatCompletionUserMessageParam.builder().content(rendered).build())) + .build(); + + ChatCompletion response = client.chat().completions().create(params); + return response.choices().get(0).message().content().get(); + } + + // ----- helpers ----- + + /** Sum of content characters in messages after the last assistant, using ~4 chars per token. */ + static long estimateTailAfterLastAssistant(List messages) { + int lastAssistantIdx = -1; + for (int i = messages.size() - 1; i >= 0; i--) { + if (messages.get(i).isAssistant()) { + lastAssistantIdx = i; + break; + } + } + long totalChars = 0; + for (int i = lastAssistantIdx + 1; i < messages.size(); i++) { + totalChars += contentCharCount(messages.get(i)); + } + return totalChars / 4; + } + + private static long contentCharCount(ChatCompletionMessageParam msg) { + if (msg.isUser()) { + return msg.asUser().content().text().map(String::length).orElse(0); + } + if (msg.isTool()) { + return msg.asTool().content().text().map(String::length).orElse(0); + } + if (msg.isAssistant()) { + return msg.asAssistant().content().flatMap(c -> c.text()).map(String::length).orElse(0); + } + if (msg.isSystem()) { + return msg.asSystem().content().text().map(String::length).orElse(0); + } + return 0; + } + + /** + * Render a list of messages as plain text for the summarizer's user turn. Tool calls and tool + * results are rendered as tagged text so the summarizer LLM doesn't try to continue them as live + * agent state. + */ + static String renderAsText(List messages) { + StringBuilder sb = new StringBuilder(); + for (ChatCompletionMessageParam msg : messages) { + if (sb.length() > 0) sb.append("\n\n"); + if (msg.isUser()) { + sb.append("USER: ").append(extractUserContent(msg)); + } else if (msg.isAssistant()) { + sb.append("ASSISTANT: ").append(extractAssistantContent(msg)); + msg.asAssistant() + .toolCalls() + .ifPresent( + calls -> { + for (ChatCompletionMessageToolCall tc : calls) { + if (tc.isFunction()) { + sb.append("\n[tool_call: ") + .append(tc.asFunction().function().name()) + .append("(") + .append(tc.asFunction().function().arguments()) + .append(") id=") + .append(tc.asFunction().id()) + .append("]"); + } + } + }); + } else if (msg.isTool()) { + sb.append("[tool_result id=") + .append(msg.asTool().toolCallId()) + .append("]: ") + .append(extractToolContent(msg)); + } else if (msg.isSystem()) { + // system prompt should not appear in oldMessages, but render defensively + sb.append("SYSTEM: ").append(extractSystemContent(msg)); + } + } + return sb.toString(); + } + + private static String extractUserContent(ChatCompletionMessageParam msg) { + return msg.asUser().content().text().orElse("[non-text content]"); + } + + private static String extractAssistantContent(ChatCompletionMessageParam msg) { + return msg.asAssistant().content().map(c -> c.text().orElse("[non-text content]")).orElse(""); + } + + private static String extractToolContent(ChatCompletionMessageParam msg) { + return msg.asTool().content().text().orElse("[non-text content]"); + } + + private static String extractSystemContent(ChatCompletionMessageParam msg) { + return msg.asSystem().content().text().orElse("[non-text content]"); + } + + /** Result of splitting the history into a summarizable prefix and a kept tail. */ + static final class Split { + + final List old; + final List kept; + + Split(List old, List kept) { + this.old = old; + this.kept = kept; + } + } + + /** + * Split the history at a boundary that preserves the last {@code keepRecentTurns} user messages, + * with adjustments so that no assistant-tool_use is separated from its tool_results. + */ + static Split computeSplit(List history, int keepRecentTurns) { + if (history.size() <= 2) { + return new Split(new ArrayList<>(), new ArrayList<>(history)); + } + + // Walk from the tail, count user boundaries. If the history does not contain enough user + // messages to satisfy keepRecentTurns, keep everything (splitIdx = 0); the empty-old check + // in beforeLlmCall will then skip this turn gracefully. + int userBoundariesFound = 0; + int splitIdx = 0; + for (int i = history.size() - 1; i >= 0; i--) { + if (history.get(i).isUser()) { + userBoundariesFound++; + if (userBoundariesFound == keepRecentTurns) { + splitIdx = i; + break; + } + } + } + + // Protect tool-call / tool-result pairing: never split between an assistant that issued + // tool_calls and the tool_results that satisfy them. + while (splitIdx > 0) { + ChatCompletionMessageParam prev = history.get(splitIdx - 1); + if (prev.isTool()) { + splitIdx--; + continue; + } + if (prev.isAssistant()) { + boolean hasToolCalls = prev.asAssistant().toolCalls().map(List::size).orElse(0) > 0; + if (hasToolCalls) { + splitIdx--; + continue; + } + } + break; + } + + // Also guard against the edge case: if kept contains a tool_result whose tool_call id is + // defined only in old, pull that assistant (and its siblings) into kept too. + Set keptCallIds = collectToolCallIds(history.subList(splitIdx, history.size())); + if (!keptCallIds.isEmpty()) { + while (splitIdx > 0) { + ChatCompletionMessageParam prev = history.get(splitIdx - 1); + if (!prev.isAssistant()) break; + List calls = prev.asAssistant().toolCalls().orElse(null); + if (calls == null || calls.isEmpty()) break; + boolean satisfiesKept = false; + for (ChatCompletionMessageToolCall tc : calls) { + if (tc.isFunction() && keptCallIds.contains(tc.asFunction().id())) { + satisfiesKept = true; + break; + } + } + if (!satisfiesKept) break; + splitIdx--; + } + } + + List oldPart = new ArrayList<>(history.subList(0, splitIdx)); + List keptPart = + new ArrayList<>(history.subList(splitIdx, history.size())); + return new Split(oldPart, keptPart); + } + + private static Set collectToolCallIds(List slice) { + Set ids = new HashSet<>(); + for (ChatCompletionMessageParam m : slice) { + if (m.isTool()) { + ids.add(m.asTool().toolCallId()); + } + } + return ids; + } + + private static ChatCompletionMessageParam wrapSummaryAsUserMessage(String summary) { + String body = "\n" + summary + "\n"; + return ChatCompletionMessageParam.ofUser( + ChatCompletionUserMessageParam.builder().content(body).build()); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/LoggingMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/LoggingMiddleware.java new file mode 100644 index 00000000000..e0a5c2364eb --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/LoggingMiddleware.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import java.util.List; +import java.util.Map; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentError; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent; +import org.apache.kyuubi.engine.dataagent.runtime.event.StepStart; +import org.apache.kyuubi.engine.dataagent.runtime.event.ToolResult; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.slf4j.MDC; + +/** + * Logging middleware that prints agent lifecycle events for debugging and observability. + * + *

Picks up {@code operationId} and {@code sessionId} from SLF4J MDC (set by ExecuteStatement) to + * tag every log line with the Kyuubi operation/session context. + * + *

Log structure: + * + *

+ *   [op:abcd1234] START user_input="..."
+ *   [op:abcd1234] Step 1
+ *   [op:abcd1234] LLM call: step=1, messages=3
+ *   [op:abcd1234] LLM response: step=1, content="...(truncated)", tool_calls=1
+ *   [op:abcd1234] Tool call: sql_query {sql=SELECT ...}
+ *   [op:abcd1234] Tool result: sql_query -> "| col1 | col2 |...(truncated)"
+ *   [op:abcd1234] FINISH steps=2, tokens=1234
+ * 
+ */ +public class LoggingMiddleware implements AgentMiddleware { + + private static final Logger LOG = LoggerFactory.getLogger("DataAgent"); + + private static final int MAX_PREVIEW_LENGTH = 500; + + private static String prefix() { + String sessionId = MDC.get("sessionId"); + String opId = MDC.get("operationId"); + StringBuilder sb = new StringBuilder(); + if (sessionId != null) { + sb.append("[s:").append(shortId(sessionId)).append("]"); + } + if (opId != null) { + sb.append("[op:").append(shortId(opId)).append("]"); + } + if (sb.length() > 0) { + sb.append(" "); + } + return sb.toString(); + } + + /** + * Take the first segment of a UUID (before the first dash). e.g. "327d8c5b-91ef-..." → "327d8c5b" + */ + private static String shortId(String id) { + int dash = id.indexOf('-'); + return dash > 0 ? id.substring(0, dash) : id; + } + + @Override + public void onAgentStart(AgentRunContext ctx) { + LOG.debug("{}START user_input=\"{}\"", prefix(), truncate(ctx.getMemory().getLastUserInput())); + } + + @Override + public void onAgentFinish(AgentRunContext ctx) { + LOG.info( + "{}FINISH steps={}, prompt_tokens={}, completion_tokens={}, total_tokens={}", + prefix(), + ctx.getIteration(), + ctx.getPromptTokens(), + ctx.getCompletionTokens(), + ctx.getTotalTokens()); + } + + @Override + public LlmCallAction beforeLlmCall( + AgentRunContext ctx, List messages) { + LOG.info("{}LLM call: step={}, messages={}", prefix(), ctx.getIteration(), messages.size()); + return null; + } + + @Override + public void afterLlmCall(AgentRunContext ctx, ChatCompletionAssistantMessageParam response) { + String content = response.content().map(Object::toString).orElse(""); + int toolCallCount = response.toolCalls().map(List::size).orElse(0); + LOG.info( + "{}LLM response: step={}, content=\"{}\", tool_calls={}, " + + "usage(cumulative): prompt={}, completion={}, total={}", + prefix(), + ctx.getIteration(), + truncate(content), + toolCallCount, + ctx.getPromptTokens(), + ctx.getCompletionTokens(), + ctx.getTotalTokens()); + } + + @Override + public ToolCallDenial beforeToolCall( + AgentRunContext ctx, String toolCallId, String toolName, Map toolArgs) { + LOG.info("{}Tool call: id={}, name={}", prefix(), toolCallId, toolName); + LOG.debug("{}Tool args: {}", prefix(), toolArgs); + return null; + } + + @Override + public String afterToolCall( + AgentRunContext ctx, String toolName, Map toolArgs, String result) { + LOG.info("{}Tool result: {} -> \"{}\"", prefix(), toolName, truncate(result)); + return null; + } + + @Override + public AgentEvent onEvent(AgentRunContext ctx, AgentEvent event) { + switch (event.eventType()) { + case STEP_START: + LOG.info("{}Step {}", prefix(), ((StepStart) event).stepNumber()); + break; + case ERROR: + LOG.error("{}ERROR: {}", prefix(), ((AgentError) event).message()); + break; + case TOOL_RESULT: + ToolResult tr = (ToolResult) event; + if (tr.isError()) { + LOG.warn("{}Tool error: {} -> \"{}\"", prefix(), tr.toolName(), truncate(tr.output())); + } + break; + default: + break; + } + return event; + } + + private static String truncate(String s) { + if (s == null) return ""; + return s.length() <= MAX_PREVIEW_LENGTH ? s : s.substring(0, MAX_PREVIEW_LENGTH) + "..."; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddleware.java new file mode 100644 index 00000000000..87aad9f3255 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddleware.java @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.ToolOutputStore; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; +import org.apache.kyuubi.engine.dataagent.tool.output.GrepToolOutputTool; +import org.apache.kyuubi.engine.dataagent.tool.output.ReadToolOutputTool; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Input-gate middleware that offloads oversized tool outputs to per-session temp files and replaces + * the in-memory tool result with a small head + tail preview plus a retrieval hint. The cheapest + * and highest-ROI defense against context rot. + * + *

Trigger: {@code result.lines > MAX_LINES} OR {@code result.bytes > MAX_BYTES}, first to + * trip wins. Thresholds are hardcoded — the ReAct loop can compensate for suboptimal defaults by + * calling the retrieval tools more aggressively. + * + *

Exempt tools: {@link ReadToolOutputTool} and {@link GrepToolOutputTool} never go + * through the gate — the agent would otherwise recursively re-offload its own retrieval output. + * + *

Session lifecycle: a monotonic counter per session is used to name temp files. {@link + * #onSessionClose(String)} wipes the counter and the per-session temp dir; call it from the + * provider's {@code close(sessionId)} hook (not from {@link #onAgentFinish}, which fires on every + * turn and would invalidate paths the LLM still needs to reference). + */ +public class ToolResultOffloadMiddleware implements AgentMiddleware { + + private static final Logger LOG = LoggerFactory.getLogger(ToolResultOffloadMiddleware.class); + + static final int MAX_LINES = 500; + static final int MAX_BYTES = 50 * 1024; + static final int PREVIEW_HEAD_LINES = 20; + static final int PREVIEW_TAIL_LINES = 20; + + private static final Set EXEMPT_TOOLS = + new HashSet<>(Arrays.asList(ReadToolOutputTool.NAME, GrepToolOutputTool.NAME)); + + private final ToolOutputStore store = ToolOutputStore.create(); + private final ConcurrentHashMap counters = new ConcurrentHashMap<>(); + + /** + * Register the companion retrieval tools so the LLM can reach back into offloaded files. Paired + * with the preview hint emitted from {@link #afterToolCall}; skipping this registration would + * leave the LLM dangling on file paths it can never read. + */ + @Override + public void onRegister(ToolRegistry registry) { + registry.register(new ReadToolOutputTool(store)); + registry.register(new GrepToolOutputTool(store)); + } + + @Override + public String afterToolCall( + AgentRunContext ctx, String toolName, Map toolArgs, String result) { + if (result.isEmpty()) return null; + if (EXEMPT_TOOLS.contains(toolName)) return null; + + int bytes = result.getBytes(StandardCharsets.UTF_8).length; + int lines = countLines(result); + if (lines <= MAX_LINES && bytes <= MAX_BYTES) { + return null; + } + + // AgentRunContext.sessionId is null in unit-test constructions that don't exercise offload. + // In production the provider always threads it through, so treat null as "skip offload". + String sessionId = ctx.getSessionId(); + if (sessionId == null) return null; + + long n = counters.computeIfAbsent(sessionId, k -> new AtomicLong()).incrementAndGet(); + String toolCallId = toolName + "_" + n; + + Path file; + try { + file = store.write(sessionId, toolCallId, result); + } catch (IOException e) { + LOG.warn( + "Tool output offload failed for tool={} session={}; passing through full output", + toolName, + sessionId, + e); + return null; + } + + LOG.info( + "Offloaded tool={} session={} ({} lines / {} bytes) -> {}", + toolName, + sessionId, + lines, + bytes, + file.getFileName()); + return buildPreview(result, lines, bytes, file); + } + + /** Clean up counter and temp dir for a closed session. Idempotent. */ + @Override + public void onSessionClose(String sessionId) { + if (sessionId == null) return; + counters.remove(sessionId); + store.cleanupSession(sessionId); + } + + /** Engine-wide shutdown: drop the temp root. */ + @Override + public void onStop() { + store.close(); + } + + static int countLines(String s) { + if (s.isEmpty()) return 0; + int count = 1; + for (int i = 0; i < s.length(); i++) { + if (s.charAt(i) == '\n') count++; + } + // Trailing newline means the last "line" is empty, but still counted. Good enough for + // gating decisions — we're not trying to match `wc -l` exactly. + return count; + } + + static String buildPreview(String full, int lines, int bytes, Path file) { + String[] split = full.split("\n", -1); + int headEnd = Math.min(PREVIEW_HEAD_LINES, split.length); + int tailStart = Math.max(headEnd, split.length - PREVIEW_TAIL_LINES); + + StringBuilder sb = new StringBuilder(); + sb.append("[Tool output truncated: ") + .append(lines) + .append(" lines, ") + .append(humanBytes(bytes)) + .append("]\n") + .append("Saved to: ") + .append(file.toString()) + .append("\n\n--- First ") + .append(headEnd) + .append(" lines ---\n"); + for (int i = 0; i < headEnd; i++) { + sb.append(split[i]).append('\n'); + } + if (tailStart < split.length) { + int tailCount = split.length - tailStart; + sb.append("--- Last ").append(tailCount).append(" lines ---\n"); + for (int i = tailStart; i < split.length; i++) { + sb.append(split[i]).append('\n'); + } + } + sb.append("\nUse ") + .append(ReadToolOutputTool.NAME) + .append("(path, offset, limit) to read windows, or ") + .append(GrepToolOutputTool.NAME) + .append("(path, pattern, max_matches) to search."); + return sb.toString(); + } + + private static String humanBytes(long bytes) { + if (bytes < 1024) return bytes + " B"; + if (bytes < 1024 * 1024) return String.format("%.1f KB", bytes / 1024.0); + return String.format("%.1f MB", bytes / (1024.0 * 1024.0)); + } + + /** Visible for testing. */ + int trackedSessions() { + return counters.size(); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/AgentTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/AgentTool.java index 297a7c8d74a..b19f6787c87 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/AgentTool.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/AgentTool.java @@ -50,7 +50,10 @@ default ToolRiskLevel riskLevel() { * Execute the tool with the given deserialized arguments. * * @param args the deserialized arguments from the LLM's tool call + * @param ctx per-invocation context (session id, etc.); never null — use {@link + * ToolContext#EMPTY} for calls without a session. Tools that are session-agnostic may ignore + * it. * @return the result string to feed back to the LLM */ - String execute(T args); + String execute(T args, ToolContext ctx); } diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolContext.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolContext.java new file mode 100644 index 00000000000..2625f3bab59 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolContext.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool; + +/** + * Per-invocation context handed to {@link AgentTool#execute(Object, ToolContext)}. Today it carries + * just the session id so session-scoped tools (e.g. the offloaded tool-output retrievers) can + * restrict their filesystem view; extend here when a tool needs user/approval/etc. + */ +public final class ToolContext { + + /** Sentinel for call sites that have no session to attribute — tests, direct CLI use. */ + public static final ToolContext EMPTY = new ToolContext(null); + + private final String sessionId; + + public ToolContext(String sessionId) { + this.sessionId = sessionId; + } + + /** Upstream session id, or {@code null} when invoked outside a session. */ + public String sessionId() { + return sessionId; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistry.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistry.java index a403c66b58d..3a11bab567c 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistry.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistry.java @@ -26,15 +26,16 @@ import com.openai.models.chat.completions.ChatCompletionTool; import java.util.LinkedHashMap; import java.util.Map; -import java.util.concurrent.Callable; -import java.util.concurrent.ExecutionException; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLong; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -68,15 +69,20 @@ public class ToolRegistry implements AutoCloseable { */ private static final int MAX_POOL_SIZE = 8; + /** Default wall-clock cap for a single tool call, used when no explicit value is configured. */ + public static final long DEFAULT_TIMEOUT_SECONDS = 300; + private final Map> tools = new LinkedHashMap<>(); private volatile Map cachedSpecs; private final long toolCallTimeoutSeconds; private final ExecutorService executor; + private final ScheduledExecutorService timeoutScheduler; /** - * @param toolCallTimeoutSeconds wall-clock cap on every {@link #executeTool} call, sourced from - * {@code kyuubi.engine.data.agent.tool.call.timeout}. When the timeout fires, the thread is - * interrupted and a descriptive error is returned to the LLM. + * @param toolCallTimeoutSeconds wall-clock cap on every {@link #executeTool} / {@link + * #submitTool} call, sourced from {@code kyuubi.engine.data.agent.tool.call.timeout}. When + * the timeout fires, the worker thread is interrupted and a descriptive error is returned to + * the LLM. */ public ToolRegistry(long toolCallTimeoutSeconds) { this.toolCallTimeoutSeconds = toolCallTimeoutSeconds; @@ -93,12 +99,20 @@ public ToolRegistry(long toolCallTimeoutSeconds) { t.setDaemon(true); return t; }); + this.timeoutScheduler = + Executors.newSingleThreadScheduledExecutor( + r -> { + Thread t = new Thread(r, "tool-call-timeout"); + t.setDaemon(true); + return t; + }); } - /** Shut down the worker pool. Idempotent. */ + /** Shut down the worker pool and the timeout scheduler. Idempotent. */ @Override public void close() { executor.shutdownNow(); + timeoutScheduler.shutdownNow(); } /** Register a tool. Keyed by {@link AgentTool#name()}. */ @@ -138,61 +152,100 @@ private synchronized Map ensureSpecs() { } /** - * Execute a tool call: deserialize the JSON args, then delegate to the tool, with a wall-clock - * timeout sourced from {@code kyuubi.engine.data.agent.tool.call.timeout}. If the tool does not - * finish within the timeout, the worker thread is interrupted and a descriptive error is returned - * to the LLM so it can react (e.g. simplify the query, retry with LIMIT). + * Synchronous entry point. Blocks until the tool finishes, times out, or the registry rejects the + * submission. Errors are surfaced as strings (never as exceptions) so the LLM can observe and + * react to them. * * @param toolName the function name from the LLM response * @param argsJson the raw JSON arguments string * @return the result string, or an error message */ - @SuppressWarnings("unchecked") public String executeTool(String toolName, String argsJson) { - AgentTool tool; + return submitTool(toolName, argsJson, ToolContext.EMPTY).join(); + } + + /** Synchronous entry point with an explicit {@link ToolContext}. */ + public String executeTool(String toolName, String argsJson, ToolContext ctx) { + return submitTool(toolName, argsJson, ctx).join(); + } + + public CompletableFuture submitTool(String toolName, String argsJson) { + return submitTool(toolName, argsJson, ToolContext.EMPTY); + } + + /** + * Asynchronous entry point. Deserialize args, run the tool on the worker pool, and apply a + * wall-clock timeout sourced from {@code kyuubi.engine.data.agent.tool.call.timeout}. The + * returned future is guaranteed to complete normally — timeouts, pool saturation, unknown tool, + * and execution failures are all translated into error strings. Callers can therefore use {@code + * .join()} / {@code .get()} without handling {@link java.util.concurrent.TimeoutException} or + * {@link java.util.concurrent.ExecutionException}. + * + * @param toolName the function name from the LLM response + * @param argsJson the raw JSON arguments string + */ + @SuppressWarnings("unchecked") + public CompletableFuture submitTool(String toolName, String argsJson, ToolContext ctx) { + AgentTool tool; synchronized (this) { - tool = tools.get(toolName); + tool = (AgentTool) tools.get(toolName); } if (tool == null) { - return "Error: unknown tool '" + toolName + "'"; + return CompletableFuture.completedFuture("Error: unknown tool '" + toolName + "'"); } - return executeWithTimeout((AgentTool) tool, argsJson); - } + ToolContext toolCtx = ctx != null ? ctx : ToolContext.EMPTY; - private String executeWithTimeout(AgentTool tool, String argsJson) { - Callable task = - () -> { - T args = JSON.readValue(argsJson, tool.argsType()); - return tool.execute(args); - }; - Future future; + CompletableFuture result = new CompletableFuture<>(); + Future submitted; try { - future = executor.submit(task); + submitted = + executor.submit( + () -> { + try { + Object args = JSON.readValue(argsJson, tool.argsType()); + String out = tool.execute(args, toolCtx); + // When the timeout handler interrupts us, the tool may still unwind cleanly and + // produce a stale return value — don't race the scheduler's timeout message with + // it. Let the timeout path be the single authority for the final result. + if (!Thread.currentThread().isInterrupted()) { + result.complete(out); + } + } catch (Exception e) { + result.complete("Error executing " + toolName + ": " + e.getMessage()); + } + }); } catch (RejectedExecutionException e) { - LOG.warn("Tool call '{}' rejected — worker pool saturated at {}", tool.name(), MAX_POOL_SIZE); - return "Error: tool call '" - + tool.name() - + "' rejected — server is handling too many concurrent tool calls. " - + "Retry in a moment."; - } - try { - return future.get(toolCallTimeoutSeconds, TimeUnit.SECONDS); - } catch (TimeoutException e) { - future.cancel(true); - LOG.warn("Tool call '{}' timed out after {} seconds", tool.name(), toolCallTimeoutSeconds); - return "Error: tool call '" - + tool.name() - + "' timed out after " - + toolCallTimeoutSeconds - + " seconds. " - + "Try simplifying the query or adding filters to reduce execution time."; - } catch (ExecutionException e) { - Throwable cause = e.getCause() != null ? e.getCause() : e; - return "Error executing " + tool.name() + ": " + cause.getMessage(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - return "Error: tool call '" + tool.name() + "' was interrupted."; + LOG.warn("Tool call '{}' rejected — worker pool saturated at {}", toolName, MAX_POOL_SIZE); + return CompletableFuture.completedFuture( + "Error: tool call '" + + toolName + + "' rejected — server is handling too many concurrent tool calls. " + + "Retry in a moment."); } + + ScheduledFuture timer = + timeoutScheduler.schedule( + () -> { + if (!result.isDone()) { + // cancel(true) interrupts the worker thread directly — the inner task's + // catch-all will see the interrupt and call result.complete(...), but the + // timeout message below wins because complete() is idempotent on first-winner. + submitted.cancel(true); + LOG.warn( + "Tool call '{}' timed out after {} seconds", toolName, toolCallTimeoutSeconds); + result.complete( + "Error: tool call '" + + toolName + + "' timed out after " + + toolCallTimeoutSeconds + + " seconds. " + + "Try simplifying the query or adding filters to reduce execution time."); + } + }, + toolCallTimeoutSeconds, + TimeUnit.SECONDS); + result.whenComplete((r, e) -> timer.cancel(false)); + return result; } private static ChatCompletionTool buildChatCompletionTool(AgentTool tool) { diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputArgs.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputArgs.java new file mode 100644 index 00000000000..b05a199090a --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputArgs.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool.output; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +/** Args for {@link GrepToolOutputTool}. */ +public class GrepToolOutputArgs { + + @JsonProperty(required = true) + @JsonPropertyDescription( + "Absolute path to the offloaded tool-output file, as reported by the truncation notice.") + public String path; + + @JsonProperty(required = true) + @JsonPropertyDescription( + "Java regex pattern to search for. Matches are returned as ':'.") + public String pattern; + + @JsonProperty("max_matches") + @JsonPropertyDescription("Maximum number of matching lines to return. Defaults to 50.") + public Integer maxMatches; +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputTool.java new file mode 100644 index 00000000000..f94fb21b315 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputTool.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool.output; + +import org.apache.kyuubi.engine.dataagent.runtime.ToolOutputStore; +import org.apache.kyuubi.engine.dataagent.tool.AgentTool; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; + +/** + * Regex-search a previously offloaded tool-output file. Companion to {@code + * ToolResultOffloadMiddleware}. + */ +public class GrepToolOutputTool implements AgentTool { + + public static final String NAME = "grep_tool_output"; + private static final int DEFAULT_MAX_MATCHES = 50; + + private final ToolOutputStore store; + + public GrepToolOutputTool(ToolOutputStore store) { + this.store = store; + } + + @Override + public String name() { + return NAME; + } + + @Override + public String description() { + return "Regex-search a previously offloaded tool-output file " + + "(the path is supplied in the truncation notice of a prior tool result). " + + "Cheaper than read_tool_output when you know what you're looking for. " + + "Returns matching lines as ':'."; + } + + @Override + public Class argsType() { + return GrepToolOutputArgs.class; + } + + @Override + public String execute(GrepToolOutputArgs args, ToolContext ctx) { + if (args == null || args.path == null || args.path.isEmpty()) { + return "Error: 'path' parameter is required."; + } + if (ctx == null || ctx.sessionId() == null) { + return "Error: grep_tool_output requires a session context."; + } + int max = args.maxMatches != null ? args.maxMatches : DEFAULT_MAX_MATCHES; + return store.grep(ctx.sessionId(), args.path, args.pattern, max); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputArgs.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputArgs.java new file mode 100644 index 00000000000..458fbfa4f66 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputArgs.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool.output; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +/** Args for {@link ReadToolOutputTool}. */ +public class ReadToolOutputArgs { + + @JsonProperty(required = true) + @JsonPropertyDescription( + "Absolute path to the offloaded tool-output file, as reported by the truncation notice.") + public String path; + + @JsonPropertyDescription("0-based line offset into the file. Defaults to 0.") + public Integer offset; + + @JsonPropertyDescription( + "Number of lines to return starting at 'offset'. Defaults to 200; capped at 500.") + public Integer limit; +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputTool.java new file mode 100644 index 00000000000..93e837998f3 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputTool.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool.output; + +import org.apache.kyuubi.engine.dataagent.runtime.ToolOutputStore; +import org.apache.kyuubi.engine.dataagent.tool.AgentTool; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; + +/** + * Read a line window from a previously offloaded tool-output file. Companion to {@code + * ToolResultOffloadMiddleware}. + */ +public class ReadToolOutputTool implements AgentTool { + + public static final String NAME = "read_tool_output"; + private static final int DEFAULT_LIMIT = 200; + private static final int MAX_LIMIT = 500; + + private final ToolOutputStore store; + + public ReadToolOutputTool(ToolOutputStore store) { + this.store = store; + } + + @Override + public String name() { + return NAME; + } + + @Override + public String description() { + return "Read a line window from a previously offloaded tool-output file " + + "(the path is supplied in the truncation notice of a prior tool result). " + + "Returns '[lines X-Y of Z total]' header followed by the requested window. " + + "Use when a prior tool's output was truncated and you need to inspect more of it."; + } + + @Override + public Class argsType() { + return ReadToolOutputArgs.class; + } + + @Override + public String execute(ReadToolOutputArgs args, ToolContext ctx) { + if (args == null || args.path == null || args.path.isEmpty()) { + return "Error: 'path' parameter is required."; + } + if (ctx == null || ctx.sessionId() == null) { + return "Error: read_tool_output requires a session context."; + } + int offset = args.offset != null ? args.offset : 0; + int limit = args.limit != null ? args.limit : DEFAULT_LIMIT; + if (limit > MAX_LIMIT) limit = MAX_LIMIT; + return store.read(ctx.sessionId(), args.path, offset, limit); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryTool.java index 06b12f2be72..88838ca477a 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryTool.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryTool.java @@ -19,6 +19,7 @@ import javax.sql.DataSource; import org.apache.kyuubi.engine.dataagent.tool.AgentTool; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; import org.apache.kyuubi.engine.dataagent.tool.ToolRiskLevel; /** @@ -69,7 +70,7 @@ public Class argsType() { } @Override - public String execute(SqlQueryArgs args) { + public String execute(SqlQueryArgs args, ToolContext ctx) { return SqlExecutor.execute(dataSource, args.sql, queryTimeoutSeconds); } } diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryTool.java index 0c57cc049ed..9136b0aa903 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryTool.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryTool.java @@ -19,6 +19,7 @@ import javax.sql.DataSource; import org.apache.kyuubi.engine.dataagent.tool.AgentTool; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; import org.apache.kyuubi.engine.dataagent.tool.ToolRiskLevel; /** @@ -69,7 +70,7 @@ public Class argsType() { } @Override - public String execute(SqlQueryArgs args) { + public String execute(SqlQueryArgs args, ToolContext ctx) { String sql = args.sql; if (sql == null || sql.trim().isEmpty()) { return "Error: 'sql' parameter is required."; diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlReadOnlyChecker.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlReadOnlyChecker.java index d2cac9ae518..52b83182add 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlReadOnlyChecker.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlReadOnlyChecker.java @@ -53,6 +53,7 @@ final class SqlReadOnlyChecker { *
  • {@code USE} — switch session catalog/database; not data-mutating *
  • {@code LIST} — Spark {@code LIST FILE} / {@code LIST JAR} inspection *
  • {@code HELP} — some engines expose interactive help + *
  • {@code PRAGMA} — SQLite schema/metadata inspection * */ private static final Set READ_ONLY_KEYWORDS = @@ -70,7 +71,8 @@ final class SqlReadOnlyChecker { "EXPLAIN", "USE", "LIST", - "HELP"))); + "HELP", + "PRAGMA"))); private SqlReadOnlyChecker() {} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/util/ConfUtils.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/util/ConfUtils.java new file mode 100644 index 00000000000..f4366094670 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/util/ConfUtils.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.util; + +import org.apache.kyuubi.config.ConfigEntry; +import org.apache.kyuubi.config.KyuubiConf; +import org.apache.kyuubi.config.OptionalConfigEntry; + +/** Small helpers for reading typed values out of {@link KyuubiConf}. */ +public final class ConfUtils { + + private ConfUtils() {} + + /** Return the string value, or throw if the entry is not set. */ + public static String requireString(KyuubiConf conf, OptionalConfigEntry key) { + scala.Option opt = conf.get(key); + if (opt.isEmpty()) { + throw new IllegalArgumentException(key.key() + " is required"); + } + return opt.get(); + } + + /** Return the string value, or {@code null} if the entry is not set. */ + public static String optionalString(KyuubiConf conf, OptionalConfigEntry key) { + scala.Option opt = conf.get(key); + return opt.isDefined() ? opt.get() : null; + } + + /** Return the value for a raw key, or {@code null} if not set. */ + public static String optionalString(KyuubiConf conf, String key) { + scala.Option opt = conf.getOption(key); + return opt.isDefined() ? opt.get() : null; + } + + public static int intConf(KyuubiConf conf, ConfigEntry key) { + return ((Number) conf.get(key)).intValue(); + } + + public static long longConf(KyuubiConf conf, ConfigEntry key) { + return ((Number) conf.get(key)).longValue(); + } + + /** Read a millisecond-valued entry and return it as whole seconds. */ + public static long millisAsSeconds(KyuubiConf conf, ConfigEntry key) { + return longConf(conf, key) / 1000L; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/scala/org/apache/kyuubi/engine/dataagent/operation/ExecuteStatement.scala b/externals/kyuubi-data-agent-engine/src/main/scala/org/apache/kyuubi/engine/dataagent/operation/ExecuteStatement.scala index 04d0defa2dc..3d902677a0a 100644 --- a/externals/kyuubi-data-agent-engine/src/main/scala/org/apache/kyuubi/engine/dataagent/operation/ExecuteStatement.scala +++ b/externals/kyuubi-data-agent-engine/src/main/scala/org/apache/kyuubi/engine/dataagent/operation/ExecuteStatement.scala @@ -25,7 +25,7 @@ import org.slf4j.MDC import org.apache.kyuubi.{KyuubiSQLException, Logging} import org.apache.kyuubi.config.KyuubiConf import org.apache.kyuubi.engine.dataagent.provider.{DataAgentProvider, ProviderRunRequest} -import org.apache.kyuubi.engine.dataagent.runtime.event.{AgentError, AgentEvent, AgentFinish, ApprovalRequest, ContentDelta, EventType, StepEnd, StepStart, ToolCall, ToolResult} +import org.apache.kyuubi.engine.dataagent.runtime.event.{AgentError, AgentEvent, AgentFinish, ApprovalRequest, Compaction, ContentDelta, EventType, StepEnd, StepStart, ToolCall, ToolResult} import org.apache.kyuubi.operation.OperationState import org.apache.kyuubi.operation.log.OperationLog import org.apache.kyuubi.session.Session @@ -117,7 +117,7 @@ class ExecuteStatement( n.put("type", sseType) n.put("id", toolCall.toolCallId()) n.put("name", toolCall.toolName()) - n.set("args", JSON.valueToTree(toolCall.toolArgs())) + n.set[ObjectNode]("args", JSON.valueToTree(toolCall.toolArgs())) })) case EventType.TOOL_RESULT => val toolResult = event.asInstanceOf[ToolResult] @@ -148,6 +148,15 @@ class ExecuteStatement( n.set("args", JSON.valueToTree(req.toolArgs())) n.put("riskLevel", req.riskLevel().name()) })) + case EventType.COMPACTION => + val c = event.asInstanceOf[Compaction] + incrementalIter.append(Array(toJson { n => + n.put("type", sseType) + n.put("summarized", c.summarizedCount()) + n.put("kept", c.keptCount()) + n.put("triggerTokens", c.triggerTokens()) + n.put("observedTokens", c.observedTokens()) + })) case EventType.AGENT_FINISH => val finish = event.asInstanceOf[AgentFinish] incrementalIter.append(Array(toJson { n => diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialectTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialectTest.java index e728ff871e7..c43942a8f35 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialectTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialectTest.java @@ -19,6 +19,7 @@ import static org.junit.Assert.*; +import org.apache.kyuubi.engine.dataagent.datasource.dialect.GenericDialect; import org.junit.Test; public class JdbcDialectTest { diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java index 4e713b9cca3..cc45ebdc7e7 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java @@ -20,8 +20,9 @@ import static org.junit.Assert.*; import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect; -import org.apache.kyuubi.engine.dataagent.datasource.MysqlDialect; +import org.apache.kyuubi.engine.dataagent.datasource.dialect.MysqlDialect; import org.apache.kyuubi.engine.dataagent.prompt.SystemPromptBuilder; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; import org.apache.kyuubi.engine.dataagent.tool.sql.RunSelectQueryTool; import org.apache.kyuubi.engine.dataagent.tool.sql.SqlQueryArgs; import org.junit.BeforeClass; @@ -66,7 +67,7 @@ public void testBacktickQuotingWithReservedWord() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT " + quotedCol + " FROM " + quotedTable + " WHERE id = 1"; - String result = selectTool.execute(args); + String result = selectTool.execute(args, ToolContext.EMPTY); assertFalse(result.startsWith("Error:")); assertTrue(result.contains("value1")); diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/provider/mock/MockLlmProvider.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/provider/mock/MockLlmProvider.java new file mode 100644 index 00000000000..2756b3e2087 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/provider/mock/MockLlmProvider.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.provider.mock; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import javax.sql.DataSource; +import org.apache.kyuubi.config.KyuubiConf; +import org.apache.kyuubi.engine.dataagent.datasource.DataSourceFactory; +import org.apache.kyuubi.engine.dataagent.provider.DataAgentProvider; +import org.apache.kyuubi.engine.dataagent.provider.ProviderRunRequest; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentFinish; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentStart; +import org.apache.kyuubi.engine.dataagent.runtime.event.ContentComplete; +import org.apache.kyuubi.engine.dataagent.runtime.event.ContentDelta; +import org.apache.kyuubi.engine.dataagent.runtime.event.StepEnd; +import org.apache.kyuubi.engine.dataagent.runtime.event.StepStart; +import org.apache.kyuubi.engine.dataagent.runtime.event.ToolCall; +import org.apache.kyuubi.engine.dataagent.runtime.event.ToolResult; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; +import org.apache.kyuubi.engine.dataagent.tool.sql.RunSelectQueryTool; +import org.apache.kyuubi.engine.dataagent.util.ConfUtils; + +/** + * A mock LLM provider for testing the full tool-call pipeline without a real LLM. Simulates the + * ReAct loop: extracts SQL from the user question, executes it via SqlQueryTool, and returns the + * result as a formatted answer. + * + *

    Recognizes two patterns: + * + *

      + *
    • Questions containing SQL keywords (SELECT, SHOW, DESCRIBE) — extracts and executes the SQL + *
    • All other questions — returns a canned response without tool calls + *
    + */ +public class MockLlmProvider implements DataAgentProvider { + + private static final Pattern SQL_PATTERN = + Pattern.compile( + "(SELECT\\b.+|SHOW\\b.+|DESCRIBE\\b.+)", Pattern.CASE_INSENSITIVE | Pattern.DOTALL); + + /** + * Simple natural-language-to-SQL mappings so tests can use human-readable questions instead of + * raw SQL. Checked before the regex pattern — if a question matches a key (case-insensitive + * prefix), the mapped SQL is executed. + */ + private static final Map NL_TO_SQL = new java.util.LinkedHashMap<>(); + + static { + NL_TO_SQL.put( + "list all employee names and departments", + "SELECT name, department FROM employees ORDER BY id"); + NL_TO_SQL.put( + "how many employees in each department", + "SELECT department, COUNT(*) as cnt FROM employees GROUP BY department"); + NL_TO_SQL.put("count the total number of employees", "SELECT COUNT(*) FROM employees"); + } + + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + private final ToolRegistry toolRegistry; + private final DataSource dataSource; + + public MockLlmProvider(KyuubiConf conf) { + String jdbcUrl = ConfUtils.requireString(conf, KyuubiConf.ENGINE_DATA_AGENT_JDBC_URL()); + this.dataSource = DataSourceFactory.create(jdbcUrl); + this.toolRegistry = new ToolRegistry(30); + this.toolRegistry.register(new RunSelectQueryTool(dataSource, 0)); + } + + @Override + public void open(String sessionId, String user) { + sessions.put(sessionId, new Object()); + } + + @Override + public void run(String sessionId, ProviderRunRequest request, Consumer onEvent) { + String question = request.getQuestion(); + onEvent.accept(new AgentStart()); + + // Trigger an error for testing the error path in ExecuteStatement + if (question.trim().equalsIgnoreCase("__error__")) { + throw new RuntimeException("MockLlmProvider simulated failure"); + } + + // First check natural-language mappings, then fall back to SQL pattern extraction + String sql = resolveToSql(question); + if (sql != null) { + runWithToolCall(sql, onEvent); + } else { + runWithoutToolCall(question, onEvent); + } + } + + private void runWithToolCall(String sql, Consumer onEvent) { + // Step 1: LLM "decides" to call sql_query tool + onEvent.accept(new StepStart(1)); + String toolCallId = "mock_call_" + System.nanoTime(); + Map toolArgs = new HashMap<>(); + toolArgs.put("sql", sql); + onEvent.accept(new ToolCall(toolCallId, "run_select_query", toolArgs)); + + // Execute the tool + String toolOutput = + toolRegistry.executeTool("run_select_query", "{\"sql\":\"" + escapeJson(sql) + "\"}"); + onEvent.accept(new ToolResult(toolCallId, "run_select_query", toolOutput, false)); + onEvent.accept(new StepEnd(1)); + + // Step 2: LLM "summarizes" the result + onEvent.accept(new StepStart(2)); + String answer = "Based on the query result:\n\n" + toolOutput; + for (String token : answer.split("(?<=\\n)")) { + onEvent.accept(new ContentDelta(token)); + } + onEvent.accept(new ContentComplete(answer)); + onEvent.accept(new StepEnd(2)); + onEvent.accept(new AgentFinish(2, 100, 50, 150)); + } + + private void runWithoutToolCall(String question, Consumer onEvent) { + onEvent.accept(new StepStart(1)); + String answer = "[MockLLM] No SQL detected in: " + question; + onEvent.accept(new ContentDelta(answer)); + onEvent.accept(new ContentComplete(answer)); + onEvent.accept(new StepEnd(1)); + onEvent.accept(new AgentFinish(1, 50, 20, 70)); + } + + @Override + public void close(String sessionId) { + sessions.remove(sessionId); + } + + @Override + public void stop() { + if (dataSource instanceof com.zaxxer.hikari.HikariDataSource) { + ((com.zaxxer.hikari.HikariDataSource) dataSource).close(); + } + } + + /** + * Resolve a user question to SQL. Checks NL_TO_SQL mappings first, then falls back to regex + * extraction of raw SQL from the input. + */ + private static String resolveToSql(String question) { + String lower = question.toLowerCase().trim(); + for (Map.Entry entry : NL_TO_SQL.entrySet()) { + if (lower.startsWith(entry.getKey())) { + return entry.getValue(); + } + } + Matcher matcher = SQL_PATTERN.matcher(question); + if (matcher.find()) { + return matcher.group(1).trim(); + } + return null; + } + + private static String escapeJson(String s) { + return s.replace("\\", "\\\\") + .replace("\"", "\\\"") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\t", "\\t"); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemoryTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemoryTest.java new file mode 100644 index 00000000000..d8eb96cfbc1 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemoryTest.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +import static org.junit.Assert.*; + +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import com.openai.models.chat.completions.ChatCompletionUserMessageParam; +import java.util.Collections; +import org.junit.Test; + +public class ConversationMemoryTest { + + @Test + public void testReplaceHistoryClearsLastTotalTokensButKeepsCumulative() { + ConversationMemory memory = new ConversationMemory(); + memory.addCumulativeTokens(100, 50, 150); + memory.addCumulativeTokens(200, 80, 280); + assertEquals(280L, memory.getLastTotalTokens()); + assertEquals(300L, memory.getCumulativePromptTokens()); + assertEquals(430L, memory.getCumulativeTotalTokens()); + + ChatCompletionMessageParam summary = + ChatCompletionMessageParam.ofUser( + ChatCompletionUserMessageParam.builder().content("summary").build()); + memory.replaceHistory(Collections.singletonList(summary)); + + assertEquals("lastTotalTokens reset after compaction", 0L, memory.getLastTotalTokens()); + assertEquals("cumulative totals preserved", 300L, memory.getCumulativePromptTokens()); + assertEquals("cumulative totals preserved", 430L, memory.getCumulativeTotalTokens()); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgentLiveTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgentLiveTest.java new file mode 100644 index 00000000000..75553f46998 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgentLiveTest.java @@ -0,0 +1,568 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +import static org.junit.Assert.*; +import static org.junit.Assume.assumeTrue; + +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import java.io.File; +import java.sql.Connection; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import org.apache.kyuubi.engine.dataagent.prompt.SystemPromptBuilder; +import org.apache.kyuubi.engine.dataagent.runtime.event.*; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.ApprovalMiddleware; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.LoggingMiddleware; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.ToolResultOffloadMiddleware; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; +import org.apache.kyuubi.engine.dataagent.tool.sql.RunMutationQueryTool; +import org.apache.kyuubi.engine.dataagent.tool.sql.RunSelectQueryTool; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.sqlite.SQLiteDataSource; + +/** + * Live integration test with a real LLM and real SQLite database. Exercises the full ReAct loop: + * LLM reasoning -> tool calls -> result verification. Requires DATA_AGENT_LLM_API_KEY and + * DATA_AGENT_LLM_API_URL environment variables. Works with any OpenAI-compatible LLM service. + */ +public class ReactAgentLiveTest { + + private static final String API_KEY = System.getenv().getOrDefault("DATA_AGENT_LLM_API_KEY", ""); + private static final String BASE_URL = System.getenv().getOrDefault("DATA_AGENT_LLM_API_URL", ""); + private static final String MODEL_NAME = + System.getenv().getOrDefault("DATA_AGENT_LLM_MODEL", "gpt-4o"); + + private static final String SYSTEM_PROMPT = + SystemPromptBuilder.create().datasource("sqlite").build(); + + private final List tempFiles = new ArrayList<>(); + private OpenAIClient client; + + @Before + public void setUp() { + assumeTrue("DATA_AGENT_LLM_API_KEY not set, skipping live tests", !API_KEY.isEmpty()); + assumeTrue("DATA_AGENT_LLM_API_URL not set, skipping live tests", !BASE_URL.isEmpty()); + client = OpenAIOkHttpClient.builder().apiKey(API_KEY).baseUrl(BASE_URL).build(); + } + + @After + public void tearDown() { + tempFiles.forEach(File::delete); + } + + @Test + public void testPlainTextStreamingWithoutTools() { + ReactAgent agent = + ReactAgent.builder() + .client(client) + .modelName(MODEL_NAME) + .toolRegistry(new ToolRegistry(30)) + .addMiddleware(new LoggingMiddleware()) + .maxIterations(3) + .systemPrompt("You are a helpful assistant. Answer concisely in 1-2 sentences.") + .build(); + + List events = new CopyOnWriteArrayList<>(); + ConversationMemory memory = new ConversationMemory(); + + agent.run(new AgentInvocation("What is Apache Kyuubi?"), memory, events::add); + + List deltas = + events.stream() + .filter(e -> e instanceof ContentDelta) + .map(e -> ((ContentDelta) e).text()) + .collect(Collectors.toList()); + + assertTrue("Expected multiple ContentDelta events", deltas.size() > 1); + assertFalse("Streamed text should not be empty", String.join("", deltas).isEmpty()); + assertTrue(events.stream().anyMatch(e -> e instanceof StepStart)); + assertTrue(events.stream().anyMatch(e -> e instanceof ContentComplete)); + assertTrue(events.get(events.size() - 1) instanceof AgentFinish); + assertEquals(2, memory.getHistory().size()); // user + assistant + } + + @Test + public void testFullReActLoopWithSchemaInspectAndSqlQuery() { + SQLiteDataSource ds = createSalesDatabase(); + ToolRegistry registry = new ToolRegistry(30); + registry.register(new RunSelectQueryTool(ds, 0)); + + ReactAgent agent = + ReactAgent.builder() + .client(client) + .modelName(MODEL_NAME) + .toolRegistry(registry) + .addMiddleware(new LoggingMiddleware()) + .maxIterations(10) + .systemPrompt(SYSTEM_PROMPT) + .build(); + + List events = new CopyOnWriteArrayList<>(); + ConversationMemory memory = new ConversationMemory(); + + agent.run( + new AgentInvocation( + "What is the total revenue by product category? Which category has the highest revenue?"), + memory, + events::add); + + // Verify tool calls happened + List toolCalls = + events.stream() + .filter(e -> e instanceof ToolCall) + .map(e -> (ToolCall) e) + .collect(Collectors.toList()); + List toolResults = + events.stream() + .filter(e -> e instanceof ToolResult) + .map(e -> (ToolResult) e) + .collect(Collectors.toList()); + + assertFalse("Agent should have called at least one tool", toolCalls.isEmpty()); + assertFalse("Agent should have received tool results", toolResults.isEmpty()); + assertTrue("Tool calls should not error", toolResults.stream().noneMatch(ToolResult::isError)); + + // Verify SQL query was called + assertTrue( + "Agent should execute at least one SQL query", + toolCalls.stream().anyMatch(tc -> "run_select_query".equals(tc.toolName()))); + + // Verify final answer mentions "Electronics" (highest revenue) + List completions = + events.stream() + .filter(e -> e instanceof ContentComplete) + .map(e -> ((ContentComplete) e).fullText()) + .collect(Collectors.toList()); + String lastAnswer = completions.get(completions.size() - 1); + assertTrue( + "Final answer should mention Electronics, got: " + lastAnswer, + lastAnswer.toLowerCase().contains("electronics")); + + // Verify agent finished successfully + AgentEvent last = events.get(events.size() - 1); + assertTrue(last instanceof AgentFinish); + assertTrue("Should take multiple steps", ((AgentFinish) last).totalSteps() > 1); + } + + @Test + public void testMultiTurnConversationWithToolUse() { + SQLiteDataSource ds = createSalesDatabase(); + ToolRegistry registry = new ToolRegistry(30); + registry.register(new RunSelectQueryTool(ds, 0)); + + ReactAgent agent = + ReactAgent.builder() + .client(client) + .modelName(MODEL_NAME) + .toolRegistry(registry) + .addMiddleware(new LoggingMiddleware()) + .maxIterations(10) + .systemPrompt(SYSTEM_PROMPT) + .build(); + + // Shared memory across turns + ConversationMemory memory = new ConversationMemory(); + + // Turn 1 + List events1 = new CopyOnWriteArrayList<>(); + agent.run(new AgentInvocation("How many orders are there in total?"), memory, events1::add); + + assertTrue( + "Turn 1 should query the database", + events1.stream() + .anyMatch( + e -> + e instanceof ToolCall && "run_select_query".equals(((ToolCall) e).toolName()))); + assertTrue(events1.get(events1.size() - 1) instanceof AgentFinish); + + // Turn 2: follow-up relying on conversation context + List events2 = new CopyOnWriteArrayList<>(); + agent.run( + new AgentInvocation("Now show me only orders above 500 dollars."), memory, events2::add); + + assertTrue( + "Turn 2 should also query the database", + events2.stream() + .anyMatch( + e -> + e instanceof ToolCall && "run_select_query".equals(((ToolCall) e).toolName()))); + assertTrue(events2.get(events2.size() - 1) instanceof AgentFinish); + + // Verify memory accumulated across both turns + assertTrue( + "Memory should contain messages from both turns, got " + memory.getHistory().size(), + memory.getHistory().size() > 4); + } + + @Test + public void testToolOutputOffloadThenGrep() throws Exception { + // Large result forces ToolResultOffloadMiddleware to truncate the tool output and + // emit a preview hint telling the LLM to use grep_tool_output / read_tool_output. + // A correct answer proves the LLM read the hint and drove the retrieval itself. + SQLiteDataSource ds = createNeedleInHaystackDatabase(); + ToolRegistry registry = new ToolRegistry(30); + registry.register(new RunSelectQueryTool(ds, 0)); + + ReactAgent agent = + ReactAgent.builder() + .client(client) + .modelName(MODEL_NAME) + .toolRegistry(registry) + .addMiddleware(new LoggingMiddleware()) + .addMiddleware(new ToolResultOffloadMiddleware()) + .maxIterations(10) + .systemPrompt(SYSTEM_PROMPT) + .build(); + + List events = new CopyOnWriteArrayList<>(); + ConversationMemory memory = new ConversationMemory(); + + // Explicit workflow: full-table scan first, then retrieval tool. Without this hint a + // capable LLM just issues SELECT note FROM events WHERE tag='NEEDLE' and skips the + // offload path entirely -- smart behavior, but defeats the purpose of this test. + agent.run( + new AgentInvocation( + "Step 1: issue exactly this query: SELECT id, tag, note FROM events (no" + + " WHERE clause, return every row). Step 2: the result will be" + + " truncated; call the grep_tool_output tool with pattern 'NEEDLE' on" + + " the saved output file to find the matching row. Step 3: respond with" + + " ONLY the note text from that row, nothing else. Do NOT add a WHERE" + + " clause. Do NOT issue any other SQL.") + // Offload middleware requires a non-null session id -- without it the offload + // path skips entirely (see ToolResultOffloadMiddleware.afterToolCall). + .sessionId("offload-live-" + java.util.UUID.randomUUID()), + memory, + events::add); + + List toolCalls = + events.stream() + .filter(e -> e instanceof ToolCall) + .map(e -> (ToolCall) e) + .collect(Collectors.toList()); + List toolResults = + events.stream() + .filter(e -> e instanceof ToolResult) + .map(e -> (ToolResult) e) + .collect(Collectors.toList()); + + // Dump the tool trace -- diagnoses whether the LLM followed the workflow, added a + // WHERE clause despite instructions, or deviated in some other way. + for (ToolCall tc : toolCalls) { + System.out.println("[ToolCall] " + tc.toolName() + " args=" + tc.toolArgs()); + } + for (ToolResult tr : toolResults) { + String out = tr.output(); + String preview = + out.length() > 300 + ? out.substring(0, 300) + "...(+" + (out.length() - 300) + " chars)" + : out; + System.out.println("[ToolResult] " + tr.toolName() + " -> " + preview); + } + + assertTrue( + "Agent should have run a select query first", + toolCalls.stream().anyMatch(tc -> "run_select_query".equals(tc.toolName()))); + + // The SELECT returned 800 rows, which must trip the offload threshold. + assertTrue( + "Expected at least one offload preview marker in tool results", + toolResults.stream().anyMatch(tr -> tr.output().contains("Tool output truncated"))); + + assertTrue( + "Agent should have used grep_tool_output or read_tool_output after seeing" + + " the offload preview; actual tool calls: " + + toolCalls.stream().map(ToolCall::toolName).collect(Collectors.toList()), + toolCalls.stream() + .anyMatch( + tc -> + "grep_tool_output".equals(tc.toolName()) + || "read_tool_output".equals(tc.toolName()))); + + String finalAnswer = + events.stream() + .filter(e -> e instanceof ContentComplete) + .map(e -> ((ContentComplete) e).fullText()) + .reduce((a, b) -> b) + .orElse(""); + assertTrue( + "Final answer should contain the needle note 'the-answer-is-42', got: " + finalAnswer, + finalAnswer.contains("the-answer-is-42")); + } + + @Test + public void testApprovalApproveFlow() throws Exception { + // Real LLM picks run_mutation_query (DESTRUCTIVE) -> ApprovalMiddleware pauses -> + // background thread resolves(approved=true) -> mutation actually runs in SQLite. + SQLiteDataSource ds = createCountersDatabase(); + ToolRegistry registry = new ToolRegistry(30); + registry.register(new RunSelectQueryTool(ds, 0)); + registry.register(new RunMutationQueryTool(ds, 0)); + + ApprovalMiddleware approval = new ApprovalMiddleware(30); + + ReactAgent agent = + ReactAgent.builder() + .client(client) + .modelName(MODEL_NAME) + .toolRegistry(registry) + .addMiddleware(new LoggingMiddleware()) + .addMiddleware(approval) + .maxIterations(10) + .systemPrompt(SYSTEM_PROMPT) + .build(); + + List events = new CopyOnWriteArrayList<>(); + ConversationMemory memory = new ConversationMemory(); + + // Auto-approve any approval request that shows up. + ExecutorService approver = Executors.newSingleThreadExecutor(); + Consumer listener = + event -> { + events.add(event); + if (event instanceof ApprovalRequest) { + String rid = ((ApprovalRequest) event).requestId(); + approver.submit(() -> approval.resolve(rid, true)); + } + }; + + try { + agent.run( + new AgentInvocation( + "Increment the 'hits' counter in the counters table by 1, then tell me its" + + " new value. Respond with ONLY the new value, no explanation."), + memory, + listener); + } finally { + approver.shutdown(); + approver.awaitTermination(5, TimeUnit.SECONDS); + } + + List approvals = + events.stream() + .filter(e -> e instanceof ApprovalRequest) + .map(e -> (ApprovalRequest) e) + .collect(Collectors.toList()); + assertFalse("Expected at least one ApprovalRequest", approvals.isEmpty()); + assertTrue( + "ApprovalRequest should target run_mutation_query", + approvals.stream().anyMatch(a -> "run_mutation_query".equals(a.toolName()))); + + // Mutation must have actually executed — check the DB directly. + try (Connection conn = ds.getConnection(); + Statement stmt = conn.createStatement(); + java.sql.ResultSet rs = stmt.executeQuery("SELECT value FROM counters WHERE name='hits'")) { + assertTrue(rs.next()); + assertEquals("Counter should be 1 after approved mutation", 1, rs.getInt(1)); + } + + String finalAnswer = + events.stream() + .filter(e -> e instanceof ContentComplete) + .map(e -> ((ContentComplete) e).fullText()) + .reduce((a, b) -> b) + .orElse(""); + assertTrue("Final answer should mention 1, got: " + finalAnswer, finalAnswer.contains("1")); + } + + @Test + public void testApprovalDenyFlow() throws Exception { + // Same setup as the approve test, but the approval listener denies. The mutation + // must NOT run, and the LLM must surface the denial to the user naturally. + SQLiteDataSource ds = createCountersDatabase(); + ToolRegistry registry = new ToolRegistry(30); + registry.register(new RunSelectQueryTool(ds, 0)); + registry.register(new RunMutationQueryTool(ds, 0)); + + ApprovalMiddleware approval = new ApprovalMiddleware(30); + + ReactAgent agent = + ReactAgent.builder() + .client(client) + .modelName(MODEL_NAME) + .toolRegistry(registry) + .addMiddleware(new LoggingMiddleware()) + .addMiddleware(approval) + .maxIterations(10) + .systemPrompt(SYSTEM_PROMPT) + .build(); + + List events = new CopyOnWriteArrayList<>(); + ConversationMemory memory = new ConversationMemory(); + + ExecutorService approver = Executors.newSingleThreadExecutor(); + Consumer listener = + event -> { + events.add(event); + if (event instanceof ApprovalRequest) { + String rid = ((ApprovalRequest) event).requestId(); + approver.submit(() -> approval.resolve(rid, false)); + } + }; + + try { + agent.run( + new AgentInvocation( + "Delete all rows from the counters table. If you cannot, explain why."), + memory, + listener); + } finally { + approver.shutdown(); + approver.awaitTermination(5, TimeUnit.SECONDS); + } + + assertTrue( + "Expected at least one ApprovalRequest", + events.stream().anyMatch(e -> e instanceof ApprovalRequest)); + + // DB must be untouched. + try (Connection conn = ds.getConnection(); + Statement stmt = conn.createStatement(); + java.sql.ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM counters")) { + assertTrue(rs.next()); + assertTrue("Counters rows must survive denied mutation", rs.getInt(1) > 0); + } + + // LLM should tell the user the operation didn't go through. Loose lexical check to + // absorb model wording drift. + String finalAnswer = + events.stream() + .filter(e -> e instanceof ContentComplete) + .map(e -> ((ContentComplete) e).fullText()) + .reduce((a, b) -> b) + .orElse("") + .toLowerCase(); + assertTrue( + "Final answer should indicate the deletion was refused/denied/not executed, got: " + + finalAnswer, + finalAnswer.contains("den") + || finalAnswer.contains("reject") + || finalAnswer.contains("not ") + || finalAnswer.contains("refus") + || finalAnswer.contains("unable") + || finalAnswer.contains("could not")); + } + + // --- Helpers --- + + private SQLiteDataSource createSalesDatabase() { + SQLiteDataSource ds = createDataSource(); + try (Connection conn = ds.getConnection(); + Statement stmt = conn.createStatement()) { + stmt.execute( + "CREATE TABLE products (" + + "id INTEGER PRIMARY KEY, name TEXT NOT NULL, " + + "category TEXT NOT NULL, price REAL NOT NULL)"); + stmt.execute( + "INSERT INTO products VALUES " + + "(1, 'Laptop', 'Electronics', 999.99), " + + "(2, 'Headphones', 'Electronics', 199.99), " + + "(3, 'T-Shirt', 'Clothing', 29.99), " + + "(4, 'Jeans', 'Clothing', 59.99), " + + "(5, 'Novel', 'Books', 14.99), " + + "(6, 'Textbook', 'Books', 89.99)"); + stmt.execute( + "CREATE TABLE orders (" + + "id INTEGER PRIMARY KEY, product_id INTEGER NOT NULL, " + + "customer_name TEXT NOT NULL, quantity INTEGER NOT NULL, " + + "order_date TEXT NOT NULL, " + + "FOREIGN KEY (product_id) REFERENCES products(id))"); + stmt.execute( + "INSERT INTO orders VALUES " + + "(1, 1, 'Alice', 1, '2024-01-15'), " + + "(2, 2, 'Bob', 2, '2024-01-20'), " + + "(3, 3, 'Charlie', 3, '2024-02-01'), " + + "(4, 4, 'Alice', 1, '2024-02-10'), " + + "(5, 5, 'Bob', 5, '2024-02-15'), " + + "(6, 1, 'Diana', 1, '2024-03-01'), " + + "(7, 6, 'Charlie', 2, '2024-03-05'), " + + "(8, 2, 'Diana', 1, '2024-03-10')"); + } catch (Exception e) { + throw new RuntimeException(e); + } + return ds; + } + + private SQLiteDataSource createNeedleInHaystackDatabase() { + SQLiteDataSource ds = createDataSource(); + try (Connection conn = ds.getConnection(); + Statement stmt = conn.createStatement()) { + stmt.execute( + "CREATE TABLE events (" + + "id INTEGER PRIMARY KEY, tag TEXT NOT NULL, note TEXT NOT NULL)"); + conn.setAutoCommit(false); + try (java.sql.PreparedStatement ps = + conn.prepareStatement("INSERT INTO events VALUES (?, ?, ?)")) { + // 800 filler rows, guaranteed to blow past ToolResultOffloadMiddleware's 500-line + // threshold when the LLM issues a SELECT *. Exactly one row carries the NEEDLE tag. + int needleId = 573; + for (int i = 1; i <= 800; i++) { + ps.setInt(1, i); + if (i == needleId) { + ps.setString(2, "NEEDLE"); + ps.setString(3, "the-answer-is-42"); + } else { + ps.setString(2, "FILLER"); + ps.setString(3, "filler-note-" + i); + } + ps.addBatch(); + } + ps.executeBatch(); + } + conn.commit(); + } catch (Exception e) { + throw new RuntimeException(e); + } + return ds; + } + + private SQLiteDataSource createCountersDatabase() { + SQLiteDataSource ds = createDataSource(); + try (Connection conn = ds.getConnection(); + Statement stmt = conn.createStatement()) { + stmt.execute("CREATE TABLE counters (name TEXT PRIMARY KEY, value INTEGER NOT NULL)"); + stmt.execute("INSERT INTO counters VALUES ('hits', 0)"); + } catch (Exception e) { + throw new RuntimeException(e); + } + return ds; + } + + private SQLiteDataSource createDataSource() { + try { + File tmpFile = File.createTempFile("kyuubi-agent-live-", ".db"); + tmpFile.deleteOnExit(); + tempFiles.add(tmpFile); + SQLiteDataSource ds = new SQLiteDataSource(); + ds.setUrl("jdbc:sqlite:" + tmpFile.getAbsolutePath()); + return ds; + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStoreTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStoreTest.java new file mode 100644 index 00000000000..e3eda5ff5e4 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStoreTest.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +import static org.junit.Assert.*; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import org.junit.Before; +import org.junit.Test; + +public class ToolOutputStoreTest { + + private ToolOutputStore store; + + @Before + public void setUp() throws IOException { + store = ToolOutputStore.create(); + } + + @Test + public void writeAndReadWindow() throws IOException { + StringBuilder sb = new StringBuilder(); + for (int i = 1; i <= 100; i++) sb.append("row").append(i).append('\n'); + Path p = store.write("sess1", "call1", sb.toString()); + assertTrue(Files.exists(p)); + + String out = store.read("sess1", p.toString(), 10, 5); + assertTrue(out, out.contains("lines 11-15 of")); + assertTrue(out, out.contains("row11")); + assertTrue(out, out.contains("row15")); + assertFalse(out, out.contains("row16")); + assertFalse(out, out.contains("row10")); + } + + @Test + public void grepReturnsMatchingLinesWithLineNumbers() throws IOException { + String content = "apple\nbanana\ncherry\napple pie\ndate\n"; + Path p = store.write("sess1", "call1", content); + + String out = store.grep("sess1", p.toString(), "apple", 10); + assertTrue(out, out.contains("1:apple")); + assertTrue(out, out.contains("4:apple pie")); + assertFalse(out, out.contains("banana")); + } + + @Test + public void grepRespectsMaxMatches() throws IOException { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 20; i++) sb.append("hit\n"); + Path p = store.write("sess1", "call1", sb.toString()); + + String out = store.grep("sess1", p.toString(), "hit", 3); + assertTrue(out, out.contains("[3 matches]")); + assertTrue(out, out.contains("1:hit")); + assertTrue(out, out.contains("3:hit")); + assertFalse("should stop after 3 matches", out.contains("4:hit")); + } + + @Test + public void grepInvalidRegexReturnsError() throws IOException { + Path p = store.write("sess1", "call1", "x\n"); + String out = store.grep("sess1", p.toString(), "[", 10); + assertTrue(out, out.startsWith("Error:")); + } + + @Test + public void readRejectsCrossSessionPath() throws IOException { + Path victim = store.write("victim", "secret_call", "top secret\n"); + assertTrue(Files.exists(victim)); + + String out = store.read("attacker", victim.toString(), 0, 10); + assertTrue(out, out.startsWith("Error:")); + assertFalse(out, out.contains("top secret")); + } + + @Test + public void grepRejectsCrossSessionPath() throws IOException { + Path victim = store.write("victim", "secret_call", "api_key=xyz\n"); + String out = store.grep("attacker", victim.toString(), "api_key", 10); + assertTrue(out, out.startsWith("Error:")); + assertFalse(out, out.contains("xyz")); + } + + @Test + public void cleanupSessionRemovesSubtree() throws IOException { + Path p1 = store.write("sessA", "call1", "a\n"); + Path p2 = store.write("sessA", "call2", "b\n"); + Path p3 = store.write("sessB", "call1", "c\n"); + assertTrue(Files.exists(p1)); + assertTrue(Files.exists(p2)); + assertTrue(Files.exists(p3)); + + store.cleanupSession("sessA"); + + assertFalse(Files.exists(p1)); + assertFalse(Files.exists(p2)); + assertTrue("other sessions untouched", Files.exists(p3)); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventTest.java index 50d22da416d..b6a5f093b61 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventTest.java @@ -113,6 +113,7 @@ public void testEventTypeSseNames() { assertEquals("step_end", EventType.STEP_END.sseEventName()); assertEquals("error", EventType.ERROR.sseEventName()); assertEquals("approval_request", EventType.APPROVAL_REQUEST.sseEventName()); + assertEquals("compaction", EventType.COMPACTION.sseEventName()); assertEquals("agent_finish", EventType.AGENT_FINISH.sseEventName()); } @@ -123,6 +124,6 @@ public void testAllEventTypesHaveUniqueSseNames() { for (EventType type : values) { assertTrue("Duplicate SSE name: " + type.sseEventName(), names.add(type.sseEventName())); } - assertEquals(10, values.length); + assertEquals(11, values.length); } } diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddlewareTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddlewareTest.java new file mode 100644 index 00000000000..a84bbc25948 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddlewareTest.java @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import static org.junit.Assert.*; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.ApprovalMode; +import org.apache.kyuubi.engine.dataagent.runtime.ConversationMemory; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent; +import org.apache.kyuubi.engine.dataagent.runtime.event.ApprovalRequest; +import org.apache.kyuubi.engine.dataagent.runtime.event.EventType; +import org.apache.kyuubi.engine.dataagent.tool.AgentTool; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; +import org.apache.kyuubi.engine.dataagent.tool.ToolRiskLevel; +import org.junit.Before; +import org.junit.Test; + +public class ApprovalMiddlewareTest { + + private ToolRegistry registry; + private List emittedEvents; + + @Before + public void setUp() { + registry = new ToolRegistry(30); + registry.register(safeTool("safe_tool")); + registry.register(destructiveTool("dangerous_tool")); + emittedEvents = Collections.synchronizedList(new ArrayList<>()); + } + + // --- Auto-approve mode: all tools pass --- + + @Test + public void testAutoApproveModeSkipsAllApproval() { + ApprovalMiddleware mw = newApprovalMiddleware(); + AgentRunContext ctx = makeContext(ApprovalMode.AUTO_APPROVE); + + assertNull(mw.beforeToolCall(ctx, "tc1", "dangerous_tool", Collections.emptyMap())); + assertNull(mw.beforeToolCall(ctx, "tc2", "safe_tool", Collections.emptyMap())); + assertTrue("No approval events should be emitted", emittedEvents.isEmpty()); + } + + // --- Normal mode: safe auto-approved, destructive needs approval --- + + @Test + public void testNormalModeAutoApprovesSafeTool() { + ApprovalMiddleware mw = newApprovalMiddleware(); + AgentRunContext ctx = makeContext(ApprovalMode.NORMAL); + + assertNull(mw.beforeToolCall(ctx, "tc1", "safe_tool", Collections.emptyMap())); + assertTrue(emittedEvents.isEmpty()); + } + + @Test + public void testNormalModeRequiresApprovalForDestructiveTool() throws Exception { + ApprovalMiddleware mw = newApprovalMiddleware(5); + AgentRunContext ctx = makeContext(ApprovalMode.NORMAL); + + ExecutorService exec = Executors.newSingleThreadExecutor(); + try { + CountDownLatch eventEmitted = new CountDownLatch(1); + // Capture the emitted event to get the requestId + ctx.setEventEmitter( + event -> { + emittedEvents.add(event); + eventEmitted.countDown(); + }); + + Future future = + exec.submit( + () -> mw.beforeToolCall(ctx, "tc1", "dangerous_tool", Collections.emptyMap())); + + // Wait for the approval request event + assertTrue("Approval event should be emitted", eventEmitted.await(2, TimeUnit.SECONDS)); + assertEquals(1, emittedEvents.size()); + assertEquals(EventType.APPROVAL_REQUEST, emittedEvents.get(0).eventType()); + + ApprovalRequest req = (ApprovalRequest) emittedEvents.get(0); + assertEquals("dangerous_tool", req.toolName()); + assertEquals(ToolRiskLevel.DESTRUCTIVE, req.riskLevel()); + + // Approve + assertTrue(mw.resolve(req.requestId(), true)); + assertNull("Approved tool should return null (no denial)", future.get(2, TimeUnit.SECONDS)); + } finally { + exec.shutdownNow(); + } + } + + @Test + public void testDeniedToolReturnsToolCallDenial() throws Exception { + ApprovalMiddleware mw = newApprovalMiddleware(5); + AgentRunContext ctx = makeContext(ApprovalMode.NORMAL); + + ExecutorService exec = Executors.newSingleThreadExecutor(); + try { + CountDownLatch eventEmitted = new CountDownLatch(1); + ctx.setEventEmitter( + event -> { + emittedEvents.add(event); + eventEmitted.countDown(); + }); + + Future future = + exec.submit( + () -> mw.beforeToolCall(ctx, "tc1", "dangerous_tool", Collections.emptyMap())); + + assertTrue(eventEmitted.await(2, TimeUnit.SECONDS)); + ApprovalRequest req = (ApprovalRequest) emittedEvents.get(0); + + // Deny + assertTrue(mw.resolve(req.requestId(), false)); + AgentMiddleware.ToolCallDenial denial = future.get(2, TimeUnit.SECONDS); + assertNotNull(denial); + assertTrue(denial.reason().contains("denied")); + } finally { + exec.shutdownNow(); + } + } + + // --- Strict mode: all tools need approval --- + + @Test + public void testStrictModeRequiresApprovalForSafeTool() throws Exception { + ApprovalMiddleware mw = newApprovalMiddleware(5); + AgentRunContext ctx = makeContext(ApprovalMode.STRICT); + + ExecutorService exec = Executors.newSingleThreadExecutor(); + try { + CountDownLatch eventEmitted = new CountDownLatch(1); + ctx.setEventEmitter( + event -> { + emittedEvents.add(event); + eventEmitted.countDown(); + }); + + Future future = + exec.submit(() -> mw.beforeToolCall(ctx, "tc1", "safe_tool", Collections.emptyMap())); + + assertTrue(eventEmitted.await(2, TimeUnit.SECONDS)); + ApprovalRequest req = (ApprovalRequest) emittedEvents.get(0); + assertEquals("safe_tool", req.toolName()); + + assertTrue(mw.resolve(req.requestId(), true)); + assertNull(future.get(2, TimeUnit.SECONDS)); + } finally { + exec.shutdownNow(); + } + } + + // --- Timeout --- + + @Test + public void testApprovalTimeoutReturnsDenial() throws Exception { + ApprovalMiddleware mw = newApprovalMiddleware(1); // 1 second timeout + AgentRunContext ctx = makeContext(ApprovalMode.STRICT); + ctx.setEventEmitter(emittedEvents::add); + + ExecutorService exec = Executors.newSingleThreadExecutor(); + try { + Future future = + exec.submit(() -> mw.beforeToolCall(ctx, "tc1", "safe_tool", Collections.emptyMap())); + + // Don't resolve — let it time out + AgentMiddleware.ToolCallDenial denial = future.get(5, TimeUnit.SECONDS); + assertNotNull("Timeout should produce a denial", denial); + assertTrue(denial.reason().contains("timed out")); + } finally { + exec.shutdownNow(); + } + } + + // --- Cancel all --- + + @Test + public void testOnStopUnblocksPendingRequests() throws Exception { + ApprovalMiddleware mw = newApprovalMiddleware(30); + AgentRunContext ctx = makeContext(ApprovalMode.STRICT); + ctx.setEventEmitter(emittedEvents::add); + + ExecutorService exec = Executors.newSingleThreadExecutor(); + try { + CountDownLatch started = new CountDownLatch(1); + Future future = + exec.submit( + () -> { + started.countDown(); + return mw.beforeToolCall(ctx, "tc1", "safe_tool", Collections.emptyMap()); + }); + + assertTrue(started.await(2, TimeUnit.SECONDS)); + Thread.sleep(100); // let the thread enter the blocking wait + + mw.onStop(); + + AgentMiddleware.ToolCallDenial denial = future.get(2, TimeUnit.SECONDS); + assertNotNull("onStop should unblock with a denial", denial); + } finally { + exec.shutdownNow(); + } + } + + // --- Helpers --- + + private ApprovalMiddleware newApprovalMiddleware() { + ApprovalMiddleware mw = new ApprovalMiddleware(); + mw.onRegister(registry); + return mw; + } + + private ApprovalMiddleware newApprovalMiddleware(long timeoutSeconds) { + ApprovalMiddleware mw = new ApprovalMiddleware(timeoutSeconds); + mw.onRegister(registry); + return mw; + } + + private AgentRunContext makeContext(ApprovalMode mode) { + AgentRunContext ctx = new AgentRunContext(new ConversationMemory(), mode); + ctx.setEventEmitter(emittedEvents::add); + return ctx; + } + + private static AgentTool safeTool(String name) { + return new DummyTool(name, ToolRiskLevel.SAFE); + } + + private static AgentTool destructiveTool(String name) { + return new DummyTool(name, ToolRiskLevel.DESTRUCTIVE); + } + + public static class DummyArgs { + public String value; + } + + private static class DummyTool implements AgentTool { + private final String name; + private final ToolRiskLevel riskLevel; + + DummyTool(String name, ToolRiskLevel riskLevel) { + this.name = name; + this.riskLevel = riskLevel; + } + + @Override + public String name() { + return name; + } + + @Override + public String description() { + return "dummy tool"; + } + + @Override + public ToolRiskLevel riskLevel() { + return riskLevel; + } + + @Override + public Class argsType() { + return DummyArgs.class; + } + + @Override + public String execute(DummyArgs args, ToolContext ctx) { + return "ok"; + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareLiveTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareLiveTest.java new file mode 100644 index 00000000000..1c4eeb7c5ad --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareLiveTest.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import static org.junit.Assert.*; +import static org.junit.Assume.assumeTrue; + +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import java.util.List; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.ApprovalMode; +import org.apache.kyuubi.engine.dataagent.runtime.ConversationMemory; +import org.junit.Before; +import org.junit.Test; + +/** + * Live integration test for {@link CompactionMiddleware}: exercises the full compaction path + * against a real OpenAI-compatible LLM. Requires {@code DATA_AGENT_LLM_API_KEY} and {@code + * DATA_AGENT_LLM_API_URL} environment variables; skipped otherwise. + */ +public class CompactionMiddlewareLiveTest { + + private static final String API_KEY = System.getenv().getOrDefault("DATA_AGENT_LLM_API_KEY", ""); + private static final String BASE_URL = System.getenv().getOrDefault("DATA_AGENT_LLM_API_URL", ""); + private static final String MODEL_NAME = System.getenv().getOrDefault("DATA_AGENT_LLM_MODEL", ""); + + private OpenAIClient client; + + @Before + public void setUp() { + assumeTrue("DATA_AGENT_LLM_API_KEY not set, skipping live tests", !API_KEY.isEmpty()); + assumeTrue("DATA_AGENT_LLM_API_URL not set, skipping live tests", !BASE_URL.isEmpty()); + client = OpenAIOkHttpClient.builder().apiKey(API_KEY).baseUrl(BASE_URL).build(); + } + + @Test + public void compactsHistoryWhenThresholdCrossed() { + // Seed a realistic ReAct-style history so the summarizer has something non-trivial to + // summarize. ~20 alternating user/assistant turns. + ConversationMemory memory = new ConversationMemory(); + memory.setSystemPrompt( + "You are a data agent. You previously helped the user investigate the orders table."); + for (int i = 0; i < 10; i++) { + memory.addUserMessage( + "Follow-up question " + i + ": what about the column customer_id in orders?"); + memory.addAssistantMessage( + ChatCompletionAssistantMessageParam.builder() + .content( + "Assistant turn " + + i + + ": the orders table has a customer_id BIGINT column referencing customers.id.") + .build()); + } + int originalSize = memory.size(); + + AgentRunContext ctx = new AgentRunContext(memory, ApprovalMode.AUTO_APPROVE); + // Simulate the previous LLM call having reported a large prompt_tokens so the next + // beforeLlmCall trips the threshold. + ctx.addTokenUsage(60_000, 0, 60_000); + + CompactionMiddleware mw = new CompactionMiddleware(client, MODEL_NAME, /* trigger */ 50_000L); + + AgentMiddleware.LlmCallAction action = mw.beforeLlmCall(ctx, memory.buildLlmMessages()); + + assertNotNull("expected compaction to fire", action); + assertTrue(action instanceof AgentMiddleware.LlmModifyMessages); + + // History got rewritten: [summary user msg] + kept tail. + List hist = memory.getHistory(); + assertTrue(hist.size() < originalSize); + assertTrue(hist.get(0).isUser()); + String first = hist.get(0).asUser().content().text().orElse(""); + assertTrue( + "summary message should be wrapped in ", + first.contains("")); + + // The LLM was told to emit 8 markdown sections; sanity-check a couple show up. + assertTrue( + "summary should contain '## User Intent' section", + first.contains("## User Intent") || first.contains("User Intent")); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareTest.java new file mode 100644 index 00000000000..d4c0cc581db --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareTest.java @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import static org.junit.Assert.*; + +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageFunctionToolCall; +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageToolCall; +import com.openai.models.chat.completions.ChatCompletionToolMessageParam; +import com.openai.models.chat.completions.ChatCompletionUserMessageParam; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.ApprovalMode; +import org.apache.kyuubi.engine.dataagent.runtime.ConversationMemory; +import org.junit.Before; +import org.junit.Test; + +/** + * Unit tests that exercise only the deterministic, LLM-free parts of {@link CompactionMiddleware}: + * + *
      + *
    • Static helpers {@code computeSplit} and {@code estimateTailAfterLastAssistant}. + *
    • Pre-summarizer gating paths in {@code beforeLlmCall} (below threshold, empty old after + * split) that never reach the LLM call. + *
    • Constructor validation. + *
    + * + * The end-to-end behaviour (actual summarization output, compacted-history structure, no re-trigger + * next turn) is covered by {@code CompactionMiddlewareLiveTest}, which is gated on real LLM + * credentials. + */ +public class CompactionMiddlewareTest { + + /** + * A minimally-configured real {@link OpenAIClient}. Never invoked in these tests because every + * code path exercised here returns before reaching the summarizer. Uses dummy credentials and a + * bogus base URL so an accidental invocation would fail loudly. + */ + private static final OpenAIClient DUMMY_CLIENT = + OpenAIOkHttpClient.builder() + .apiKey("dummy-for-unit-tests") + .baseUrl("http://127.0.0.1:0") + .build(); + + private ConversationMemory memory; + private AgentRunContext ctx; + + @Before + public void setUp() { + memory = new ConversationMemory(); + memory.setSystemPrompt("SYS"); + ctx = new AgentRunContext(memory, ApprovalMode.AUTO_APPROVE); + } + + // ----- computeSplit ----- + + @Test + public void computeSplit_shortHistoryReturnsAllInKept() { + List history = + Arrays.asList(userMsg("u0"), asstMsg(asstPlain("a0"))); + CompactionMiddleware.Split s = CompactionMiddleware.computeSplit(history, 4); + assertEquals(0, s.old.size()); + assertEquals(2, s.kept.size()); + } + + @Test + public void computeSplit_simpleAlternationSplitsCorrectly() { + // 10 msgs: u0,a0,u1,a1,u2,a2,u3,a3,u4,a4. keep=2 → boundary at u3 (2 users from tail). + List history = alternatingHistory(10); + CompactionMiddleware.Split s = CompactionMiddleware.computeSplit(history, 2); + // 4 users from tail: u4,u3 → splitIdx = index of u3 = 6 → old=[0..5], kept=[6..9] + assertEquals(6, s.old.size()); + assertEquals(4, s.kept.size()); + assertTrue(s.kept.get(0).isUser()); + } + + @Test + public void computeSplit_neverOrphansToolResult() { + // Layout: u0 a0 u1 a1 u2 a2 u3 a3(tc1) tool(tc1) u4 a4 u5 a5 u6 a6 u7 a7 + // keep=4 → naive splitIdx lands between tool(tc1) and u4; pair-protection must shift it back + // to before a3(tc1), so a3(tc1) + tool(tc1) end up in kept. + List history = new ArrayList<>(); + history.add(userMsg("u0")); + history.add(asstMsg(asstPlain("a0"))); + history.add(userMsg("u1")); + history.add(asstMsg(asstPlain("a1"))); + history.add(userMsg("u2")); + history.add(asstMsg(asstPlain("a2"))); + history.add(userMsg("u3")); + history.add(asstMsg(asstWithToolCall("a3", "tc1", "sql_query", "{}"))); + history.add(toolMsg("tc1", "r1")); + history.add(userMsg("u4")); + history.add(asstMsg(asstPlain("a4"))); + history.add(userMsg("u5")); + history.add(asstMsg(asstPlain("a5"))); + history.add(userMsg("u6")); + history.add(asstMsg(asstPlain("a6"))); + history.add(userMsg("u7")); + history.add(asstMsg(asstPlain("a7"))); + + CompactionMiddleware.Split s = CompactionMiddleware.computeSplit(history, 4); + assertNoOrphanToolResult(s.kept); + // Verify the tc1 pair really did end up in kept, not old. + assertTrue("a3(tc1) must be in kept", containsToolCallId(s.kept, "tc1")); + assertTrue("tool_result(tc1) must be in kept", containsToolCallIdAsResult(s.kept, "tc1")); + } + + @Test + public void computeSplit_keepCountExceedsAvailableUsers() { + // Only 2 user msgs but we ask to keep 4 — boundary walks to the top, old=[], kept=all. + List history = alternatingHistory(4); // u0,a0,u1,a1 + CompactionMiddleware.Split s = CompactionMiddleware.computeSplit(history, 4); + assertEquals(0, s.old.size()); + assertEquals(4, s.kept.size()); + } + + // ----- estimateTailAfterLastAssistant ----- + + @Test + public void estimateTail_afterLastAssistant() { + // u(200 chars) a(50) u(100) → last assistant at index 1, tail is the final user = 100 chars + // → 100/4 = 25 tokens + List msgs = + Arrays.asList( + userMsg(repeat('x', 200)), + asstMsg(asstPlain(repeat('y', 50))), + userMsg(repeat('z', 100))); + assertEquals(25L, CompactionMiddleware.estimateTailAfterLastAssistant(msgs)); + } + + @Test + public void estimateTail_noAssistantMeansEverythingIsTail() { + List msgs = + Arrays.asList(userMsg(repeat('x', 400)), userMsg(repeat('y', 400))); + assertEquals(200L, CompactionMiddleware.estimateTailAfterLastAssistant(msgs)); + } + + @Test + public void estimateTail_emptyReturnsZero() { + assertEquals(0L, CompactionMiddleware.estimateTailAfterLastAssistant(Collections.emptyList())); + } + + // ----- beforeLlmCall pre-summarizer gating ----- + + @Test + public void belowThresholdReturnsNull() { + seedSimpleHistory(6); + ctx.addTokenUsage(1000, 0, 1000); + CompactionMiddleware mw = new CompactionMiddleware(DUMMY_CLIENT, "m", 50_000L); + + assertNull(mw.beforeLlmCall(ctx, memory.buildLlmMessages())); + // Nothing was mutated. + assertEquals(6, memory.size()); + } + + @Test + public void aboveThresholdButHistoryTooShortReturnsNull() { + // Threshold crossed (60k cumulative) but history has only 2 user turns → computeSplit + // can't satisfy KEEP_RECENT_TURNS=4 and keeps everything, leaving split.old empty; so + // beforeLlmCall bails out before ever invoking the summarizer. + memory.addUserMessage("u0"); + memory.addAssistantMessage(asstPlain("a0")); + memory.addUserMessage("u1"); + ctx.addTokenUsage(60_000, 0, 60_000); + CompactionMiddleware mw = new CompactionMiddleware(DUMMY_CLIENT, "m", 50_000L); + + assertNull(mw.beforeLlmCall(ctx, memory.buildLlmMessages())); + assertEquals(3, memory.size()); + } + + @Test + public void triggerUsesLastCallTotalNotCumulative() { + // Two consecutive calls with total_tokens below threshold. The middleware must key on the + // *last* call's total (prompt + completion), not the session cumulative — otherwise a session + // that has accumulated large cumulative cost but then compacted would misfire. Using total + // (not just prompt) also covers the last assistant message — e.g. a tool_call's completion + // tokens — which is part of the next prompt but sits beyond the tail estimator's window. + seedSimpleHistory(6); + CompactionMiddleware mw = new CompactionMiddleware(DUMMY_CLIENT, "m", 50_000L); + + ctx.addTokenUsage(4_000, 1_000, 5_000); + assertNull(mw.beforeLlmCall(ctx, memory.buildLlmMessages())); + + ctx.addTokenUsage(8_000, 2_000, 10_000); + assertNull(mw.beforeLlmCall(ctx, memory.buildLlmMessages())); + + assertEquals(10_000L, memory.getLastTotalTokens()); + assertEquals(15_000L, memory.getCumulativeTotalTokens()); + } + + // ----- helpers ----- + + private void seedSimpleHistory(int n) { + for (int i = 0; i < n; i++) { + if (i % 2 == 0) { + memory.addUserMessage("u" + i); + } else { + memory.addAssistantMessage(asstPlain("a" + i)); + } + } + } + + private static List alternatingHistory(int n) { + List out = new ArrayList<>(n); + for (int i = 0; i < n; i++) { + if (i % 2 == 0) { + out.add(userMsg("u" + i)); + } else { + out.add(asstMsg(asstPlain("a" + i))); + } + } + return out; + } + + private static ChatCompletionMessageParam userMsg(String text) { + return ChatCompletionMessageParam.ofUser( + ChatCompletionUserMessageParam.builder().content(text).build()); + } + + private static ChatCompletionMessageParam asstMsg(ChatCompletionAssistantMessageParam p) { + return ChatCompletionMessageParam.ofAssistant(p); + } + + private static ChatCompletionMessageParam toolMsg(String toolCallId, String content) { + return ChatCompletionMessageParam.ofTool( + ChatCompletionToolMessageParam.builder().toolCallId(toolCallId).content(content).build()); + } + + private static ChatCompletionAssistantMessageParam asstPlain(String text) { + return ChatCompletionAssistantMessageParam.builder().content(text).build(); + } + + private static ChatCompletionAssistantMessageParam asstWithToolCall( + String text, String toolCallId, String toolName, String args) { + List calls = new ArrayList<>(); + calls.add( + ChatCompletionMessageToolCall.ofFunction( + ChatCompletionMessageFunctionToolCall.builder() + .id(toolCallId) + .function( + ChatCompletionMessageFunctionToolCall.Function.builder() + .name(toolName) + .arguments(args) + .build()) + .build())); + return ChatCompletionAssistantMessageParam.builder().content(text).toolCalls(calls).build(); + } + + private static String repeat(char c, int n) { + char[] arr = new char[n]; + Arrays.fill(arr, c); + return new String(arr); + } + + private static boolean containsToolCallId(List msgs, String id) { + for (ChatCompletionMessageParam m : msgs) { + if (m.isAssistant()) { + List calls = m.asAssistant().toolCalls().orElse(null); + if (calls == null) continue; + for (ChatCompletionMessageToolCall tc : calls) { + if (tc.isFunction() && id.equals(tc.asFunction().id())) return true; + } + } + } + return false; + } + + private static boolean containsToolCallIdAsResult( + List msgs, String id) { + for (ChatCompletionMessageParam m : msgs) { + if (m.isTool() && id.equals(m.asTool().toolCallId())) return true; + } + return false; + } + + private static void assertNoOrphanToolResult(List msgs) { + Set issued = new HashSet<>(); + for (ChatCompletionMessageParam m : msgs) { + if (m.isAssistant()) { + m.asAssistant() + .toolCalls() + .ifPresent( + calls -> { + for (ChatCompletionMessageToolCall tc : calls) { + if (tc.isFunction()) issued.add(tc.asFunction().id()); + } + }); + } + } + for (ChatCompletionMessageParam m : msgs) { + if (m.isTool()) { + assertTrue( + "tool_result id=" + m.asTool().toolCallId() + " has no matching tool_call", + issued.contains(m.asTool().toolCallId())); + } + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddlewareTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddlewareTest.java new file mode 100644 index 00000000000..cb107b775f5 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddlewareTest.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import static org.junit.Assert.*; + +import java.util.Collections; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.ApprovalMode; +import org.apache.kyuubi.engine.dataagent.runtime.ConversationMemory; +import org.apache.kyuubi.engine.dataagent.tool.output.GrepToolOutputTool; +import org.apache.kyuubi.engine.dataagent.tool.output.ReadToolOutputTool; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class ToolResultOffloadMiddlewareTest { + + private ToolResultOffloadMiddleware mw; + private AgentRunContext ctxWithSession; + private AgentRunContext ctxNoSession; + + @Before + public void setUp() { + mw = new ToolResultOffloadMiddleware(); + ctxWithSession = + new AgentRunContext(new ConversationMemory(), ApprovalMode.AUTO_APPROVE, "sess-1"); + ctxNoSession = new AgentRunContext(new ConversationMemory(), ApprovalMode.AUTO_APPROVE, null); + } + + @After + public void tearDown() { + mw.onStop(); + } + + @Test + public void underThresholdPassesThrough() { + String small = "row1\nrow2\nrow3\n"; + String out = + mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), small); + assertNull(out); + } + + @Test + public void overLineThresholdTriggersOffload() { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 600; i++) sb.append("row").append(i).append('\n'); + String out = + mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString()); + + assertNotNull(out); + assertTrue(out, out.contains("Tool output truncated")); + assertTrue(out, out.contains("Saved to:")); + assertTrue(out, out.contains(ReadToolOutputTool.NAME)); + assertTrue(out, out.contains(GrepToolOutputTool.NAME)); + assertTrue(out, out.contains("row0")); + assertTrue(out, out.contains("row599")); + } + + @Test + public void overByteThresholdTriggersOffload() { + // 60 lines of ~1 KB each = ~60 KB — over the byte threshold but well under the line threshold. + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 60; i++) { + for (int j = 0; j < 1024; j++) sb.append('a'); + sb.append('\n'); + } + String out = + mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString()); + + assertNotNull("byte threshold should trigger", out); + assertTrue(out, out.contains("Tool output truncated")); + } + + @Test + public void retrievalToolsAreExemptFromGate() { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 2000; i++) sb.append("row").append(i).append('\n'); + assertNull( + mw.afterToolCall( + ctxWithSession, ReadToolOutputTool.NAME, Collections.emptyMap(), sb.toString())); + assertNull( + mw.afterToolCall( + ctxWithSession, GrepToolOutputTool.NAME, Collections.emptyMap(), sb.toString())); + } + + @Test + public void missingSessionIdPassesThrough() { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 1000; i++) sb.append("row").append(i).append('\n'); + String out = + mw.afterToolCall(ctxNoSession, "run_select_query", Collections.emptyMap(), sb.toString()); + assertNull("without sessionId, cannot offload safely — pass through", out); + } + + @Test + public void onSessionCloseClearsCounterAndFiles() { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 600; i++) sb.append("row").append(i).append('\n'); + mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString()); + assertEquals(1, mw.trackedSessions()); + + mw.onSessionClose("sess-1"); + assertEquals(0, mw.trackedSessions()); + } + + @Test + public void multipleOffloadsReuseSameSessionDir() { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 600; i++) sb.append("row").append(i).append('\n'); + String out1 = + mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString()); + String out2 = + mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString()); + // Both previews reference the same session dir, different file names. + assertNotEquals(extractPath(out1), extractPath(out2)); + assertTrue(extractPath(out1).contains("sess-1")); + assertTrue(extractPath(out2).contains("sess-1")); + } + + private static String extractPath(String preview) { + int i = preview.indexOf("Saved to:"); + int eol = preview.indexOf('\n', i); + return preview.substring(i + "Saved to:".length(), eol).trim(); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistryThreadSafetyTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistryThreadSafetyTest.java index c3ceb1a4dc0..7f790e1238e 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistryThreadSafetyTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistryThreadSafetyTest.java @@ -66,7 +66,7 @@ public Class argsType() { } @Override - public String execute(DummyArgs args) { + public String execute(DummyArgs args, ToolContext ctx) { return "result_" + idx; } }); @@ -118,7 +118,7 @@ public Class argsType() { } @Override - public String execute(DummyArgs args) { + public String execute(DummyArgs args, ToolContext ctx) { return "existing_result"; } }); @@ -159,7 +159,7 @@ public Class argsType() { } @Override - public String execute(DummyArgs args) { + public String execute(DummyArgs args, ToolContext ctx) { return "dynamic"; } }); diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolTest.java index 8e0a5cd01b0..777017c4438 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolTest.java @@ -109,7 +109,7 @@ public Class argsType() { } @Override - public String execute(ToolRegistryThreadSafetyTest.DummyArgs args) { + public String execute(ToolRegistryThreadSafetyTest.DummyArgs args, ToolContext ctx) { try { Thread.sleep(60_000); } catch (InterruptedException e) { @@ -147,7 +147,7 @@ public Class argsType() { } @Override - public String execute(ToolRegistryThreadSafetyTest.DummyArgs args) { + public String execute(ToolRegistryThreadSafetyTest.DummyArgs args, ToolContext ctx) { throw new RuntimeException("intentional failure"); } }); diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryToolTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryToolTest.java index c46fc2501bc..5384bd937d9 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryToolTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryToolTest.java @@ -24,6 +24,7 @@ import java.sql.Statement; import java.util.ArrayList; import java.util.List; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; import org.apache.kyuubi.engine.dataagent.tool.ToolRiskLevel; import org.junit.After; import org.junit.Before; @@ -60,7 +61,7 @@ public void testRiskLevelDestructive() { public void testInsert() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "INSERT INTO t VALUES (9999, 'hello')"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertTrue(result.contains("1 row(s) affected")); } @@ -68,21 +69,21 @@ public void testInsert() { public void testUpdate() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "UPDATE t SET v = 'updated' WHERE id = 1"; - assertTrue(tool.execute(args).contains("1 row(s) affected")); + assertTrue(tool.execute(args, ToolContext.EMPTY).contains("1 row(s) affected")); } @Test public void testDelete() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "DELETE FROM t WHERE id = 1"; - assertTrue(tool.execute(args).contains("1 row(s) affected")); + assertTrue(tool.execute(args, ToolContext.EMPTY).contains("1 row(s) affected")); } @Test public void testCreateTable() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "CREATE TABLE new_t (id INTEGER PRIMARY KEY, v TEXT)"; - assertTrue(tool.execute(args).contains("executed successfully")); + assertTrue(tool.execute(args, ToolContext.EMPTY).contains("executed successfully")); } @Test @@ -90,7 +91,7 @@ public void testAlsoAcceptsSelect() { // Mutation tool does not enforce read-only check; SELECT works fine here. SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT v FROM t WHERE id = 1"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertFalse(result.startsWith("Error:")); } @@ -98,18 +99,18 @@ public void testAlsoAcceptsSelect() { public void testRejectsEmptyAndNullSql() { SqlQueryArgs emptyArgs = new SqlQueryArgs(); emptyArgs.sql = ""; - assertTrue(tool.execute(emptyArgs).startsWith("Error:")); + assertTrue(tool.execute(emptyArgs, ToolContext.EMPTY).startsWith("Error:")); SqlQueryArgs nullArgs = new SqlQueryArgs(); nullArgs.sql = null; - assertTrue(tool.execute(nullArgs).startsWith("Error:")); + assertTrue(tool.execute(nullArgs, ToolContext.EMPTY).startsWith("Error:")); } @Test public void testInvalidSqlReturnsError() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "INSERT INTO nonexistent_table VALUES (1)"; - assertTrue(tool.execute(args).startsWith("Error:")); + assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); } // --- Helpers --- diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryToolTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryToolTest.java index 3c6579cf9a7..d1015ede63d 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryToolTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryToolTest.java @@ -24,6 +24,7 @@ import java.sql.Statement; import java.util.ArrayList; import java.util.List; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -56,7 +57,7 @@ public void tearDown() { public void testRejectsInsert() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "INSERT INTO large_table VALUES (9999, 'x')"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertTrue(result.startsWith("Error:")); assertTrue(result.contains("read-only")); assertTrue(result.contains("run_mutation_query")); @@ -66,28 +67,28 @@ public void testRejectsInsert() { public void testRejectsUpdate() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "UPDATE large_table SET value = 'x' WHERE id = 1"; - assertTrue(tool.execute(args).startsWith("Error:")); + assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); } @Test public void testRejectsDelete() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "DELETE FROM large_table WHERE id = 1"; - assertTrue(tool.execute(args).startsWith("Error:")); + assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); } @Test public void testRejectsCreateTable() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "CREATE TABLE x (id INT)"; - assertTrue(tool.execute(args).startsWith("Error:")); + assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); } @Test public void testAllowsSelect() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT id FROM large_table LIMIT 100"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertFalse(result.startsWith("Error:")); assertTrue(result.contains("[100 row(s) returned]")); } @@ -96,7 +97,7 @@ public void testAllowsSelect() { public void testAllowsCte() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "WITH cte AS (SELECT id, value FROM large_table LIMIT 5) SELECT * FROM cte"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertFalse(result.startsWith("Error:")); assertTrue(result.contains("row(s)")); } @@ -107,7 +108,7 @@ public void testAllowsCte() { public void testRespectsLimitInSql() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT id FROM large_table LIMIT 5"; - assertTrue(tool.execute(args).contains("[5 row(s) returned]")); + assertTrue(tool.execute(args, ToolContext.EMPTY).contains("[5 row(s) returned]")); } @Test @@ -116,7 +117,7 @@ public void testNoClientSideCapWhenLimitOmitted() { // Cap discipline is delegated to the LLM via the system prompt. SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT id FROM large_table"; - assertTrue(tool.execute(args).contains("[1500 row(s) returned]")); + assertTrue(tool.execute(args, ToolContext.EMPTY).contains("[1500 row(s) returned]")); } // --- Zero-row result --- @@ -125,7 +126,7 @@ public void testNoClientSideCapWhenLimitOmitted() { public void testZeroRowsResult() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT id FROM large_table WHERE id < 0"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertFalse(result.startsWith("Error:")); assertTrue(result.contains("[0 row(s) returned]")); } @@ -136,15 +137,15 @@ public void testZeroRowsResult() { public void testSelectWithLeadingBlockComment() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "/* get count */ SELECT COUNT(*) FROM large_table"; - assertFalse(tool.execute(args).startsWith("Error:")); + assertFalse(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); } @Test public void testRejectsMutationHiddenBehindComment() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "-- looks innocent\nDROP TABLE large_table"; - assertTrue(tool.execute(args).startsWith("Error:")); - assertTrue(tool.execute(args).contains("read-only")); + assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); + assertTrue(tool.execute(args, ToolContext.EMPTY).contains("read-only")); } // --- Edge cases --- @@ -153,25 +154,25 @@ public void testRejectsMutationHiddenBehindComment() { public void testRejectsEmptyAndNullSql() { SqlQueryArgs emptyArgs = new SqlQueryArgs(); emptyArgs.sql = ""; - assertTrue(tool.execute(emptyArgs).startsWith("Error:")); + assertTrue(tool.execute(emptyArgs, ToolContext.EMPTY).startsWith("Error:")); SqlQueryArgs nullArgs = new SqlQueryArgs(); nullArgs.sql = null; - assertTrue(tool.execute(nullArgs).startsWith("Error:")); + assertTrue(tool.execute(nullArgs, ToolContext.EMPTY).startsWith("Error:")); } @Test public void testRejectsWhitespaceOnlySql() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = " \t\n "; - assertTrue(tool.execute(args).startsWith("Error:")); + assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); } @Test public void testInvalidSqlReturnsError() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT * FROM nonexistent_table"; - assertTrue(tool.execute(args).startsWith("Error:")); + assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); } // --- Output formatting --- @@ -188,7 +189,7 @@ public void testNullValuesRenderedAsNULL() { } SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT id, name FROM nullable_test ORDER BY ROWID"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertTrue(result.contains("NULL")); assertTrue(result.contains("Alice")); } @@ -204,7 +205,7 @@ public void testPipeCharacterEscapedInOutput() { } SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT val FROM pipe_test"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertTrue("Pipe should be escaped for markdown table", result.contains("a\\|b\\|c")); } @@ -214,7 +215,7 @@ public void testPipeCharacterEscapedInOutput() { public void testExtractRootCauseFromNestedExceptions() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT * FROM this_table_does_not_exist_at_all"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertTrue(result.startsWith("Error:")); assertTrue(result.contains("this_table_does_not_exist_at_all")); } @@ -223,7 +224,7 @@ public void testExtractRootCauseFromNestedExceptions() { public void testErrorMessageIsConcise() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELEC INVALID SYNTAX HERE !!!"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertTrue(result.startsWith("Error:")); long newlines = result.chars().filter(c -> c == '\n').count(); assertTrue("Error should be concise (<=2 newlines), got " + newlines, newlines <= 2); @@ -236,7 +237,7 @@ public void testCustomQueryTimeout() { RunSelectQueryTool customTool = new RunSelectQueryTool(ds, 5); SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT COUNT(*) FROM large_table"; - assertFalse(customTool.execute(args).startsWith("Error:")); + assertFalse(customTool.execute(args, ToolContext.EMPTY).startsWith("Error:")); } @Test @@ -314,7 +315,7 @@ public boolean isWrapperFor(Class iface) { RunSelectQueryTool timeoutTool = new RunSelectQueryTool(slowDs, 1); SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT * FROM large_table"; - String result = timeoutTool.execute(args); + String result = timeoutTool.execute(args, ToolContext.EMPTY); assertTrue("Expected error on timeout", result.startsWith("Error:")); assertTrue("Expected timeout message", result.contains("timed out")); } diff --git a/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentCompactionE2ESuite.scala b/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentCompactionE2ESuite.scala new file mode 100644 index 00000000000..5fd95bb0fdc --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentCompactionE2ESuite.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.operation + +import java.sql.DriverManager + +import scala.collection.JavaConverters._ + +import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper} + +import org.apache.kyuubi.config.KyuubiConf._ +import org.apache.kyuubi.engine.dataagent.WithDataAgentEngine +import org.apache.kyuubi.operation.HiveJDBCTestHelper + +/** + * End-to-end test that forces CompactionMiddleware to fire inside the engine and verifies + * two things simultaneously: + * 1. The `compaction` SSE event reaches the JDBC client (wiring works). + * 2. The agent still answers correctly *after* compaction, proving the summary preserved + * the facts the follow-up question depends on. + * + * The trigger threshold is set extremely low (500 tokens) so that the schema dump plus the + * first turn's completion already blow past it, forcing compaction before turn 2 or 3. + * + * Requires DATA_AGENT_LLM_API_KEY and DATA_AGENT_LLM_API_URL. + */ +class DataAgentCompactionE2ESuite extends HiveJDBCTestHelper with WithDataAgentEngine { + + private val apiKey = sys.env.getOrElse("DATA_AGENT_LLM_API_KEY", "") + private val apiUrl = sys.env.getOrElse("DATA_AGENT_LLM_API_URL", "") + private val modelName = sys.env.getOrElse("DATA_AGENT_LLM_MODEL", "") + private val dbPath = { + val tmp = System.getProperty("java.io.tmpdir") + val uid = java.util.UUID.randomUUID() + s"$tmp/dataagent_compaction_e2e_$uid.db" + } + + override def withKyuubiConf: Map[String, String] = Map( + ENGINE_DATA_AGENT_PROVIDER.key -> "OPENAI_COMPATIBLE", + ENGINE_DATA_AGENT_LLM_API_KEY.key -> apiKey, + ENGINE_DATA_AGENT_LLM_API_URL.key -> apiUrl, + ENGINE_DATA_AGENT_LLM_MODEL.key -> modelName, + ENGINE_DATA_AGENT_MAX_ITERATIONS.key -> "10", + ENGINE_DATA_AGENT_APPROVAL_MODE.key -> "AUTO_APPROVE", + // Force compaction to fire aggressively -- any realistic prompt + one LLM round-trip + // will exceed 500 tokens, so compaction must trigger by turn 2 or 3. + ENGINE_DATA_AGENT_COMPACTION_TRIGGER_TOKENS.key -> "500", + ENGINE_DATA_AGENT_JDBC_URL.key -> s"jdbc:sqlite:$dbPath") + + override protected def jdbcUrl: String = jdbcConnectionUrl + + private val enabled: Boolean = apiKey.nonEmpty && apiUrl.nonEmpty + + override def beforeAll(): Unit = { + if (enabled) { + setupTestDatabase() + super.beforeAll() + } + } + + override def afterAll(): Unit = { + if (enabled) { + super.afterAll() + new java.io.File(dbPath).delete() + } + } + + private def setupTestDatabase(): Unit = { + new java.io.File(dbPath).delete() + val conn = DriverManager.getConnection(s"jdbc:sqlite:$dbPath") + try { + val stmt = conn.createStatement() + stmt.execute( + """ + |CREATE TABLE employees ( + | id INTEGER PRIMARY KEY, + | name TEXT NOT NULL, + | department TEXT NOT NULL, + | salary REAL NOT NULL + |)""".stripMargin) + // 6 employees across 3 departments; Frank is unambiguously the top earner. + stmt.execute("INSERT INTO employees VALUES (1, 'Alice', 'Engineering', 25000)") + stmt.execute("INSERT INTO employees VALUES (2, 'Bob', 'Engineering', 30000)") + stmt.execute("INSERT INTO employees VALUES (3, 'Charlie', 'Sales', 20000)") + stmt.execute("INSERT INTO employees VALUES (4, 'Diana', 'Sales', 22000)") + stmt.execute("INSERT INTO employees VALUES (5, 'Eve', 'Marketing', 18000)") + stmt.execute("INSERT INTO employees VALUES (6, 'Frank', 'Engineering', 35000)") + } finally { + conn.close() + } + } + + private val mapper = new ObjectMapper() + + private def drainReply(rs: java.sql.ResultSet): String = { + val sb = new StringBuilder + while (rs.next()) { + sb.append(rs.getString("reply")) + } + sb.toString() + } + + private def parseEvents(stream: String): Seq[JsonNode] = { + val parser = mapper.getFactory.createParser(stream) + try mapper.readValues(parser, classOf[JsonNode]).asScala.toList + finally parser.close() + } + + private def extractAnswer(events: Seq[JsonNode]): String = { + val sb = new StringBuilder + events.foreach { node => + if ("content_delta" == node.path("type").asText()) { + sb.append(node.path("text").asText("")) + } + } + sb.toString() + } + + private val strictFormatHint = + "Respond with ONLY the answer, no explanation, no markdown, no punctuation." + + test("E2E: compaction fires mid-conversation and preserves facts across turns") { + assume(enabled, "DATA_AGENT_LLM_API_KEY/API_URL not set, skipping") + + // CompactionMiddleware.KEEP_RECENT_TURNS is hardcoded to 4. computeSplit needs a + // non-empty 'old' slice, so at least 5 distinct user turns must accumulate before + // compaction can fire -- turns 1..(N-4) become the old slice, 4 most recent are + // kept verbatim. Turn 5 is the observable trigger point. + // + // Turn 5 is phrased to force a fresh SQL query rather than relying on recall, + // because summary quality varies across LLMs (some drop the top-earner fact). + // Correctness of the final answer then validates that the post-compaction history + // still gives the agent enough context to pick the right tool and query -- which is + // the compaction contract we actually care about: mechanism fires, agent recovers. + withJdbcStatement() { stmt => + Seq( + "List every department that appears in the employees table.", + "How many employees work in Engineering?", + "What salaries do Sales employees earn?", + "Who works in Marketing?").zipWithIndex.foreach { case (q, i) => + val events = parseEvents(drainReply(stmt.executeQuery(q))) + info(s"Turn ${i + 1} answer: ${extractAnswer(events)}") + } + + // Turn 5 -- explicitly instruct the agent to re-query so the answer does not + // depend on summary fidelity, only on the agent still functioning after + // compaction rewrote history. + val events5 = parseEvents(drainReply(stmt.executeQuery( + "Run a SELECT against the employees table to find the single employee with" + + " the highest salary. Report ONLY that employee's name." + + s" $strictFormatHint"))) + val answer5 = extractAnswer(events5) + info(s"Turn 5 answer: $answer5") + + val compactionEvents = + events5.filter(_.path("type").asText() == "compaction") + assert( + compactionEvents.nonEmpty, + "Expected at least one compaction event in turn 5 (trigger=500 tokens, 5 turns)") + + // Sanity-check event shape -- field names must match ExecuteStatement's SSE encoder. + val c = compactionEvents.head + assert( + c.has("summarized") && c.get("summarized").asInt() > 0, + s"compaction event should carry a positive summarized count: $c") + assert(c.has("kept") && c.get("kept").asInt() >= 0, s"compaction event missing kept: $c") + assert( + c.has("triggerTokens") && c.get("triggerTokens").asLong() == 500L, + s"compaction event should echo configured trigger: $c") + + // Turn 5 was told to SELECT fresh; Frank (35000) is unambiguously the top earner. + // If we don't get "Frank", either the agent failed to re-query after compaction + // (real bug in post-compaction history) or the tool layer is broken. + assert( + answer5.contains("Frank"), + s"Turn 5 should identify Frank as the top earner after re-querying; the agent" + + s" must remain functional post-compaction. Got: $answer5") + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentE2ESuite.scala b/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentE2ESuite.scala index 977c655dabd..cfb83ecc9b2 100644 --- a/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentE2ESuite.scala +++ b/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentE2ESuite.scala @@ -19,6 +19,10 @@ package org.apache.kyuubi.engine.dataagent.operation import java.sql.DriverManager +import scala.collection.JavaConverters._ + +import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper} + import org.apache.kyuubi.config.KyuubiConf._ import org.apache.kyuubi.engine.dataagent.WithDataAgentEngine import org.apache.kyuubi.operation.HiveJDBCTestHelper @@ -34,7 +38,7 @@ class DataAgentE2ESuite extends HiveJDBCTestHelper with WithDataAgentEngine { private val apiKey = sys.env.getOrElse("DATA_AGENT_LLM_API_KEY", "") private val apiUrl = sys.env.getOrElse("DATA_AGENT_LLM_API_URL", "") - private val modelName = sys.env.getOrElse("DATA_AGENT_LLM_MODEL", "gpt-4o") + private val modelName = sys.env.getOrElse("DATA_AGENT_LLM_MODEL", "") private val dbPath = s"${System.getProperty("java.io.tmpdir")}/dataagent_e2e_test_${java.util.UUID.randomUUID()}.db" @@ -107,33 +111,71 @@ class DataAgentE2ESuite extends HiveJDBCTestHelper with WithDataAgentEngine { new java.io.File(dbPath).delete() } + private val mapper = new ObjectMapper() + + private def drainReply(rs: java.sql.ResultSet): String = { + val sb = new StringBuilder + while (rs.next()) { + sb.append(rs.getString("reply")) + } + val stream = sb.toString() + info(s"Agent event stream: $stream") + stream + } + + /** + * The JDBC `reply` column is a concatenated stream of SSE events + * (`agent_start`, `tool_call`, `tool_result`, `content_delta`, ...). Only + * `content_delta.text` is actual model output - this pulls those out and + * joins them to recover the final natural-language answer. + */ + private def extractAnswer(eventStream: String): String = { + val parser = mapper.getFactory.createParser(eventStream) + val sb = new StringBuilder + try { + mapper.readValues(parser, classOf[JsonNode]).asScala.foreach { node => + if ("content_delta" == node.path("type").asText()) { + sb.append(node.path("text").asText("")) + } + } + } finally { + parser.close() + } + sb.toString() + } + + private val strictFormatHint = + "Respond with ONLY the answer, no explanation, no markdown, no punctuation." + test("E2E: agent answers data question through full Kyuubi pipeline") { assume(enabled, "DATA_AGENT_LLM_API_KEY/API_URL not set, skipping E2E tests") - // scalastyle:off println withJdbcStatement() { stmt => - // Ask a question that requires schema exploration + SQL execution - val result = stmt.executeQuery( - "Which department has the highest average salary?") - - val sb = new StringBuilder - while (result.next()) { - val chunk = result.getString("reply") - sb.append(chunk) - print(chunk) // real-time output for debugging - } - println() - - val reply = sb.toString() - - // The agent should have: - // 1. Explored the schema (mentioned table names or columns) - // 2. Executed SQL (the reply should contain actual data) - // 3. Answered with "Engineering" (avg salary 30000) - assert(reply.nonEmpty, "Agent should return a non-empty response") - assert( - reply.toLowerCase.contains("engineering") || reply.contains("30000"), - s"Expected the answer to mention 'Engineering' or '30000', got: ${reply.take(500)}") + val stream = drainReply( + stmt.executeQuery( + s"Which department has the highest average salary? $strictFormatHint")) + assert(extractAnswer(stream) == "Engineering") + } + } + + test("E2E: agent resolves follow-up question using prior conversation context") { + assume(enabled, "DATA_AGENT_LLM_API_KEY/API_URL not set, skipping E2E tests") + // Two executeQuery calls on the same Statement share the JDBC session, which means + // the provider reuses the same ConversationMemory across turns. Turn 2 uses the + // demonstrative "that department" - it can only be answered correctly if Turn 1's + // answer (Engineering) is carried over in the agent's conversation history. + withJdbcStatement() { stmt => + val stream1 = drainReply( + stmt.executeQuery( + s"Which department has the highest average salary? $strictFormatHint")) + assert(extractAnswer(stream1) == "Engineering") + + // Engineering has 3 employees (Alice, Bob, Frank). If memory is not shared + // the agent cannot resolve "that department" and cannot produce the exact + // integer 3 - nothing in Turn 2's prompt points to Engineering. + val stream2 = drainReply( + stmt.executeQuery( + s"How many employees work in that department? $strictFormatHint")) + assert(extractAnswer(stream2) == "3") } - // scalastyle:on println } } diff --git a/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala b/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala index 215a76e26d4..31d308e9c67 100644 --- a/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala +++ b/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala @@ -3848,6 +3848,21 @@ object KyuubiConf { .checkValue(_ > 0, "must be positive number") .createWithDefault(100) + val ENGINE_DATA_AGENT_COMPACTION_TRIGGER_TOKENS: ConfigEntry[Long] = + buildConf("kyuubi.engine.data.agent.compaction.trigger.tokens") + .doc("The prompt-token threshold above which the Data Agent's compaction middleware " + + "summarizes older conversation history into a compact message. The check is made each " + + "turn as " + + "real_prompt_tokens_of_previous_LLM_call + estimate_of_newly_appended_tail; " + + "when this predicted prompt size reaches the configured value, older messages are " + + "replaced by a single summary message while the most recent exchanges are kept verbatim. " + + "Set to a very large value (e.g., 9223372036854775807) to effectively " + + "disable compaction.") + .version("1.12.0") + .longConf + .checkValue(_ > 0, "must be positive number") + .createWithDefault(128000L) + val ENGINE_DATA_AGENT_QUERY_TIMEOUT: ConfigEntry[Long] = buildConf("kyuubi.engine.data.agent.query.timeout") .doc("The JDBC query execution timeout for the Data Agent SQL tools. " + From 52235617db8f52a211d205581377eb4e0aada950 Mon Sep 17 00:00:00 2001 From: wangzhigang Date: Thu, 23 Apr 2026 10:59:03 +0800 Subject: [PATCH 2/2] [KYUUBI #7379][2b/4] Move mysql-connector-j to test scope MySQL Connector/J is GPL-licensed and cannot be bundled in an Apache binary release. Users who need the MySQL/StarRocks datasource at runtime should provide the driver jar themselves on the engine classpath. Addresses review feedback on #7417. --- externals/kyuubi-data-agent-engine/pom.xml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/externals/kyuubi-data-agent-engine/pom.xml b/externals/kyuubi-data-agent-engine/pom.xml index 74da5005784..43a5008dda9 100644 --- a/externals/kyuubi-data-agent-engine/pom.xml +++ b/externals/kyuubi-data-agent-engine/pom.xml @@ -76,10 +76,11 @@ ${sqlite.version} - + com.mysql mysql-connector-j + test