diff --git a/docs/configuration/settings.md b/docs/configuration/settings.md
index 3e34882b28f..a2910abfca6 100644
--- a/docs/configuration/settings.md
+++ b/docs/configuration/settings.md
@@ -144,6 +144,7 @@ You can configure the Kyuubi properties in `$KYUUBI_HOME/conf/kyuubi-defaults.co
| kyuubi.engine.chat.provider | ECHO | The provider for the Chat engine. Candidates:
ECHO: simply replies a welcome message. GPT: a.k.a ChatGPT, powered by OpenAI. ERNIE: ErnieBot, powered by Baidu. | string | 1.8.0 |
| kyuubi.engine.connection.url.use.hostname | true | (deprecated) When true, the engine registers with hostname to zookeeper. When Spark runs on K8s with cluster mode, set to false to ensure that server can connect to engine | boolean | 1.3.0 |
| kyuubi.engine.data.agent.approval.mode | NORMAL | Default approval mode for tool execution in the Data Agent engine. Candidates: AUTO_APPROVE: all tools are auto-approved without user interaction. NORMAL: only destructive tools require explicit approval. STRICT: all tools require explicit user approval. | string | 1.12.0 |
+| kyuubi.engine.data.agent.compaction.trigger.tokens | 128000 | The prompt-token threshold above which the Data Agent's compaction middleware summarizes older conversation history into a compact message. The check is made each turn as real_prompt_tokens_of_previous_LLM_call + estimate_of_newly_appended_tail; when this predicted prompt size reaches the configured value, older messages are replaced by a single summary message while the most recent exchanges are kept verbatim. Set to a very large value (e.g., 9223372036854775807) to effectively disable compaction. | long | 1.12.0 |
| kyuubi.engine.data.agent.extra.classpath | <undefined> | The extra classpath for the Data Agent engine, for configuring the location of the LLM SDK and etc. | string | 1.12.0 |
| kyuubi.engine.data.agent.java.options | <undefined> | The extra Java options for the Data Agent engine | string | 1.12.0 |
| kyuubi.engine.data.agent.jdbc.url | <undefined> | The JDBC URL for the Data Agent engine to connect to the target database. If not set, the Data Agent will connect back to Kyuubi server via ZooKeeper service discovery. | string | 1.12.0 |
diff --git a/externals/kyuubi-data-agent-engine/pom.xml b/externals/kyuubi-data-agent-engine/pom.xml
index c34d049360c..43a5008dda9 100644
--- a/externals/kyuubi-data-agent-engine/pom.xml
+++ b/externals/kyuubi-data-agent-engine/pom.xml
@@ -50,45 +50,63 @@
${project.version}
+
com.openai
openai-java
+ ${openai.sdk.version}
+
com.github.victools
jsonschema-generator
+ ${victools.jsonschema.version}
-
com.github.victools
jsonschema-module-jackson
+ ${victools.jsonschema.version}
-
+
- org.apache.kyuubi
- kyuubi-common_${scala.binary.version}
- ${project.version}
- test-jar
- test
+ org.xerial
+ sqlite-jdbc
+ ${sqlite.version}
+
- org.xerial
- sqlite-jdbc
+ com.mysql
+ mysql-connector-j
test
+
- org.testcontainers
- testcontainers-mysql
+ io.trino
+ trino-jdbc
+
+
+
+
+ com.zaxxer
+ HikariCP
+
+
+
+
+ org.apache.kyuubi
+ kyuubi-common_${scala.binary.version}
+ ${project.version}
+ test-jar
test
- com.mysql
- mysql-connector-j
+ org.testcontainers
+ testcontainers-mysql
test
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialect.java
index c3be1dad61a..c771ad222aa 100644
--- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialect.java
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialect.java
@@ -17,6 +17,12 @@
package org.apache.kyuubi.engine.dataagent.datasource;
+import org.apache.kyuubi.engine.dataagent.datasource.dialect.GenericDialect;
+import org.apache.kyuubi.engine.dataagent.datasource.dialect.MysqlDialect;
+import org.apache.kyuubi.engine.dataagent.datasource.dialect.SparkDialect;
+import org.apache.kyuubi.engine.dataagent.datasource.dialect.SqliteDialect;
+import org.apache.kyuubi.engine.dataagent.datasource.dialect.TrinoDialect;
+
/**
* SQL dialect abstraction for datasource-specific SQL generation.
*
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/GenericDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/GenericDialect.java
similarity index 92%
rename from externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/GenericDialect.java
rename to externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/GenericDialect.java
index 3ea22ed54e3..d8c4512de03 100644
--- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/GenericDialect.java
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/GenericDialect.java
@@ -15,7 +15,9 @@
* limitations under the License.
*/
-package org.apache.kyuubi.engine.dataagent.datasource;
+package org.apache.kyuubi.engine.dataagent.datasource.dialect;
+
+import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect;
/**
* Fallback dialect for JDBC subprotocols that have no dedicated implementation. Carries the
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/MysqlDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/MysqlDialect.java
similarity index 85%
rename from externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/MysqlDialect.java
rename to externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/MysqlDialect.java
index 98747ffa30c..350789a6a87 100644
--- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/MysqlDialect.java
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/MysqlDialect.java
@@ -15,12 +15,14 @@
* limitations under the License.
*/
-package org.apache.kyuubi.engine.dataagent.datasource;
+package org.apache.kyuubi.engine.dataagent.datasource.dialect;
+
+import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect;
/** MySQL dialect. Uses backtick quoting for identifiers. */
public final class MysqlDialect implements JdbcDialect {
- static final MysqlDialect INSTANCE = new MysqlDialect();
+ public static final MysqlDialect INSTANCE = new MysqlDialect();
private MysqlDialect() {}
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SparkDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SparkDialect.java
similarity index 85%
rename from externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SparkDialect.java
rename to externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SparkDialect.java
index 3adb43fa398..34e20034bfb 100644
--- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SparkDialect.java
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SparkDialect.java
@@ -15,12 +15,14 @@
* limitations under the License.
*/
-package org.apache.kyuubi.engine.dataagent.datasource;
+package org.apache.kyuubi.engine.dataagent.datasource.dialect;
+
+import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect;
/** Spark SQL dialect. Uses backtick quoting for identifiers. */
public final class SparkDialect implements JdbcDialect {
- static final SparkDialect INSTANCE = new SparkDialect();
+ public static final SparkDialect INSTANCE = new SparkDialect();
private SparkDialect() {}
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SqliteDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SqliteDialect.java
similarity index 85%
rename from externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SqliteDialect.java
rename to externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SqliteDialect.java
index a53255a9c67..eb98ca8edfa 100644
--- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SqliteDialect.java
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SqliteDialect.java
@@ -15,12 +15,14 @@
* limitations under the License.
*/
-package org.apache.kyuubi.engine.dataagent.datasource;
+package org.apache.kyuubi.engine.dataagent.datasource.dialect;
+
+import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect;
/** SQLite dialect. Uses double-quote quoting for identifiers. */
public final class SqliteDialect implements JdbcDialect {
- static final SqliteDialect INSTANCE = new SqliteDialect();
+ public static final SqliteDialect INSTANCE = new SqliteDialect();
private SqliteDialect() {}
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/TrinoDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/TrinoDialect.java
similarity index 85%
rename from externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/TrinoDialect.java
rename to externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/TrinoDialect.java
index edacf2f87e2..75fbd4bb242 100644
--- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/TrinoDialect.java
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/TrinoDialect.java
@@ -15,12 +15,14 @@
* limitations under the License.
*/
-package org.apache.kyuubi.engine.dataagent.datasource;
+package org.apache.kyuubi.engine.dataagent.datasource.dialect;
+
+import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect;
/** Trino SQL dialect. Uses double-quote quoting for identifiers. */
public final class TrinoDialect implements JdbcDialect {
- static final TrinoDialect INSTANCE = new TrinoDialect();
+ public static final TrinoDialect INSTANCE = new TrinoDialect();
private TrinoDialect() {}
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/ProviderRunRequest.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/ProviderRunRequest.java
index f4e40b2fae8..26ad8be77fb 100644
--- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/ProviderRunRequest.java
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/ProviderRunRequest.java
@@ -17,13 +17,23 @@
package org.apache.kyuubi.engine.dataagent.provider;
+import org.apache.kyuubi.engine.dataagent.runtime.ApprovalMode;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
/**
* User-facing request parameters for a provider-level agent invocation. Only contains fields from
* the caller (question, model override, etc.). Adding new per-request options does not require
* changing the {@link DataAgentProvider} interface.
+ *
+ * The approval mode is accepted as a raw string (natural for config-driven callers) and parsed
+ * into {@link ApprovalMode} by {@link #getApprovalMode()}. Unrecognised values fall back to {@link
+ * ApprovalMode#NORMAL} with a warning.
*/
public class ProviderRunRequest {
+ private static final Logger LOG = LoggerFactory.getLogger(ProviderRunRequest.class);
+
private final String question;
private String modelName;
private String approvalMode;
@@ -45,8 +55,20 @@ public ProviderRunRequest modelName(String modelName) {
return this;
}
- public String getApprovalMode() {
- return approvalMode;
+ /**
+ * Resolved approval mode. Returns {@link ApprovalMode#NORMAL} when the caller did not set one or
+ * supplied an unknown value.
+ */
+ public ApprovalMode getApprovalMode() {
+ if (approvalMode == null || approvalMode.isEmpty()) {
+ return ApprovalMode.NORMAL;
+ }
+ try {
+ return ApprovalMode.valueOf(approvalMode.toUpperCase());
+ } catch (IllegalArgumentException e) {
+ LOG.warn("Unknown approval mode '{}', using default NORMAL", approvalMode);
+ return ApprovalMode.NORMAL;
+ }
}
public ProviderRunRequest approvalMode(String approvalMode) {
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/openai/OpenAiProvider.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/openai/OpenAiProvider.java
new file mode 100644
index 00000000000..bcd647b9326
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/openai/OpenAiProvider.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.provider.openai;
+
+import com.openai.client.OpenAIClient;
+import com.openai.client.okhttp.OpenAIOkHttpClient;
+import com.zaxxer.hikari.HikariDataSource;
+import java.time.Duration;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.Consumer;
+import javax.sql.DataSource;
+import org.apache.kyuubi.config.KyuubiConf;
+import org.apache.kyuubi.config.KyuubiReservedKeys;
+import org.apache.kyuubi.engine.dataagent.datasource.DataSourceFactory;
+import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect;
+import org.apache.kyuubi.engine.dataagent.prompt.SystemPromptBuilder;
+import org.apache.kyuubi.engine.dataagent.provider.DataAgentProvider;
+import org.apache.kyuubi.engine.dataagent.provider.ProviderRunRequest;
+import org.apache.kyuubi.engine.dataagent.runtime.AgentInvocation;
+import org.apache.kyuubi.engine.dataagent.runtime.ConversationMemory;
+import org.apache.kyuubi.engine.dataagent.runtime.ReactAgent;
+import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent;
+import org.apache.kyuubi.engine.dataagent.runtime.middleware.ApprovalMiddleware;
+import org.apache.kyuubi.engine.dataagent.runtime.middleware.CompactionMiddleware;
+import org.apache.kyuubi.engine.dataagent.runtime.middleware.LoggingMiddleware;
+import org.apache.kyuubi.engine.dataagent.runtime.middleware.ToolResultOffloadMiddleware;
+import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry;
+import org.apache.kyuubi.engine.dataagent.tool.sql.RunMutationQueryTool;
+import org.apache.kyuubi.engine.dataagent.tool.sql.RunSelectQueryTool;
+import org.apache.kyuubi.engine.dataagent.util.ConfUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * An OpenAI-compatible provider that wires up the full ReactAgent with streaming LLM, tools, and
+ * middleware pipeline. Uses the official OpenAI Java SDK.
+ *
+ *
The ReactAgent, DataSource, and ToolRegistry are shared across all sessions within this engine
+ * instance. Each session only maintains its own {@link ConversationMemory}. This works because each
+ * engine is bound to one user + one datasource, so all sessions within the engine naturally share
+ * the same data connection.
+ */
+public class OpenAiProvider implements DataAgentProvider {
+
+ private static final Logger LOG = LoggerFactory.getLogger(OpenAiProvider.class);
+
+ private final ReactAgent agent;
+ private final ToolRegistry toolRegistry;
+ private final DataSource dataSource;
+ private final OpenAIClient client;
+ private final ConcurrentHashMap sessions = new ConcurrentHashMap<>();
+
+ public OpenAiProvider(KyuubiConf conf) {
+ String apiKey = ConfUtils.requireString(conf, KyuubiConf.ENGINE_DATA_AGENT_LLM_API_KEY());
+ String baseUrl = ConfUtils.requireString(conf, KyuubiConf.ENGINE_DATA_AGENT_LLM_API_URL());
+ String modelName = ConfUtils.requireString(conf, KyuubiConf.ENGINE_DATA_AGENT_LLM_MODEL());
+
+ int maxIterations = ConfUtils.intConf(conf, KyuubiConf.ENGINE_DATA_AGENT_MAX_ITERATIONS());
+ long compactionTriggerTokens =
+ ConfUtils.longConf(conf, KyuubiConf.ENGINE_DATA_AGENT_COMPACTION_TRIGGER_TOKENS());
+ int queryTimeoutSeconds =
+ (int) ConfUtils.millisAsSeconds(conf, KyuubiConf.ENGINE_DATA_AGENT_QUERY_TIMEOUT());
+ long toolCallTimeoutSeconds =
+ ConfUtils.millisAsSeconds(conf, KyuubiConf.ENGINE_DATA_AGENT_TOOL_CALL_TIMEOUT());
+
+ this.client =
+ OpenAIOkHttpClient.builder()
+ .apiKey(apiKey)
+ .baseUrl(baseUrl)
+ .maxRetries(3)
+ .timeout(Duration.ofSeconds(180))
+ .build();
+
+ this.toolRegistry = new ToolRegistry(toolCallTimeoutSeconds);
+
+ SystemPromptBuilder promptBuilder = SystemPromptBuilder.create();
+ this.dataSource = attachJdbcDataSource(conf, toolRegistry, promptBuilder, queryTimeoutSeconds);
+
+ this.agent =
+ ReactAgent.builder()
+ .client(client)
+ .modelName(modelName)
+ .toolRegistry(toolRegistry)
+ .addMiddleware(new ToolResultOffloadMiddleware())
+ .addMiddleware(new LoggingMiddleware())
+ .addMiddleware(new CompactionMiddleware(client, modelName, compactionTriggerTokens))
+ .addMiddleware(new ApprovalMiddleware())
+ .maxIterations(maxIterations)
+ .systemPrompt(promptBuilder.build())
+ .build();
+ }
+
+ /**
+ * Register JDBC-backed SQL tools if a JDBC URL is configured. Returns the created {@link
+ * DataSource} so the provider can close it on shutdown, or {@code null} when no JDBC is wired.
+ */
+ private static DataSource attachJdbcDataSource(
+ KyuubiConf conf,
+ ToolRegistry registry,
+ SystemPromptBuilder promptBuilder,
+ int queryTimeoutSeconds) {
+ String jdbcUrl = ConfUtils.optionalString(conf, KyuubiConf.ENGINE_DATA_AGENT_JDBC_URL());
+ if (jdbcUrl == null) {
+ return null;
+ }
+ LOG.info("Data Agent JDBC URL configured ({})", jdbcUrl.replaceAll("//.*@", "//@"));
+
+ String sessionUser =
+ ConfUtils.optionalString(conf, KyuubiReservedKeys.KYUUBI_SESSION_USER_KEY());
+
+ DataSource ds = DataSourceFactory.create(jdbcUrl, sessionUser);
+ registry.register(new RunSelectQueryTool(ds, queryTimeoutSeconds));
+ registry.register(new RunMutationQueryTool(ds, queryTimeoutSeconds));
+ promptBuilder.datasource(JdbcDialect.fromUrl(jdbcUrl).datasourceName());
+ return ds;
+ }
+
+ @Override
+ public void open(String sessionId, String user) {
+ sessions.put(sessionId, new ConversationMemory());
+ LOG.info("Opened Data Agent session {} for user {}", sessionId, user);
+ }
+
+ @Override
+ public void run(String sessionId, ProviderRunRequest request, Consumer onEvent) {
+ ConversationMemory memory = sessions.get(sessionId);
+ if (memory == null) {
+ throw new IllegalStateException("No open Data Agent session for id=" + sessionId);
+ }
+
+ AgentInvocation invocation =
+ new AgentInvocation(request.getQuestion())
+ .modelName(request.getModelName())
+ .approvalMode(request.getApprovalMode())
+ .sessionId(sessionId);
+ agent.run(invocation, memory, onEvent);
+ }
+
+ @Override
+ public boolean resolveApproval(String requestId, boolean approved) {
+ return agent.resolveApproval(requestId, approved);
+ }
+
+ @Override
+ public void close(String sessionId) {
+ sessions.remove(sessionId);
+ agent.closeSession(sessionId);
+ LOG.info("Closed Data Agent session {}", sessionId);
+ }
+
+ @Override
+ public void stop() {
+ agent.stop();
+ toolRegistry.close();
+ if (dataSource instanceof HikariDataSource) {
+ ((HikariDataSource) dataSource).close();
+ LOG.info("Closed Data Agent connection pool");
+ }
+ client.close();
+ }
+}
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentInvocation.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentInvocation.java
new file mode 100644
index 00000000000..0695d556058
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentInvocation.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.runtime;
+
+import java.util.Objects;
+
+/**
+ * User-facing request parameters for a single agent invocation. Only contains fields that come from
+ * the caller (question, model override, etc.). Framework-level concerns like memory and event
+ * consumer are separate method parameters.
+ *
+ * Adding new per-request options (e.g. temperature, maxTokens) does not require changing the
+ * {@code ReactAgent.run()} signature.
+ */
+public class AgentInvocation {
+
+ private final String userInput;
+ private String modelName;
+ private ApprovalMode approvalMode = ApprovalMode.NORMAL;
+ private String sessionId;
+
+ public AgentInvocation(String userInput) {
+ this.userInput = Objects.requireNonNull(userInput, "userInput must not be null");
+ }
+
+ public String getUserInput() {
+ return userInput;
+ }
+
+ public String getModelName() {
+ return modelName;
+ }
+
+ public AgentInvocation modelName(String modelName) {
+ this.modelName = modelName;
+ return this;
+ }
+
+ public ApprovalMode getApprovalMode() {
+ return approvalMode;
+ }
+
+ public AgentInvocation approvalMode(ApprovalMode approvalMode) {
+ this.approvalMode = approvalMode;
+ return this;
+ }
+
+ public String getSessionId() {
+ return sessionId;
+ }
+
+ /** Upstream session id, propagated into {@link AgentRunContext#getSessionId()}. */
+ public AgentInvocation sessionId(String sessionId) {
+ this.sessionId = sessionId;
+ return this;
+ }
+}
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentRunContext.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentRunContext.java
new file mode 100644
index 00000000000..e7c92df8033
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentRunContext.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.runtime;
+
+import java.util.function.Consumer;
+import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent;
+
+/**
+ * Mutable context passed through the middleware pipeline and agent loop. Tracks the current state
+ * of agent execution including iteration count, token usage, and custom middleware state.
+ */
+public class AgentRunContext {
+
+ private final ConversationMemory memory;
+ private final String sessionId;
+ private Consumer eventEmitter;
+ private int iteration;
+ private long promptTokens;
+ private long completionTokens;
+ private long totalTokens;
+ private ApprovalMode approvalMode;
+
+ public AgentRunContext(ConversationMemory memory, ApprovalMode approvalMode) {
+ this(memory, approvalMode, null);
+ }
+
+ public AgentRunContext(ConversationMemory memory, ApprovalMode approvalMode, String sessionId) {
+ this.memory = memory;
+ this.iteration = 0;
+ this.approvalMode = approvalMode;
+ this.sessionId = sessionId;
+ }
+
+ public ConversationMemory getMemory() {
+ return memory;
+ }
+
+ /**
+ * The upstream session identifier this run belongs to. Threaded down from {@code
+ * DataAgentProvider.run(sessionId, ...)}. May be {@code null} in unit tests that do not exercise
+ * session-scoped middleware.
+ */
+ public String getSessionId() {
+ return sessionId;
+ }
+
+ public int getIteration() {
+ return iteration;
+ }
+
+ public void setIteration(int iteration) {
+ this.iteration = iteration;
+ }
+
+ public long getPromptTokens() {
+ return promptTokens;
+ }
+
+ public long getCompletionTokens() {
+ return completionTokens;
+ }
+
+ public long getTotalTokens() {
+ return totalTokens;
+ }
+
+ /**
+ * Record one LLM call's usage. Updates both the per-run counters on this context and the
+ * session-level cumulative on the underlying {@link ConversationMemory}, so middlewares that need
+ * a session-wide picture can read it directly from memory without keeping their own bookkeeping.
+ */
+ public void addTokenUsage(long prompt, long completion, long total) {
+ this.promptTokens += prompt;
+ this.completionTokens += completion;
+ this.totalTokens += total;
+ memory.addCumulativeTokens(prompt, completion, total);
+ }
+
+ public ApprovalMode getApprovalMode() {
+ return approvalMode;
+ }
+
+ public void setApprovalMode(ApprovalMode approvalMode) {
+ this.approvalMode = approvalMode;
+ }
+
+ public void setEventEmitter(Consumer eventEmitter) {
+ this.eventEmitter = eventEmitter;
+ }
+
+ /** Emit an event through the agent's event pipeline. Available for use by middlewares. */
+ public void emit(AgentEvent event) {
+ if (eventEmitter != null) {
+ eventEmitter.accept(event);
+ }
+ }
+}
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ApprovalMode.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ApprovalMode.java
new file mode 100644
index 00000000000..57bc20bc2bb
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ApprovalMode.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.runtime;
+
+/** Approval modes for tool execution in the Data Agent engine. */
+public enum ApprovalMode {
+ /** All tools require explicit user approval. */
+ STRICT,
+ /** Only non-readonly tools require approval. */
+ NORMAL,
+ /** All tools are auto-approved. */
+ AUTO_APPROVE
+}
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemory.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemory.java
new file mode 100644
index 00000000000..0bfae26ec27
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemory.java
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.runtime;
+
+import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam;
+import com.openai.models.chat.completions.ChatCompletionMessageParam;
+import com.openai.models.chat.completions.ChatCompletionSystemMessageParam;
+import com.openai.models.chat.completions.ChatCompletionToolMessageParam;
+import com.openai.models.chat.completions.ChatCompletionUserMessageParam;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * Manages conversation history for a Data Agent session. Ensures tool result messages are never
+ * orphaned from their corresponding AI messages.
+ *
+ * Each instance is session-scoped and accessed sequentially within a single ReAct loop — no
+ * synchronization is needed. Cross-session concurrency is handled by the provider's session map.
+ */
+public class ConversationMemory {
+
+ /**
+ * System prompt prepended to every LLM call built from this memory. Rebuilt by the provider on
+ * each invocation (datasource/tool metadata can change between turns), so it lives outside the
+ * {@link #messages} list rather than being inserted as the first entry. May be {@code null} until
+ * the first {@link #setSystemPrompt} call — {@link #buildLlmMessages()} will simply omit the
+ * system slot in that case.
+ */
+ private String systemPrompt;
+
+ /**
+ * The raw content of the most recent user message added via {@link #addUserMessage}. Cached
+ * separately from {@link #messages} so middleware and callers can recover the current turn's
+ * question after compaction rewrites history. Not used by the LLM call path itself.
+ */
+ private String lastUserInput;
+
+ /**
+ * The ordered conversation history: user / assistant / tool messages, in the order the LLM will
+ * see them. The system prompt is intentionally NOT stored here (see {@link #systemPrompt}).
+ * Mutated in place by {@link #replaceHistory} during compaction; otherwise append-only.
+ */
+ private final List messages = new ArrayList<>();
+
+ /**
+ * Session-level running total of {@code prompt_tokens} reported by every LLM call on this
+ * conversation (across ReAct turns). Intended for billing, quota, and observability — not used by
+ * any runtime decision. Updated via {@link #addCumulativeTokens}.
+ */
+ private long cumulativePromptTokens;
+
+ /**
+ * Session-level running total of {@code completion_tokens}. See {@link #cumulativePromptTokens}.
+ */
+ private long cumulativeCompletionTokens;
+
+ /**
+ * Session-level running total of {@code total_tokens} (prompt + completion as reported by the
+ * provider — not necessarily the sum of the two counters above, since providers may count
+ * cached/reasoning tokens differently). See {@link #cumulativePromptTokens}.
+ */
+ private long cumulativeTotalTokens;
+
+ /**
+ * The {@code total_tokens} reported by the single most recent LLM call, or {@code 0} if no call
+ * has completed yet. Distinct from the cumulative counters: this is a snapshot, overwritten every
+ * call. Used by {@link
+ * org.apache.kyuubi.engine.dataagent.runtime.middleware.CompactionMiddleware} to estimate the
+ * next prompt size (the last response becomes part of the next prompt, so the next call's prompt
+ * is at least {@code lastTotalTokens}). Persists across ReAct turns until the next call
+ * overwrites it.
+ */
+ private long lastTotalTokens;
+
+ public ConversationMemory() {}
+
+ public String getSystemPrompt() {
+ return systemPrompt;
+ }
+
+ public void setSystemPrompt(String prompt) {
+ this.systemPrompt = prompt;
+ }
+
+ public void addUserMessage(String content) {
+ this.lastUserInput = content;
+ messages.add(
+ ChatCompletionMessageParam.ofUser(
+ ChatCompletionUserMessageParam.builder().content(content).build()));
+ }
+
+ public String getLastUserInput() {
+ return lastUserInput;
+ }
+
+ public void addAssistantMessage(ChatCompletionAssistantMessageParam message) {
+ messages.add(ChatCompletionMessageParam.ofAssistant(message));
+ }
+
+ public void addToolResult(String toolCallId, String content) {
+ messages.add(
+ ChatCompletionMessageParam.ofTool(
+ ChatCompletionToolMessageParam.builder()
+ .toolCallId(toolCallId)
+ .content(content)
+ .build()));
+ }
+
+ /**
+ * Build the full message list for LLM API invocation: [system prompt] + conversation history.
+ *
+ * No windowing is applied — callers are responsible for managing context length (e.g. via a
+ * token-based truncation strategy).
+ *
+ * @see #getHistory() for history-only access without system prompt
+ */
+ public List buildLlmMessages() {
+ List result = new ArrayList<>(messages.size() + 1);
+ if (systemPrompt != null) {
+ result.add(
+ ChatCompletionMessageParam.ofSystem(
+ ChatCompletionSystemMessageParam.builder().content(systemPrompt).build()));
+ }
+ result.addAll(messages);
+ return Collections.unmodifiableList(result);
+ }
+
+ /**
+ * Returns the conversation history (user, assistant, tool messages) without the system prompt.
+ * Useful for middleware that needs to inspect or compact history.
+ */
+ public List getHistory() {
+ return Collections.unmodifiableList(new ArrayList<>(messages));
+ }
+
+ /**
+ * Replace the conversation history with a compacted version. Useful for context-length management
+ * strategies (e.g., summarizing older messages).
+ *
+ * Also clears {@link #lastTotalTokens}: the prior snapshot referred to a prompt whose bulk we
+ * just discarded, so it no longer describes anything in memory. Leaving it stale would keep the
+ * compaction trigger armed until the next successful LLM call overwrites it — fine on the happy
+ * path, but if that call fails the next turn would re-enter compaction against already-compacted
+ * history. Zeroing means "unknown, wait for the next real usage report". Cumulative totals are
+ * intentionally preserved (session-level accounting, must not regress on internal compaction).
+ */
+ public void replaceHistory(List compacted) {
+ messages.clear();
+ messages.addAll(compacted);
+ this.lastTotalTokens = 0;
+ }
+
+ public void clear() {
+ messages.clear();
+ }
+
+ public int size() {
+ return messages.size();
+ }
+
+ public long getCumulativePromptTokens() {
+ return cumulativePromptTokens;
+ }
+
+ public long getCumulativeCompletionTokens() {
+ return cumulativeCompletionTokens;
+ }
+
+ public long getCumulativeTotalTokens() {
+ return cumulativeTotalTokens;
+ }
+
+ public long getLastTotalTokens() {
+ return lastTotalTokens;
+ }
+
+ /** Add one LLM call's usage to the session cumulative. Intended for {@link AgentRunContext}. */
+ public void addCumulativeTokens(long prompt, long completion, long total) {
+ this.cumulativePromptTokens += prompt;
+ this.cumulativeCompletionTokens += completion;
+ this.cumulativeTotalTokens += total;
+ this.lastTotalTokens = total;
+ }
+}
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgent.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgent.java
new file mode 100644
index 00000000000..520cd963ef8
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgent.java
@@ -0,0 +1,606 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.runtime;
+
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.openai.client.OpenAIClient;
+import com.openai.core.http.StreamResponse;
+import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam;
+import com.openai.models.chat.completions.ChatCompletionChunk;
+import com.openai.models.chat.completions.ChatCompletionCreateParams;
+import com.openai.models.chat.completions.ChatCompletionMessageFunctionToolCall;
+import com.openai.models.chat.completions.ChatCompletionMessageParam;
+import com.openai.models.chat.completions.ChatCompletionMessageToolCall;
+import com.openai.models.chat.completions.ChatCompletionStreamOptions;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.Consumer;
+import org.apache.kyuubi.engine.dataagent.runtime.event.AgentError;
+import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent;
+import org.apache.kyuubi.engine.dataagent.runtime.event.AgentFinish;
+import org.apache.kyuubi.engine.dataagent.runtime.event.AgentStart;
+import org.apache.kyuubi.engine.dataagent.runtime.event.ContentComplete;
+import org.apache.kyuubi.engine.dataagent.runtime.event.ContentDelta;
+import org.apache.kyuubi.engine.dataagent.runtime.event.StepEnd;
+import org.apache.kyuubi.engine.dataagent.runtime.event.StepStart;
+import org.apache.kyuubi.engine.dataagent.runtime.event.ToolCall;
+import org.apache.kyuubi.engine.dataagent.runtime.event.ToolResult;
+import org.apache.kyuubi.engine.dataagent.runtime.middleware.AgentMiddleware;
+import org.apache.kyuubi.engine.dataagent.runtime.middleware.ApprovalMiddleware;
+import org.apache.kyuubi.engine.dataagent.tool.ToolContext;
+import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * ReAct (Reasoning + Acting) agent loop using the OpenAI official Java SDK. Iterates through LLM
+ * reasoning, tool execution, and result verification until the agent produces a final answer or
+ * hits the iteration limit.
+ *
+ * Emits {@link AgentEvent}s via the provided consumer for real-time token-level streaming.
+ */
+public class ReactAgent {
+
+ private static final Logger LOG = LoggerFactory.getLogger(ReactAgent.class);
+ private static final ObjectMapper JSON = new ObjectMapper();
+
+ private final OpenAIClient client;
+ private final String defaultModelName;
+ private final ToolRegistry toolRegistry;
+ private final List middlewares;
+ private final ApprovalMiddleware approvalMiddleware;
+ private final int maxIterations;
+ private final String systemPrompt;
+
+ public ReactAgent(
+ OpenAIClient client,
+ String modelName,
+ ToolRegistry toolRegistry,
+ List middlewares,
+ int maxIterations,
+ String systemPrompt) {
+ this.client = client;
+ this.defaultModelName = modelName;
+ this.toolRegistry = toolRegistry;
+ this.middlewares = middlewares;
+ this.approvalMiddleware = findApprovalMiddleware(middlewares);
+ this.maxIterations = maxIterations;
+ this.systemPrompt = systemPrompt;
+ }
+
+ private static ApprovalMiddleware findApprovalMiddleware(List middlewares) {
+ for (AgentMiddleware mw : middlewares) {
+ if (mw instanceof ApprovalMiddleware) return (ApprovalMiddleware) mw;
+ }
+ return null;
+ }
+
+ /** Resolve a pending approval request. Returns false if no pending request matches. */
+ public boolean resolveApproval(String requestId, boolean approved) {
+ if (approvalMiddleware == null) return false;
+ return approvalMiddleware.resolve(requestId, approved);
+ }
+
+ /** Fan out session-close to every middleware. Errors in one middleware don't block others. */
+ public void closeSession(String sessionId) {
+ for (AgentMiddleware mw : middlewares) {
+ try {
+ mw.onSessionClose(sessionId);
+ } catch (Exception e) {
+ LOG.warn("Middleware onSessionClose error", e);
+ }
+ }
+ }
+
+ /** Fan out engine-stop to every middleware. Errors in one middleware don't block others. */
+ public void stop() {
+ for (AgentMiddleware mw : middlewares) {
+ try {
+ mw.onStop();
+ } catch (Exception e) {
+ LOG.warn("Middleware onStop error", e);
+ }
+ }
+ }
+
+ /**
+ * Run the ReAct loop for the given request.
+ *
+ * @param request user-facing parameters (question, model override, etc.)
+ * @param memory the conversation memory (may contain prior context)
+ * @param eventConsumer callback for each agent event (token-level streaming)
+ */
+ public void run(
+ AgentInvocation request, ConversationMemory memory, Consumer eventConsumer) {
+ String userInput = request.getUserInput();
+ ApprovalMode approvalMode = request.getApprovalMode();
+ String modelNameOverride = request.getModelName();
+
+ String effectiveModel =
+ (modelNameOverride != null && !modelNameOverride.isEmpty())
+ ? modelNameOverride
+ : defaultModelName;
+
+ // System prompt is immutable for the lifetime of this agent — only set it on first run
+ // to avoid redundant overwrites on multi-turn conversations.
+ if (memory.getSystemPrompt() == null) {
+ memory.setSystemPrompt(systemPrompt);
+ }
+ memory.addUserMessage(userInput);
+
+ AgentRunContext ctx = new AgentRunContext(memory, approvalMode, request.getSessionId());
+ ctx.setEventEmitter(event -> emit(ctx, event, eventConsumer));
+ dispatchAgentStart(ctx);
+ emit(ctx, new AgentStart(), eventConsumer);
+
+ try {
+ for (int step = 1; step <= maxIterations; step++) {
+ ctx.setIteration(step);
+ emit(ctx, new StepStart(step), eventConsumer);
+
+ List messages =
+ resolveMessagesForCall(ctx, memory.buildLlmMessages(), eventConsumer);
+ if (messages == null) {
+ // Middleware asked us to skip — AgentError + AgentFinish have already been emitted.
+ return;
+ }
+
+ StreamResult result = streamLlmResponse(ctx, messages, effectiveModel, eventConsumer);
+ if (result.isEmpty()) {
+ emit(ctx, new AgentError("LLM returned empty response"), eventConsumer);
+ emitFinish(ctx, eventConsumer);
+ return;
+ }
+
+ if (!result.content.isEmpty()) {
+ emit(ctx, new ContentComplete(result.content), eventConsumer);
+ }
+ ChatCompletionAssistantMessageParam assistantMsg = buildAssistantMessage(result);
+ memory.addAssistantMessage(assistantMsg);
+ dispatchAfterLlmCall(ctx, assistantMsg);
+
+ if (result.toolCalls == null || result.toolCalls.isEmpty()) {
+ // No tool calls — agent is done.
+ emit(ctx, new StepEnd(step), eventConsumer);
+ emitFinish(ctx, eventConsumer);
+ return;
+ }
+
+ executeToolCalls(ctx, memory, result.toolCalls, eventConsumer);
+ emit(ctx, new StepEnd(step), eventConsumer);
+ }
+
+ emit(
+ ctx, new AgentError("Reached maximum iterations (" + maxIterations + ")"), eventConsumer);
+ emitFinish(ctx, eventConsumer);
+
+ } catch (Exception e) {
+ LOG.error("Agent run failed", e);
+ emit(
+ ctx, new AgentError(e.getClass().getSimpleName() + ": " + e.getMessage()), eventConsumer);
+ emitFinish(ctx, eventConsumer);
+ } finally {
+ dispatchAgentFinish(ctx);
+ }
+ }
+
+ /**
+ * Run {@code beforeLlmCall} middleware against {@code messages}. Returns the messages to send,
+ * possibly rewritten by middleware, or {@code null} if middleware aborted the call (in which case
+ * this method has already emitted the terminal events).
+ */
+ private List resolveMessagesForCall(
+ AgentRunContext ctx,
+ List messages,
+ Consumer eventConsumer) {
+ AgentMiddleware.LlmCallAction action = dispatchBeforeLlmCall(ctx, messages);
+ if (action instanceof AgentMiddleware.LlmSkip) {
+ String reason = ((AgentMiddleware.LlmSkip) action).reason();
+ emit(ctx, new AgentError("LLM call skipped by middleware: " + reason), eventConsumer);
+ emitFinish(ctx, eventConsumer);
+ return null;
+ }
+ if (action instanceof AgentMiddleware.LlmModifyMessages) {
+ return ((AgentMiddleware.LlmModifyMessages) action).messages();
+ }
+ return messages;
+ }
+
+ private static ChatCompletionAssistantMessageParam buildAssistantMessage(StreamResult result) {
+ ChatCompletionAssistantMessageParam.Builder b = ChatCompletionAssistantMessageParam.builder();
+ if (!result.content.isEmpty()) {
+ b.content(result.content);
+ }
+ if (result.toolCalls != null && !result.toolCalls.isEmpty()) {
+ b.toolCalls(result.toolCalls);
+ }
+ return b.build();
+ }
+
+ /**
+ * Execute the assistant's tool calls in 3 phases:
+ *
+ *
+ * Serial: run {@code beforeToolCall} middleware, emit {@link ToolCall} events, and collect
+ * the calls that survived approval.
+ * Concurrent: fan out to {@link ToolRegistry#submitTool}, which always returns a future
+ * that completes normally — timeouts and execution errors surface as error strings.
+ * Serial: join futures in order, run {@code afterToolCall}, and record results to memory.
+ *
+ */
+ private void executeToolCalls(
+ AgentRunContext ctx,
+ ConversationMemory memory,
+ List toolCalls,
+ Consumer eventConsumer) {
+ List approved = new ArrayList<>();
+ for (ChatCompletionMessageToolCall toolCall : toolCalls) {
+ ChatCompletionMessageFunctionToolCall fnCall = toolCall.asFunction();
+ String toolName = fnCall.function().name();
+ Map toolArgs;
+ try {
+ toolArgs = parseToolArgs(fnCall.function().arguments());
+ } catch (IllegalArgumentException e) {
+ // Malformed JSON from the LLM: record an error tool_result (preserves the
+ // assistant/tool_result pairing the next API call needs) and let the loop self-correct.
+ String err = "Tool call failed: " + e.getMessage();
+ memory.addToolResult(fnCall.id(), err);
+ emit(ctx, new ToolResult(fnCall.id(), toolName, err, true), eventConsumer);
+ continue;
+ }
+
+ AgentMiddleware.ToolCallDenial denial =
+ dispatchBeforeToolCall(ctx, fnCall.id(), toolName, toolArgs);
+ if (denial != null) {
+ String denied = "Tool call denied: " + denial.reason();
+ memory.addToolResult(fnCall.id(), denied);
+ emit(ctx, new ToolResult(fnCall.id(), toolName, denied, true), eventConsumer);
+ continue;
+ }
+
+ emit(ctx, new ToolCall(fnCall.id(), toolName, toolArgs), eventConsumer);
+ approved.add(new ToolCallEntry(fnCall, toolName, toolArgs));
+ }
+
+ ToolContext toolCtx = new ToolContext(ctx.getSessionId());
+ List> futures = new ArrayList<>(approved.size());
+ for (ToolCallEntry entry : approved) {
+ futures.add(
+ toolRegistry.submitTool(entry.toolName, entry.fnCall.function().arguments(), toolCtx));
+ }
+
+ for (int i = 0; i < approved.size(); i++) {
+ ToolCallEntry entry = approved.get(i);
+ String output = futures.get(i).join();
+ String modified = dispatchAfterToolCall(ctx, entry.toolName, entry.toolArgs, output);
+ if (modified != null) {
+ output = modified;
+ }
+ memory.addToolResult(entry.fnCall.id(), output);
+ emit(ctx, new ToolResult(entry.fnCall.id(), entry.toolName, output, false), eventConsumer);
+ }
+ }
+
+ /** Result of a streaming LLM call, assembled from chunks. */
+ private static class StreamResult {
+ final String content;
+ final List toolCalls;
+
+ StreamResult(String content, List toolCalls) {
+ this.content = content;
+ this.toolCalls = toolCalls;
+ }
+
+ boolean isEmpty() {
+ return content.isEmpty() && (toolCalls == null || toolCalls.isEmpty());
+ }
+ }
+
+ /** Holds an approved tool call's parsed metadata for the 3-phase execution pipeline. */
+ private static class ToolCallEntry {
+ final ChatCompletionMessageFunctionToolCall fnCall;
+ final String toolName;
+ final Map toolArgs;
+
+ ToolCallEntry(
+ ChatCompletionMessageFunctionToolCall fnCall,
+ String toolName,
+ Map toolArgs) {
+ this.fnCall = fnCall;
+ this.toolName = toolName;
+ this.toolArgs = toolArgs;
+ }
+ }
+
+ /**
+ * Stream LLM response, emitting ContentDelta for each text chunk. Assembles tool calls directly
+ * from streamed chunks — no non-streaming fallback. Exceptions propagate to the top-level handler
+ * in {@link #run}.
+ */
+ private StreamResult streamLlmResponse(
+ AgentRunContext ctx,
+ List messages,
+ String effectiveModel,
+ Consumer eventConsumer) {
+ ChatCompletionCreateParams.Builder paramsBuilder =
+ ChatCompletionCreateParams.builder()
+ .model(effectiveModel)
+ .streamOptions(ChatCompletionStreamOptions.builder().includeUsage(true).build());
+ for (ChatCompletionMessageParam msg : messages) {
+ paramsBuilder.addMessage(msg);
+ }
+ toolRegistry.addToolsTo(paramsBuilder);
+
+ LOG.info("LLM request: model={}", effectiveModel);
+ StreamAccumulator acc = new StreamAccumulator();
+ try (StreamResponse stream =
+ client.chat().completions().createStreaming(paramsBuilder.build())) {
+ stream.stream().forEach(chunk -> consumeChunk(ctx, chunk, acc, eventConsumer));
+ }
+ return new StreamResult(acc.content.toString(), acc.buildToolCalls());
+ }
+
+ /** Fold one streaming chunk into {@code acc}, emitting per-token {@link ContentDelta}s. */
+ private void consumeChunk(
+ AgentRunContext ctx,
+ ChatCompletionChunk chunk,
+ StreamAccumulator acc,
+ Consumer eventConsumer) {
+ if (!acc.serverModelLogged) {
+ LOG.info("LLM response: server-echoed model={}", chunk.model());
+ acc.serverModelLogged = true;
+ }
+ chunk
+ .usage()
+ .ifPresent(u -> ctx.addTokenUsage(u.promptTokens(), u.completionTokens(), u.totalTokens()));
+
+ for (ChatCompletionChunk.Choice c : chunk.choices()) {
+ c.delta()
+ .content()
+ .ifPresent(
+ text -> {
+ acc.content.append(text);
+ emit(ctx, new ContentDelta(text), eventConsumer);
+ });
+ c.delta().toolCalls().ifPresent(acc::mergeToolCallDeltas);
+ }
+ }
+
+ /**
+ * Mutable accumulator for a single streaming LLM turn. Tool call fields are keyed by the chunk's
+ * {@code index} because provider SDKs may deliver a single logical call across multiple chunks
+ * and only surface the {@code id}/{@code name} on the first one.
+ */
+ private static final class StreamAccumulator {
+ final StringBuilder content = new StringBuilder();
+ final Map toolCallIds = new HashMap<>();
+ final Map toolCallNames = new HashMap<>();
+ final Map toolCallArgs = new HashMap<>();
+ boolean serverModelLogged = false;
+
+ void mergeToolCallDeltas(List deltas) {
+ for (ChatCompletionChunk.Choice.Delta.ToolCall tc : deltas) {
+ int idx = (int) tc.index();
+ tc.id().ifPresent(id -> toolCallIds.put(idx, id));
+ tc.function()
+ .ifPresent(
+ fn -> {
+ fn.name().ifPresent(name -> toolCallNames.put(idx, name));
+ fn.arguments()
+ .ifPresent(
+ args ->
+ toolCallArgs
+ .computeIfAbsent(idx, k -> new StringBuilder())
+ .append(args));
+ });
+ }
+ }
+
+ /**
+ * Materialize accumulated deltas into SDK tool-call objects. Returns {@code null} (not an empty
+ * list) if no tool calls were seen, matching the existing {@link StreamResult} contract.
+ */
+ List buildToolCalls() {
+ if (toolCallIds.isEmpty()) return null;
+ List out = new ArrayList<>(toolCallIds.size());
+ for (Map.Entry e : toolCallIds.entrySet()) {
+ int idx = e.getKey();
+ String id = (e.getValue() == null || e.getValue().isEmpty()) ? synthId() : e.getValue();
+ String args = toolCallArgs.containsKey(idx) ? toolCallArgs.get(idx).toString() : "{}";
+ out.add(
+ ChatCompletionMessageToolCall.ofFunction(
+ ChatCompletionMessageFunctionToolCall.builder()
+ .id(id)
+ .function(
+ ChatCompletionMessageFunctionToolCall.Function.builder()
+ .name(toolCallNames.getOrDefault(idx, ""))
+ .arguments(args)
+ .build())
+ .build()));
+ }
+ return out;
+ }
+
+ /**
+ * Synthesize an id for tool calls whose id never arrived on the stream (some OpenAI-compatible
+ * providers omit it). The id has to be stable within a turn and unique across turns so the
+ * assistant/tool_result pairing downstream holds.
+ */
+ private static String synthId() {
+ return "local_" + java.util.UUID.randomUUID().toString().replace("-", "").substring(0, 24);
+ }
+ }
+
+ private static Map parseToolArgs(String json) {
+ if (json == null || json.isEmpty()) {
+ return new HashMap<>();
+ }
+ try {
+ return JSON.readValue(json, new TypeReference>() {});
+ } catch (java.io.IOException e) {
+ throw new IllegalArgumentException("Malformed tool-call arguments JSON: " + json, e);
+ }
+ }
+
+ // --- Middleware dispatch methods ---
+ //
+ // Middlewares are internal framework code. If one throws, the agent run fails via the
+ // top-level catch in run() — we do not wrap individual dispatch calls in try/catch.
+
+ private void emitFinish(AgentRunContext ctx, Consumer eventConsumer) {
+ emit(
+ ctx,
+ new AgentFinish(
+ ctx.getIteration(),
+ ctx.getPromptTokens(),
+ ctx.getCompletionTokens(),
+ ctx.getTotalTokens()),
+ eventConsumer);
+ }
+
+ private void emit(AgentRunContext ctx, AgentEvent event, Consumer consumer) {
+ AgentEvent filtered = event;
+ for (AgentMiddleware mw : middlewares) {
+ filtered = mw.onEvent(ctx, filtered);
+ if (filtered == null) return;
+ }
+ consumer.accept(filtered);
+ }
+
+ private void dispatchAgentStart(AgentRunContext ctx) {
+ for (AgentMiddleware mw : middlewares) {
+ mw.onAgentStart(ctx);
+ }
+ }
+
+ private void dispatchAgentFinish(AgentRunContext ctx) {
+ // Runs even when the agent body threw, so swallow here to ensure every middleware's cleanup
+ // gets a chance to run; otherwise we'd leak session state in later middlewares.
+ for (int i = middlewares.size() - 1; i >= 0; i--) {
+ try {
+ middlewares.get(i).onAgentFinish(ctx);
+ } catch (Exception e) {
+ LOG.warn("Middleware onAgentFinish error", e);
+ }
+ }
+ }
+
+ private AgentMiddleware.LlmCallAction dispatchBeforeLlmCall(
+ AgentRunContext ctx, List messages) {
+ for (AgentMiddleware mw : middlewares) {
+ AgentMiddleware.LlmCallAction action = mw.beforeLlmCall(ctx, messages);
+ if (action != null) return action;
+ }
+ return null;
+ }
+
+ private void dispatchAfterLlmCall(
+ AgentRunContext ctx, ChatCompletionAssistantMessageParam response) {
+ for (int i = middlewares.size() - 1; i >= 0; i--) {
+ middlewares.get(i).afterLlmCall(ctx, response);
+ }
+ }
+
+ private AgentMiddleware.ToolCallDenial dispatchBeforeToolCall(
+ AgentRunContext ctx, String toolCallId, String toolName, Map toolArgs) {
+ for (AgentMiddleware mw : middlewares) {
+ AgentMiddleware.ToolCallDenial denial =
+ mw.beforeToolCall(ctx, toolCallId, toolName, toolArgs);
+ if (denial != null) return denial;
+ }
+ return null;
+ }
+
+ private String dispatchAfterToolCall(
+ AgentRunContext ctx, String toolName, Map toolArgs, String result) {
+ String modified = null;
+ for (int i = middlewares.size() - 1; i >= 0; i--) {
+ String mwResult =
+ middlewares
+ .get(i)
+ .afterToolCall(ctx, toolName, toolArgs, modified != null ? modified : result);
+ if (mwResult != null) {
+ modified = mwResult;
+ }
+ }
+ return modified;
+ }
+
+ // --- Builder ---
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ public static class Builder {
+ private OpenAIClient client;
+ private String modelName;
+ private ToolRegistry toolRegistry = new ToolRegistry(ToolRegistry.DEFAULT_TIMEOUT_SECONDS);
+ private final List middlewares = new ArrayList<>();
+ private int maxIterations = 20;
+ private String systemPrompt;
+
+ public Builder client(OpenAIClient client) {
+ this.client = client;
+ return this;
+ }
+
+ public Builder modelName(String modelName) {
+ this.modelName = modelName;
+ return this;
+ }
+
+ public Builder toolRegistry(ToolRegistry toolRegistry) {
+ this.toolRegistry = toolRegistry;
+ return this;
+ }
+
+ public Builder addMiddleware(AgentMiddleware middleware) {
+ this.middlewares.add(middleware);
+ return this;
+ }
+
+ public Builder maxIterations(int maxIterations) {
+ if (maxIterations < 1) {
+ throw new IllegalArgumentException("maxIterations must be >= 1, got " + maxIterations);
+ }
+ this.maxIterations = maxIterations;
+ return this;
+ }
+
+ public Builder systemPrompt(String systemPrompt) {
+ this.systemPrompt = systemPrompt;
+ return this;
+ }
+
+ public ReactAgent build() {
+ if (client == null) throw new IllegalStateException("client is required");
+ if (modelName == null) throw new IllegalStateException("modelName is required");
+ if (toolRegistry == null) throw new IllegalStateException("toolRegistry is required");
+ for (AgentMiddleware mw : middlewares) {
+ mw.onRegister(toolRegistry);
+ }
+ return new ReactAgent(
+ client, modelName, toolRegistry, middlewares, maxIterations, systemPrompt);
+ }
+ }
+}
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStore.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStore.java
new file mode 100644
index 00000000000..d4a8a97e97d
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStore.java
@@ -0,0 +1,242 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.runtime;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+import java.util.regex.Pattern;
+import java.util.regex.PatternSyntaxException;
+import java.util.stream.Stream;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Per-session temp-file store for large tool outputs that the input gate has offloaded from
+ * conversation history.
+ *
+ * Layout: {@code //tool_.txt}, where {@code } is a
+ * per-engine randomly-named temp directory (see {@link #create()}). UTF-8 text.
+ *
+ * Isolation: every {@code read}/{@code grep} call requires the current session id and
+ * validates that the caller-supplied path resolves under {@code /}, not merely
+ * under {@code }. A session cannot read another session's offloaded output even if it somehow
+ * obtained the absolute path. The per-engine random root additionally isolates different engine
+ * processes sharing the same host.
+ *
+ * Failures (traversal, missing file, session-id mismatch, IO) are reported as error strings
+ * rather than thrown — the tool must keep the agent loop alive.
+ */
+public class ToolOutputStore implements AutoCloseable {
+
+ private static final Logger LOG = LoggerFactory.getLogger(ToolOutputStore.class);
+ private static final String ROOT_PREFIX = "kyuubi-data-agent-";
+
+ private final Path root;
+
+ /**
+ * Create a store backed by a fresh, per-engine random temp directory under {@code
+ * java.io.tmpdir}.
+ */
+ public static ToolOutputStore create() {
+ try {
+ return new ToolOutputStore(Files.createTempDirectory(ROOT_PREFIX));
+ } catch (IOException e) {
+ throw new IllegalStateException("Failed to create ToolOutputStore temp root", e);
+ }
+ }
+
+ private ToolOutputStore(Path root) {
+ try {
+ Files.createDirectories(root);
+ this.root = root.toRealPath();
+ } catch (IOException e) {
+ throw new IllegalStateException("Failed to initialize ToolOutputStore root: " + root, e);
+ }
+ }
+
+ public Path getRoot() {
+ return root;
+ }
+
+ /** Write {@code content} to {@code //tool_.txt}. */
+ public Path write(String sessionId, String toolCallId, String content) throws IOException {
+ Path dir = root.resolve(safeSegment(sessionId));
+ Files.createDirectories(dir);
+ Path file = dir.resolve("tool_" + safeSegment(toolCallId) + ".txt");
+ Files.write(file, content.getBytes(StandardCharsets.UTF_8));
+ return file;
+ }
+
+ /**
+ * Read a line window. Returns a human-readable block including a 1-based {@code [lines X-Y of Z
+ * total]} header, or an error string on traversal / IO failure / cross-session access.
+ */
+ public String read(String sessionId, String pathStr, long offset, int limit) {
+ Path file = validatePath(sessionId, pathStr);
+ if (file == null) {
+ return "Error: path is outside this session's tool-output directory or does not exist: "
+ + pathStr;
+ }
+ if (offset < 0) offset = 0;
+ if (limit <= 0) limit = 1;
+
+ List taken = new ArrayList<>(limit);
+ long totalLines = 0;
+ try (BufferedReader br = Files.newBufferedReader(file, StandardCharsets.UTF_8)) {
+ String line;
+ while ((line = br.readLine()) != null) {
+ if (totalLines >= offset && taken.size() < limit) {
+ taken.add(line);
+ }
+ totalLines++;
+ }
+ } catch (IOException e) {
+ return "Error reading " + pathStr + ": " + e.getMessage();
+ }
+
+ long fromLine = offset + 1; // 1-based
+ long toLineExclusive = Math.min(offset + limit, totalLines);
+ StringBuilder sb = new StringBuilder();
+ sb.append("[lines ")
+ .append(fromLine)
+ .append("-")
+ .append(toLineExclusive)
+ .append(" of ")
+ .append(totalLines)
+ .append(" total]\n");
+ for (String line : taken) {
+ sb.append(line).append('\n');
+ }
+ return sb.toString();
+ }
+
+ /**
+ * Stream-grep the file. Returns at most {@code maxMatches} matches as {@code lineNo:content}, one
+ * per line; or an error string on traversal / regex / IO failure / cross-session access.
+ */
+ public String grep(String sessionId, String pathStr, String patternStr, int maxMatches) {
+ Path file = validatePath(sessionId, pathStr);
+ if (file == null) {
+ return "Error: path is outside this session's tool-output directory or does not exist: "
+ + pathStr;
+ }
+ if (patternStr == null || patternStr.isEmpty()) {
+ return "Error: 'pattern' parameter is required.";
+ }
+ if (maxMatches <= 0) maxMatches = 50;
+
+ Pattern pattern;
+ try {
+ pattern = Pattern.compile(patternStr);
+ } catch (PatternSyntaxException e) {
+ return "Error: invalid regex pattern: " + e.getMessage();
+ }
+
+ StringBuilder sb = new StringBuilder();
+ int matches = 0;
+ long lineNo = 0;
+ try (BufferedReader br = Files.newBufferedReader(file, StandardCharsets.UTF_8)) {
+ String line;
+ while ((line = br.readLine()) != null) {
+ lineNo++;
+ if (pattern.matcher(line).find()) {
+ sb.append(lineNo).append(':').append(line).append('\n');
+ matches++;
+ if (matches >= maxMatches) break;
+ }
+ }
+ } catch (IOException e) {
+ return "Error reading " + pathStr + ": " + e.getMessage();
+ }
+ if (matches == 0) {
+ return "[no matches for pattern: " + patternStr + "]";
+ }
+ return "[" + matches + " match" + (matches == 1 ? "" : "es") + "]\n" + sb;
+ }
+
+ /** Recursively delete the session's subtree. Safe to call on missing sessions. */
+ public void cleanupSession(String sessionId) {
+ if (sessionId == null) return;
+ Path dir = root.resolve(safeSegment(sessionId));
+ deleteTree(dir);
+ }
+
+ /** Delete everything below (and including) the root. Idempotent; safe to call multiple times. */
+ @Override
+ public void close() {
+ deleteTree(root);
+ }
+
+ private static void deleteTree(Path dir) {
+ if (!Files.exists(dir)) return;
+ try (Stream stream = Files.walk(dir)) {
+ stream.sorted(Comparator.reverseOrder()).forEach(ToolOutputStore::deleteQuietly);
+ } catch (IOException e) {
+ LOG.warn("Failed to clean up dir {}", dir, e);
+ }
+ }
+
+ private static void deleteQuietly(Path p) {
+ try {
+ Files.deleteIfExists(p);
+ } catch (IOException e) {
+ LOG.debug("Failed to delete {}", p, e);
+ }
+ }
+
+ /**
+ * Resolve {@code pathStr} and return it only if (a) it exists as a regular file and (b) the real
+ * path is under {@code /}. Returns null on any violation — including a null or
+ * empty session id, since without one we cannot scope the check.
+ */
+ private Path validatePath(String sessionId, String pathStr) {
+ if (pathStr == null || pathStr.isEmpty()) return null;
+ if (sessionId == null || sessionId.isEmpty()) return null;
+ Path sessionRoot = root.resolve(safeSegment(sessionId));
+ try {
+ Path real = Paths.get(pathStr).toRealPath();
+ if (!real.startsWith(sessionRoot)) return null;
+ if (!Files.isRegularFile(real)) return null;
+ return real;
+ } catch (IOException | SecurityException e) {
+ return null;
+ }
+ }
+
+ /** Strip anything that could escape a single path segment. */
+ private static String safeSegment(String raw) {
+ if (raw == null || raw.isEmpty()) return "_";
+ StringBuilder sb = new StringBuilder(raw.length());
+ for (int i = 0; i < raw.length(); i++) {
+ char c = raw.charAt(i);
+ if (Character.isLetterOrDigit(c) || c == '-' || c == '_' || c == '.') {
+ sb.append(c);
+ } else {
+ sb.append('_');
+ }
+ }
+ return sb.toString();
+ }
+}
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/Compaction.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/Compaction.java
new file mode 100644
index 00000000000..26eb6024e98
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/Compaction.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.runtime.event;
+
+/**
+ * Emitted by {@code CompactionMiddleware} after it has summarized a prefix of the conversation
+ * history and replaced it in memory. Purely observational — the LLM call that immediately follows
+ * uses the already-compacted history, so consumers just see this as a side-channel notice that
+ * compaction happened. The summary text itself is intentionally not included: it can be large and
+ * would bloat the event stream; operators who need it can read the middleware log.
+ */
+public final class Compaction extends AgentEvent {
+ private final int summarizedCount;
+ private final int keptCount;
+ private final long triggerTokens;
+ private final long observedTokens;
+
+ public Compaction(int summarizedCount, int keptCount, long triggerTokens, long observedTokens) {
+ super(EventType.COMPACTION);
+ this.summarizedCount = summarizedCount;
+ this.keptCount = keptCount;
+ this.triggerTokens = triggerTokens;
+ this.observedTokens = observedTokens;
+ }
+
+ public int summarizedCount() {
+ return summarizedCount;
+ }
+
+ public int keptCount() {
+ return keptCount;
+ }
+
+ public long triggerTokens() {
+ return triggerTokens;
+ }
+
+ public long observedTokens() {
+ return observedTokens;
+ }
+
+ @Override
+ public String toString() {
+ return "Compaction{summarized="
+ + summarizedCount
+ + ", kept="
+ + keptCount
+ + ", triggerTokens="
+ + triggerTokens
+ + ", observedTokens="
+ + observedTokens
+ + "}";
+ }
+}
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventType.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventType.java
index 937422e2bf5..d58e5de2ee7 100644
--- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventType.java
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventType.java
@@ -21,6 +21,8 @@
* Enumerates the types of events emitted by the ReAct agent loop. Each value maps to a
* corresponding {@link AgentEvent} subclass and carries an SSE event name used for wire
* serialization.
+ *
+ * @see org.apache.kyuubi.engine.dataagent.runtime.ReactAgent
*/
public enum EventType {
@@ -51,6 +53,9 @@ public enum EventType {
/** The agent requires user approval before executing a tool. */
APPROVAL_REQUEST("approval_request"),
+ /** The conversation history was compacted by the compaction middleware. */
+ COMPACTION("compaction"),
+
/** The agent has finished its analysis. */
AGENT_FINISH("agent_finish");
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/AgentMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/AgentMiddleware.java
new file mode 100644
index 00000000000..c934bb0882a
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/AgentMiddleware.java
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.runtime.middleware;
+
+import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam;
+import com.openai.models.chat.completions.ChatCompletionMessageParam;
+import java.util.List;
+import java.util.Map;
+import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext;
+import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent;
+import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry;
+
+/**
+ * Middleware interface for the Data Agent ReAct loop. Middlewares are executed in onion-model
+ * order: before_* hooks run first-to-last, after_* hooks run last-to-first.
+ *
+ * All hooks have default no-op implementations. Override only what you need.
+ */
+public interface AgentMiddleware {
+
+ /**
+ * Called once when the middleware is wired into the agent. Register companion tools that are part
+ * of the middleware's contract, or capture a reference to the registry for later use. Dispatched
+ * by {@code ReactAgent.Builder.build} before the agent accepts any requests.
+ */
+ default void onRegister(ToolRegistry registry) {}
+
+ /** Called when the agent starts processing a user query. Runs first-to-last. */
+ default void onAgentStart(AgentRunContext ctx) {}
+
+ /** Called when the agent finishes. Runs last-to-first (cleanup order). */
+ default void onAgentFinish(AgentRunContext ctx) {}
+
+ /**
+ * Called before each LLM invocation. Return non-null to skip or modify the LLM call. Runs
+ * first-to-last.
+ *
+ * @return {@code null} to proceed normally, {@link LlmSkip} to abort, or {@link
+ * LlmModifyMessages} to replace the message list for this call.
+ */
+ default LlmCallAction beforeLlmCall(
+ AgentRunContext ctx, List messages) {
+ return null;
+ }
+
+ /** Called after each LLM invocation. Runs last-to-first. */
+ default void afterLlmCall(AgentRunContext ctx, ChatCompletionAssistantMessageParam response) {}
+
+ /** Called before each tool execution. Return non-null to deny the call. Runs first-to-last. */
+ default ToolCallDenial beforeToolCall(
+ AgentRunContext ctx, String toolCallId, String toolName, Map toolArgs) {
+ return null;
+ }
+
+ /**
+ * Called after each tool execution. Runs last-to-first.
+ *
+ * Returns {@code String} (not {@code void}) so that middlewares can intercept and transform
+ * the tool result before it is fed back to the LLM — e.g. for data masking, output truncation, or
+ * injecting metadata. Return {@code null} to keep the original result unchanged; return a
+ * non-null value to replace it.
+ */
+ default String afterToolCall(
+ AgentRunContext ctx, String toolName, Map toolArgs, String result) {
+ return null;
+ }
+
+ /**
+ * Called for every event before it is emitted. Return null to suppress the event. Runs
+ * first-to-last.
+ */
+ default AgentEvent onEvent(AgentRunContext ctx, AgentEvent event) {
+ return event;
+ }
+
+ /**
+ * Called when a session is closed. Clean up per-session state (scratch files, pending tasks,
+ * counters). Idempotent. Dispatched by {@code ReactAgent.closeSession}.
+ */
+ default void onSessionClose(String sessionId) {}
+
+ /**
+ * Called when the engine is stopping. Release global resources and unblock any threads still
+ * waiting on this middleware. Dispatched by {@code ReactAgent.stop}.
+ */
+ default void onStop() {}
+
+ /**
+ * Base type for {@code beforeLlmCall} return values. Subtypes: {@link LlmSkip} to abort the LLM
+ * call, {@link LlmModifyMessages} to replace the message list for this call.
+ */
+ abstract class LlmCallAction {
+ private LlmCallAction() {}
+ }
+
+ /** Returned from {@code beforeLlmCall} to skip the LLM call and abort the agent loop. */
+ class LlmSkip extends LlmCallAction {
+ private final String reason;
+
+ public LlmSkip(String reason) {
+ this.reason = reason;
+ }
+
+ public String reason() {
+ return reason;
+ }
+ }
+
+ /**
+ * Returned from {@code beforeLlmCall} to replace the message list for this LLM invocation. The
+ * agent loop continues normally with the modified messages.
+ */
+ class LlmModifyMessages extends LlmCallAction {
+ private final List messages;
+
+ public LlmModifyMessages(List messages) {
+ this.messages = messages;
+ }
+
+ public List messages() {
+ return messages;
+ }
+ }
+
+ /** Returned from {@code beforeToolCall} to deny a tool call. Non-null means denied. */
+ class ToolCallDenial {
+ private final String reason;
+
+ public ToolCallDenial(String reason) {
+ this.reason = reason;
+ }
+
+ public String reason() {
+ return reason;
+ }
+ }
+}
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddleware.java
new file mode 100644
index 00000000000..92d25b47b9d
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddleware.java
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.runtime.middleware;
+
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext;
+import org.apache.kyuubi.engine.dataagent.runtime.ApprovalMode;
+import org.apache.kyuubi.engine.dataagent.runtime.event.ApprovalRequest;
+import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry;
+import org.apache.kyuubi.engine.dataagent.tool.ToolRiskLevel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Middleware that enforces human-in-the-loop approval for tool calls based on the {@link
+ * ApprovalMode} and the tool's {@link ToolRiskLevel}.
+ *
+ * When approval is required, an {@link ApprovalRequest} event is emitted to the client via
+ * {@link AgentRunContext#emit}, and the agent thread blocks until the client responds via {@link
+ * #resolve} or the timeout expires.
+ */
+public class ApprovalMiddleware implements AgentMiddleware {
+
+ private static final Logger LOG = LoggerFactory.getLogger(ApprovalMiddleware.class);
+
+ private static final long DEFAULT_TIMEOUT_SECONDS = 300; // 5 minutes
+
+ private final long timeoutSeconds;
+ private final ConcurrentHashMap> pending =
+ new ConcurrentHashMap<>();
+ private ToolRegistry toolRegistry;
+
+ public ApprovalMiddleware() {
+ this(DEFAULT_TIMEOUT_SECONDS);
+ }
+
+ public ApprovalMiddleware(long timeoutSeconds) {
+ this.timeoutSeconds = timeoutSeconds;
+ }
+
+ @Override
+ public void onRegister(ToolRegistry registry) {
+ this.toolRegistry = registry;
+ }
+
+ @Override
+ public ToolCallDenial beforeToolCall(
+ AgentRunContext ctx, String toolCallId, String toolName, Map toolArgs) {
+ ToolRiskLevel riskLevel = toolRegistry.getRiskLevel(toolName);
+
+ if (shouldAutoApprove(ctx.getApprovalMode(), riskLevel)) {
+ return null;
+ }
+
+ String requestId = UUID.randomUUID().toString();
+ CompletableFuture future = new CompletableFuture<>();
+ pending.put(requestId, future);
+
+ ctx.emit(new ApprovalRequest(requestId, toolCallId, toolName, toolArgs, riskLevel));
+ LOG.info("Approval requested for tool '{}' (requestId={})", toolName, requestId);
+
+ try {
+ boolean approved = future.get(timeoutSeconds, TimeUnit.SECONDS);
+ if (!approved) {
+ LOG.info("Tool '{}' denied by user (requestId={})", toolName, requestId);
+ return new ToolCallDenial("User denied execution of " + toolName);
+ }
+ LOG.info("Tool '{}' approved by user (requestId={})", toolName, requestId);
+ return null;
+ } catch (TimeoutException e) {
+ // Complete the future so that a late resolve() call is a harmless no-op
+ // instead of completing a dangling future.
+ future.completeExceptionally(e);
+ LOG.warn("Approval timed out for tool '{}' (requestId={})", toolName, requestId);
+ return new ToolCallDenial("Approval timed out after " + timeoutSeconds + "s for " + toolName);
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ return new ToolCallDenial("Approval interrupted for " + toolName);
+ } catch (Exception e) {
+ LOG.error("Unexpected error waiting for approval", e);
+ return new ToolCallDenial("Approval error: " + e.getMessage());
+ } finally {
+ pending.remove(requestId);
+ }
+ }
+
+ /**
+ * Resolve a pending approval request. Called by the external approval channel (e.g. a Kyuubi
+ * operation or REST endpoint).
+ *
+ * @param requestId the request ID from the {@link ApprovalRequest} event
+ * @param approved true to approve, false to deny
+ * @return true if the request was found and resolved, false if not found (already timed out or
+ * invalid ID)
+ */
+ public boolean resolve(String requestId, boolean approved) {
+ CompletableFuture future = pending.get(requestId);
+ if (future != null) {
+ return future.complete(approved);
+ }
+ LOG.warn("No pending approval found for requestId={}", requestId);
+ return false;
+ }
+
+ /**
+ * Cancel all pending approval requests to unblock any waiting agent threads. Invoked as part of
+ * engine shutdown via {@code ReactAgent.stop}.
+ */
+ @Override
+ public void onStop() {
+ InterruptedException ex = new InterruptedException("Session closed");
+ pending.forEachKey(
+ Long.MAX_VALUE,
+ key -> {
+ CompletableFuture future = pending.remove(key);
+ if (future != null) {
+ future.completeExceptionally(ex);
+ }
+ });
+ }
+
+ private static boolean shouldAutoApprove(ApprovalMode mode, ToolRiskLevel riskLevel) {
+ if (mode == ApprovalMode.AUTO_APPROVE) {
+ return true;
+ }
+ if (mode == ApprovalMode.NORMAL && riskLevel == ToolRiskLevel.SAFE) {
+ return true;
+ }
+ // STRICT: all tools require approval
+ return false;
+ }
+}
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddleware.java
new file mode 100644
index 00000000000..acab7ae8b75
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddleware.java
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.runtime.middleware;
+
+import com.openai.client.OpenAIClient;
+import com.openai.models.chat.completions.ChatCompletion;
+import com.openai.models.chat.completions.ChatCompletionCreateParams;
+import com.openai.models.chat.completions.ChatCompletionMessageParam;
+import com.openai.models.chat.completions.ChatCompletionMessageToolCall;
+import com.openai.models.chat.completions.ChatCompletionSystemMessageParam;
+import com.openai.models.chat.completions.ChatCompletionUserMessageParam;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext;
+import org.apache.kyuubi.engine.dataagent.runtime.ConversationMemory;
+import org.apache.kyuubi.engine.dataagent.runtime.event.Compaction;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Middleware that compacts conversation history when the prompt grows large.
+ *
+ * Trigger formula:
+ *
+ *
+ * predicted_this_turn_prompt_tokens
+ * = last_llm_call_total_tokens // prompt + completion of the previous call
+ * + estimate_new_tail(messages) // chars / 4, for content appended after the
+ * // last assistant message (tool results, new user)
+ *
+ *
+ * History (already sent to the LLM) must use real token counts — we read them straight off {@link
+ * ConversationMemory#getLastTotalTokens()}, which the provider updates after every call. We use
+ * total (prompt + completion) rather than just prompt because the previous call's completion
+ * — e.g. an assistant {@code tool_call} message — is already appended to history and will be
+ * tokenized into the next prompt; the tail estimator starts strictly after the last
+ * assistant message, so without completion we would miss it. Only content appended since that
+ * assistant message is estimated, which catches spikes before the next API report arrives.
+ *
+ * Post-compaction persistence: the summary + kept tail replace {@link
+ * ConversationMemory}'s history via {@link ConversationMemory#replaceHistory}. All subsequent turns
+ * in this session read the compacted history — we do not re-summarize each turn. Because the next
+ * LLM call uses the compacted messages, its reported {@code prompt_tokens} will be small, naturally
+ * preventing retriggering.
+ *
+ *
Thread safety / shared instance: Instances of this middleware are shared across all
+ * sessions inside a provider (see {@code OpenAiProvider} javadoc). All per-session state
+ * (cumulative and last-call totals) lives on {@link ConversationMemory}, so this middleware itself
+ * is stateless across sessions and requires no per-session cleanup.
+ *
+ *
Tool-call pair invariant: The split never separates an assistant message bearing {@code
+ * tool_calls} from the {@code tool_result} messages that satisfy those calls — an orphan
+ * tool_result is rejected by the OpenAI API with HTTP 400.
+ *
+ *
Failure handling: summarizer failures propagate out of {@code beforeLlmCall} to the
+ * agent's top-level catch, which surfaces an {@code AgentError} event. We don't silently skip
+ * compaction — a broken summarizer is a real problem the operator needs to see.
+ *
+ *
Disabling: to effectively turn compaction off, construct with a very large {@code
+ * triggerPromptTokens} (e.g., {@link Long#MAX_VALUE}).
+ *
+ *
TODO: support a separate (cheaper) summarization model distinct from the main agent model.
+ */
+public class CompactionMiddleware implements AgentMiddleware {
+
+ private static final Logger LOG = LoggerFactory.getLogger(CompactionMiddleware.class);
+
+ /** Number of recent user turns (and their assistant/tool companions) to preserve verbatim. */
+ private static final int KEEP_RECENT_TURNS = 4;
+
+ private static final String COMPACTION_SYSTEM_PROMPT =
+ "SYSTEM OPERATION — this is an automated context compaction step, NOT a user message.\n"
+ + "\n"
+ + "You are summarizing a conversation between a user and a Data Agent (a ReAct agent"
+ + " that executes SQL and analytics tools against Kyuubi). The user has NOT asked you"
+ + " to summarize anything. Do not address the user. Do not ask questions. Produce only"
+ + " the summary in the schema below.\n"
+ + "\n"
+ + "Goal: produce a dense, structured summary the agent can resume from without losing"
+ + " critical context. Preserve concrete details verbatim — file paths, table names,"
+ + " schema definitions, SQL snippets, column names, error messages.\n"
+ + "\n"
+ + "Output EXACTLY these 8 sections, in order, as markdown headers:\n"
+ + "\n"
+ + "1. ## User Intent\n"
+ + " The user's original request, restated in full. Preserve the literal phrasing of"
+ + " the ask. Include follow-up refinements.\n"
+ + "\n"
+ + "2. ## Key Concepts\n"
+ + " Domain terms, data sources, tables, schemas, SQL dialects, and business logic"
+ + " the agent has been reasoning about.\n"
+ + "\n"
+ + "3. ## Files and Code\n"
+ + " File paths, query text, DDL, or code artifacts referenced. Include verbatim SQL"
+ + " snippets that produced meaningful results.\n"
+ + "\n"
+ + "4. ## Errors and Recoveries\n"
+ + " Errors encountered (SQL syntax, permission, timeout, tool failures), what was"
+ + " tried, and what resolved them. Preserve error messages verbatim.\n"
+ + "\n"
+ + "5. ## Pending Work\n"
+ + " Tasks the agent identified but has not completed yet.\n"
+ + "\n"
+ + "6. ## Current State\n"
+ + " Where the agent is right now — what question is open, what data has been"
+ + " retrieved, what hypothesis is being tested.\n"
+ + "\n"
+ + "7. ## Next Step\n"
+ + " The immediate next action the agent should take when resuming.\n"
+ + "\n"
+ + "8. ## Tool Usage Summary\n"
+ + " Which tools were called, how many times, and notable results.\n"
+ + "\n"
+ + "CRITICAL:\n"
+ + "- DO NOT ask the user about this summary.\n"
+ + "- DO NOT mention that compaction occurred in any future assistant response.\n"
+ + "- DO NOT invent details not present in the conversation.\n"
+ + "- DO NOT output anything outside the 8 sections.\n";
+
+ private final OpenAIClient client;
+ private final String summarizerModel;
+ private final long triggerPromptTokens;
+
+ public CompactionMiddleware(
+ OpenAIClient client, String summarizerModel, long triggerPromptTokens) {
+ this.client = client;
+ this.summarizerModel = summarizerModel;
+ this.triggerPromptTokens = triggerPromptTokens;
+ }
+
+ @Override
+ public LlmCallAction beforeLlmCall(
+ AgentRunContext ctx, List messages) {
+ ConversationMemory mem = ctx.getMemory();
+ // 1) Real token count of the previous LLM call (prompt + completion, i.e. everything through
+ // the last assistant message, which is now part of history). 0 on the first call.
+ long lastTotal = mem.getLastTotalTokens();
+ // 2) Estimated tokens appended to the tail after the last assistant (tool_results, new user).
+ long newTailEstimate = estimateTailAfterLastAssistant(messages);
+
+ if (lastTotal + newTailEstimate < triggerPromptTokens) {
+ return null;
+ }
+
+ List history = mem.getHistory();
+
+ // 3) Split history into old (to summarize) and kept (recent tail), never orphaning a
+ // tool_result.
+ Split split = computeSplit(history, KEEP_RECENT_TURNS);
+ if (split.old.isEmpty()) {
+ return null;
+ }
+
+ String summary = summarize(mem.getSystemPrompt(), split.old);
+
+ // 4) Build the compacted history and persist into ConversationMemory.
+ List compacted = new ArrayList<>(1 + split.kept.size());
+ compacted.add(wrapSummaryAsUserMessage(summary));
+ compacted.addAll(split.kept);
+ mem.replaceHistory(compacted);
+
+ LOG.info(
+ "Compacted {} old msgs into 1 summary; kept {} tail msgs (lastTotal={}, newTail~={})",
+ split.old.size(),
+ split.kept.size(),
+ lastTotal,
+ newTailEstimate);
+
+ ctx.emit(
+ new Compaction(
+ split.old.size(), split.kept.size(), triggerPromptTokens, lastTotal + newTailEstimate));
+
+ return new LlmModifyMessages(mem.buildLlmMessages());
+ }
+
+ /** Call the LLM to produce a summary of {@code oldMessages}. Failures propagate. */
+ private String summarize(String agentSystemPrompt, List oldMessages) {
+ String systemPrompt = COMPACTION_SYSTEM_PROMPT;
+ if (agentSystemPrompt != null && !agentSystemPrompt.isEmpty()) {
+ systemPrompt =
+ systemPrompt
+ + "\n---\nFor context, the agent's own system prompt is:\n"
+ + agentSystemPrompt;
+ }
+
+ String rendered = renderAsText(oldMessages);
+
+ ChatCompletionCreateParams params =
+ ChatCompletionCreateParams.builder()
+ .model(summarizerModel)
+ .temperature(0.0)
+ .addMessage(
+ ChatCompletionMessageParam.ofSystem(
+ ChatCompletionSystemMessageParam.builder().content(systemPrompt).build()))
+ .addMessage(
+ ChatCompletionMessageParam.ofUser(
+ ChatCompletionUserMessageParam.builder().content(rendered).build()))
+ .build();
+
+ ChatCompletion response = client.chat().completions().create(params);
+ return response.choices().get(0).message().content().get();
+ }
+
+ // ----- helpers -----
+
+ /** Sum of content characters in messages after the last assistant, using ~4 chars per token. */
+ static long estimateTailAfterLastAssistant(List messages) {
+ int lastAssistantIdx = -1;
+ for (int i = messages.size() - 1; i >= 0; i--) {
+ if (messages.get(i).isAssistant()) {
+ lastAssistantIdx = i;
+ break;
+ }
+ }
+ long totalChars = 0;
+ for (int i = lastAssistantIdx + 1; i < messages.size(); i++) {
+ totalChars += contentCharCount(messages.get(i));
+ }
+ return totalChars / 4;
+ }
+
+ private static long contentCharCount(ChatCompletionMessageParam msg) {
+ if (msg.isUser()) {
+ return msg.asUser().content().text().map(String::length).orElse(0);
+ }
+ if (msg.isTool()) {
+ return msg.asTool().content().text().map(String::length).orElse(0);
+ }
+ if (msg.isAssistant()) {
+ return msg.asAssistant().content().flatMap(c -> c.text()).map(String::length).orElse(0);
+ }
+ if (msg.isSystem()) {
+ return msg.asSystem().content().text().map(String::length).orElse(0);
+ }
+ return 0;
+ }
+
+ /**
+ * Render a list of messages as plain text for the summarizer's user turn. Tool calls and tool
+ * results are rendered as tagged text so the summarizer LLM doesn't try to continue them as live
+ * agent state.
+ */
+ static String renderAsText(List messages) {
+ StringBuilder sb = new StringBuilder();
+ for (ChatCompletionMessageParam msg : messages) {
+ if (sb.length() > 0) sb.append("\n\n");
+ if (msg.isUser()) {
+ sb.append("USER: ").append(extractUserContent(msg));
+ } else if (msg.isAssistant()) {
+ sb.append("ASSISTANT: ").append(extractAssistantContent(msg));
+ msg.asAssistant()
+ .toolCalls()
+ .ifPresent(
+ calls -> {
+ for (ChatCompletionMessageToolCall tc : calls) {
+ if (tc.isFunction()) {
+ sb.append("\n[tool_call: ")
+ .append(tc.asFunction().function().name())
+ .append("(")
+ .append(tc.asFunction().function().arguments())
+ .append(") id=")
+ .append(tc.asFunction().id())
+ .append("]");
+ }
+ }
+ });
+ } else if (msg.isTool()) {
+ sb.append("[tool_result id=")
+ .append(msg.asTool().toolCallId())
+ .append("]: ")
+ .append(extractToolContent(msg));
+ } else if (msg.isSystem()) {
+ // system prompt should not appear in oldMessages, but render defensively
+ sb.append("SYSTEM: ").append(extractSystemContent(msg));
+ }
+ }
+ return sb.toString();
+ }
+
+ private static String extractUserContent(ChatCompletionMessageParam msg) {
+ return msg.asUser().content().text().orElse("[non-text content]");
+ }
+
+ private static String extractAssistantContent(ChatCompletionMessageParam msg) {
+ return msg.asAssistant().content().map(c -> c.text().orElse("[non-text content]")).orElse("");
+ }
+
+ private static String extractToolContent(ChatCompletionMessageParam msg) {
+ return msg.asTool().content().text().orElse("[non-text content]");
+ }
+
+ private static String extractSystemContent(ChatCompletionMessageParam msg) {
+ return msg.asSystem().content().text().orElse("[non-text content]");
+ }
+
+ /** Result of splitting the history into a summarizable prefix and a kept tail. */
+ static final class Split {
+
+ final List old;
+ final List kept;
+
+ Split(List old, List kept) {
+ this.old = old;
+ this.kept = kept;
+ }
+ }
+
+ /**
+ * Split the history at a boundary that preserves the last {@code keepRecentTurns} user messages,
+ * with adjustments so that no assistant-tool_use is separated from its tool_results.
+ */
+ static Split computeSplit(List history, int keepRecentTurns) {
+ if (history.size() <= 2) {
+ return new Split(new ArrayList<>(), new ArrayList<>(history));
+ }
+
+ // Walk from the tail, count user boundaries. If the history does not contain enough user
+ // messages to satisfy keepRecentTurns, keep everything (splitIdx = 0); the empty-old check
+ // in beforeLlmCall will then skip this turn gracefully.
+ int userBoundariesFound = 0;
+ int splitIdx = 0;
+ for (int i = history.size() - 1; i >= 0; i--) {
+ if (history.get(i).isUser()) {
+ userBoundariesFound++;
+ if (userBoundariesFound == keepRecentTurns) {
+ splitIdx = i;
+ break;
+ }
+ }
+ }
+
+ // Protect tool-call / tool-result pairing: never split between an assistant that issued
+ // tool_calls and the tool_results that satisfy them.
+ while (splitIdx > 0) {
+ ChatCompletionMessageParam prev = history.get(splitIdx - 1);
+ if (prev.isTool()) {
+ splitIdx--;
+ continue;
+ }
+ if (prev.isAssistant()) {
+ boolean hasToolCalls = prev.asAssistant().toolCalls().map(List::size).orElse(0) > 0;
+ if (hasToolCalls) {
+ splitIdx--;
+ continue;
+ }
+ }
+ break;
+ }
+
+ // Also guard against the edge case: if kept contains a tool_result whose tool_call id is
+ // defined only in old, pull that assistant (and its siblings) into kept too.
+ Set keptCallIds = collectToolCallIds(history.subList(splitIdx, history.size()));
+ if (!keptCallIds.isEmpty()) {
+ while (splitIdx > 0) {
+ ChatCompletionMessageParam prev = history.get(splitIdx - 1);
+ if (!prev.isAssistant()) break;
+ List calls = prev.asAssistant().toolCalls().orElse(null);
+ if (calls == null || calls.isEmpty()) break;
+ boolean satisfiesKept = false;
+ for (ChatCompletionMessageToolCall tc : calls) {
+ if (tc.isFunction() && keptCallIds.contains(tc.asFunction().id())) {
+ satisfiesKept = true;
+ break;
+ }
+ }
+ if (!satisfiesKept) break;
+ splitIdx--;
+ }
+ }
+
+ List oldPart = new ArrayList<>(history.subList(0, splitIdx));
+ List keptPart =
+ new ArrayList<>(history.subList(splitIdx, history.size()));
+ return new Split(oldPart, keptPart);
+ }
+
+ private static Set collectToolCallIds(List slice) {
+ Set ids = new HashSet<>();
+ for (ChatCompletionMessageParam m : slice) {
+ if (m.isTool()) {
+ ids.add(m.asTool().toolCallId());
+ }
+ }
+ return ids;
+ }
+
+ private static ChatCompletionMessageParam wrapSummaryAsUserMessage(String summary) {
+ String body = "\n" + summary + "\n ";
+ return ChatCompletionMessageParam.ofUser(
+ ChatCompletionUserMessageParam.builder().content(body).build());
+ }
+}
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/LoggingMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/LoggingMiddleware.java
new file mode 100644
index 00000000000..e0a5c2364eb
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/LoggingMiddleware.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.runtime.middleware;
+
+import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam;
+import com.openai.models.chat.completions.ChatCompletionMessageParam;
+import java.util.List;
+import java.util.Map;
+import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext;
+import org.apache.kyuubi.engine.dataagent.runtime.event.AgentError;
+import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent;
+import org.apache.kyuubi.engine.dataagent.runtime.event.StepStart;
+import org.apache.kyuubi.engine.dataagent.runtime.event.ToolResult;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.slf4j.MDC;
+
+/**
+ * Logging middleware that prints agent lifecycle events for debugging and observability.
+ *
+ * Picks up {@code operationId} and {@code sessionId} from SLF4J MDC (set by ExecuteStatement) to
+ * tag every log line with the Kyuubi operation/session context.
+ *
+ *
Log structure:
+ *
+ *
+ * [op:abcd1234] START user_input="..."
+ * [op:abcd1234] Step 1
+ * [op:abcd1234] LLM call: step=1, messages=3
+ * [op:abcd1234] LLM response: step=1, content="...(truncated)", tool_calls=1
+ * [op:abcd1234] Tool call: sql_query {sql=SELECT ...}
+ * [op:abcd1234] Tool result: sql_query -> "| col1 | col2 |...(truncated)"
+ * [op:abcd1234] FINISH steps=2, tokens=1234
+ *
+ */
+public class LoggingMiddleware implements AgentMiddleware {
+
+ private static final Logger LOG = LoggerFactory.getLogger("DataAgent");
+
+ private static final int MAX_PREVIEW_LENGTH = 500;
+
+ private static String prefix() {
+ String sessionId = MDC.get("sessionId");
+ String opId = MDC.get("operationId");
+ StringBuilder sb = new StringBuilder();
+ if (sessionId != null) {
+ sb.append("[s:").append(shortId(sessionId)).append("]");
+ }
+ if (opId != null) {
+ sb.append("[op:").append(shortId(opId)).append("]");
+ }
+ if (sb.length() > 0) {
+ sb.append(" ");
+ }
+ return sb.toString();
+ }
+
+ /**
+ * Take the first segment of a UUID (before the first dash). e.g. "327d8c5b-91ef-..." → "327d8c5b"
+ */
+ private static String shortId(String id) {
+ int dash = id.indexOf('-');
+ return dash > 0 ? id.substring(0, dash) : id;
+ }
+
+ @Override
+ public void onAgentStart(AgentRunContext ctx) {
+ LOG.debug("{}START user_input=\"{}\"", prefix(), truncate(ctx.getMemory().getLastUserInput()));
+ }
+
+ @Override
+ public void onAgentFinish(AgentRunContext ctx) {
+ LOG.info(
+ "{}FINISH steps={}, prompt_tokens={}, completion_tokens={}, total_tokens={}",
+ prefix(),
+ ctx.getIteration(),
+ ctx.getPromptTokens(),
+ ctx.getCompletionTokens(),
+ ctx.getTotalTokens());
+ }
+
+ @Override
+ public LlmCallAction beforeLlmCall(
+ AgentRunContext ctx, List messages) {
+ LOG.info("{}LLM call: step={}, messages={}", prefix(), ctx.getIteration(), messages.size());
+ return null;
+ }
+
+ @Override
+ public void afterLlmCall(AgentRunContext ctx, ChatCompletionAssistantMessageParam response) {
+ String content = response.content().map(Object::toString).orElse("");
+ int toolCallCount = response.toolCalls().map(List::size).orElse(0);
+ LOG.info(
+ "{}LLM response: step={}, content=\"{}\", tool_calls={}, "
+ + "usage(cumulative): prompt={}, completion={}, total={}",
+ prefix(),
+ ctx.getIteration(),
+ truncate(content),
+ toolCallCount,
+ ctx.getPromptTokens(),
+ ctx.getCompletionTokens(),
+ ctx.getTotalTokens());
+ }
+
+ @Override
+ public ToolCallDenial beforeToolCall(
+ AgentRunContext ctx, String toolCallId, String toolName, Map toolArgs) {
+ LOG.info("{}Tool call: id={}, name={}", prefix(), toolCallId, toolName);
+ LOG.debug("{}Tool args: {}", prefix(), toolArgs);
+ return null;
+ }
+
+ @Override
+ public String afterToolCall(
+ AgentRunContext ctx, String toolName, Map toolArgs, String result) {
+ LOG.info("{}Tool result: {} -> \"{}\"", prefix(), toolName, truncate(result));
+ return null;
+ }
+
+ @Override
+ public AgentEvent onEvent(AgentRunContext ctx, AgentEvent event) {
+ switch (event.eventType()) {
+ case STEP_START:
+ LOG.info("{}Step {}", prefix(), ((StepStart) event).stepNumber());
+ break;
+ case ERROR:
+ LOG.error("{}ERROR: {}", prefix(), ((AgentError) event).message());
+ break;
+ case TOOL_RESULT:
+ ToolResult tr = (ToolResult) event;
+ if (tr.isError()) {
+ LOG.warn("{}Tool error: {} -> \"{}\"", prefix(), tr.toolName(), truncate(tr.output()));
+ }
+ break;
+ default:
+ break;
+ }
+ return event;
+ }
+
+ private static String truncate(String s) {
+ if (s == null) return "";
+ return s.length() <= MAX_PREVIEW_LENGTH ? s : s.substring(0, MAX_PREVIEW_LENGTH) + "...";
+ }
+}
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddleware.java
new file mode 100644
index 00000000000..87aad9f3255
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddleware.java
@@ -0,0 +1,191 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.runtime.middleware;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Path;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicLong;
+import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext;
+import org.apache.kyuubi.engine.dataagent.runtime.ToolOutputStore;
+import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry;
+import org.apache.kyuubi.engine.dataagent.tool.output.GrepToolOutputTool;
+import org.apache.kyuubi.engine.dataagent.tool.output.ReadToolOutputTool;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Input-gate middleware that offloads oversized tool outputs to per-session temp files and replaces
+ * the in-memory tool result with a small head + tail preview plus a retrieval hint. The cheapest
+ * and highest-ROI defense against context rot.
+ *
+ * Trigger: {@code result.lines > MAX_LINES} OR {@code result.bytes > MAX_BYTES}, first to
+ * trip wins. Thresholds are hardcoded — the ReAct loop can compensate for suboptimal defaults by
+ * calling the retrieval tools more aggressively.
+ *
+ *
Exempt tools: {@link ReadToolOutputTool} and {@link GrepToolOutputTool} never go
+ * through the gate — the agent would otherwise recursively re-offload its own retrieval output.
+ *
+ *
Session lifecycle: a monotonic counter per session is used to name temp files. {@link
+ * #onSessionClose(String)} wipes the counter and the per-session temp dir; call it from the
+ * provider's {@code close(sessionId)} hook (not from {@link #onAgentFinish}, which fires on every
+ * turn and would invalidate paths the LLM still needs to reference).
+ */
+public class ToolResultOffloadMiddleware implements AgentMiddleware {
+
+ private static final Logger LOG = LoggerFactory.getLogger(ToolResultOffloadMiddleware.class);
+
+ static final int MAX_LINES = 500;
+ static final int MAX_BYTES = 50 * 1024;
+ static final int PREVIEW_HEAD_LINES = 20;
+ static final int PREVIEW_TAIL_LINES = 20;
+
+ private static final Set EXEMPT_TOOLS =
+ new HashSet<>(Arrays.asList(ReadToolOutputTool.NAME, GrepToolOutputTool.NAME));
+
+ private final ToolOutputStore store = ToolOutputStore.create();
+ private final ConcurrentHashMap counters = new ConcurrentHashMap<>();
+
+ /**
+ * Register the companion retrieval tools so the LLM can reach back into offloaded files. Paired
+ * with the preview hint emitted from {@link #afterToolCall}; skipping this registration would
+ * leave the LLM dangling on file paths it can never read.
+ */
+ @Override
+ public void onRegister(ToolRegistry registry) {
+ registry.register(new ReadToolOutputTool(store));
+ registry.register(new GrepToolOutputTool(store));
+ }
+
+ @Override
+ public String afterToolCall(
+ AgentRunContext ctx, String toolName, Map toolArgs, String result) {
+ if (result.isEmpty()) return null;
+ if (EXEMPT_TOOLS.contains(toolName)) return null;
+
+ int bytes = result.getBytes(StandardCharsets.UTF_8).length;
+ int lines = countLines(result);
+ if (lines <= MAX_LINES && bytes <= MAX_BYTES) {
+ return null;
+ }
+
+ // AgentRunContext.sessionId is null in unit-test constructions that don't exercise offload.
+ // In production the provider always threads it through, so treat null as "skip offload".
+ String sessionId = ctx.getSessionId();
+ if (sessionId == null) return null;
+
+ long n = counters.computeIfAbsent(sessionId, k -> new AtomicLong()).incrementAndGet();
+ String toolCallId = toolName + "_" + n;
+
+ Path file;
+ try {
+ file = store.write(sessionId, toolCallId, result);
+ } catch (IOException e) {
+ LOG.warn(
+ "Tool output offload failed for tool={} session={}; passing through full output",
+ toolName,
+ sessionId,
+ e);
+ return null;
+ }
+
+ LOG.info(
+ "Offloaded tool={} session={} ({} lines / {} bytes) -> {}",
+ toolName,
+ sessionId,
+ lines,
+ bytes,
+ file.getFileName());
+ return buildPreview(result, lines, bytes, file);
+ }
+
+ /** Clean up counter and temp dir for a closed session. Idempotent. */
+ @Override
+ public void onSessionClose(String sessionId) {
+ if (sessionId == null) return;
+ counters.remove(sessionId);
+ store.cleanupSession(sessionId);
+ }
+
+ /** Engine-wide shutdown: drop the temp root. */
+ @Override
+ public void onStop() {
+ store.close();
+ }
+
+ static int countLines(String s) {
+ if (s.isEmpty()) return 0;
+ int count = 1;
+ for (int i = 0; i < s.length(); i++) {
+ if (s.charAt(i) == '\n') count++;
+ }
+ // Trailing newline means the last "line" is empty, but still counted. Good enough for
+ // gating decisions — we're not trying to match `wc -l` exactly.
+ return count;
+ }
+
+ static String buildPreview(String full, int lines, int bytes, Path file) {
+ String[] split = full.split("\n", -1);
+ int headEnd = Math.min(PREVIEW_HEAD_LINES, split.length);
+ int tailStart = Math.max(headEnd, split.length - PREVIEW_TAIL_LINES);
+
+ StringBuilder sb = new StringBuilder();
+ sb.append("[Tool output truncated: ")
+ .append(lines)
+ .append(" lines, ")
+ .append(humanBytes(bytes))
+ .append("]\n")
+ .append("Saved to: ")
+ .append(file.toString())
+ .append("\n\n--- First ")
+ .append(headEnd)
+ .append(" lines ---\n");
+ for (int i = 0; i < headEnd; i++) {
+ sb.append(split[i]).append('\n');
+ }
+ if (tailStart < split.length) {
+ int tailCount = split.length - tailStart;
+ sb.append("--- Last ").append(tailCount).append(" lines ---\n");
+ for (int i = tailStart; i < split.length; i++) {
+ sb.append(split[i]).append('\n');
+ }
+ }
+ sb.append("\nUse ")
+ .append(ReadToolOutputTool.NAME)
+ .append("(path, offset, limit) to read windows, or ")
+ .append(GrepToolOutputTool.NAME)
+ .append("(path, pattern, max_matches) to search.");
+ return sb.toString();
+ }
+
+ private static String humanBytes(long bytes) {
+ if (bytes < 1024) return bytes + " B";
+ if (bytes < 1024 * 1024) return String.format("%.1f KB", bytes / 1024.0);
+ return String.format("%.1f MB", bytes / (1024.0 * 1024.0));
+ }
+
+ /** Visible for testing. */
+ int trackedSessions() {
+ return counters.size();
+ }
+}
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/AgentTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/AgentTool.java
index 297a7c8d74a..b19f6787c87 100644
--- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/AgentTool.java
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/AgentTool.java
@@ -50,7 +50,10 @@ default ToolRiskLevel riskLevel() {
* Execute the tool with the given deserialized arguments.
*
* @param args the deserialized arguments from the LLM's tool call
+ * @param ctx per-invocation context (session id, etc.); never null — use {@link
+ * ToolContext#EMPTY} for calls without a session. Tools that are session-agnostic may ignore
+ * it.
* @return the result string to feed back to the LLM
*/
- String execute(T args);
+ String execute(T args, ToolContext ctx);
}
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolContext.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolContext.java
new file mode 100644
index 00000000000..2625f3bab59
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolContext.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.tool;
+
+/**
+ * Per-invocation context handed to {@link AgentTool#execute(Object, ToolContext)}. Today it carries
+ * just the session id so session-scoped tools (e.g. the offloaded tool-output retrievers) can
+ * restrict their filesystem view; extend here when a tool needs user/approval/etc.
+ */
+public final class ToolContext {
+
+ /** Sentinel for call sites that have no session to attribute — tests, direct CLI use. */
+ public static final ToolContext EMPTY = new ToolContext(null);
+
+ private final String sessionId;
+
+ public ToolContext(String sessionId) {
+ this.sessionId = sessionId;
+ }
+
+ /** Upstream session id, or {@code null} when invoked outside a session. */
+ public String sessionId() {
+ return sessionId;
+ }
+}
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistry.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistry.java
index a403c66b58d..3a11bab567c 100644
--- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistry.java
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistry.java
@@ -26,15 +26,16 @@
import com.openai.models.chat.completions.ChatCompletionTool;
import java.util.LinkedHashMap;
import java.util.Map;
-import java.util.concurrent.Callable;
-import java.util.concurrent.ExecutionException;
+import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.RejectedExecutionException;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
-import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicLong;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -68,15 +69,20 @@ public class ToolRegistry implements AutoCloseable {
*/
private static final int MAX_POOL_SIZE = 8;
+ /** Default wall-clock cap for a single tool call, used when no explicit value is configured. */
+ public static final long DEFAULT_TIMEOUT_SECONDS = 300;
+
private final Map> tools = new LinkedHashMap<>();
private volatile Map cachedSpecs;
private final long toolCallTimeoutSeconds;
private final ExecutorService executor;
+ private final ScheduledExecutorService timeoutScheduler;
/**
- * @param toolCallTimeoutSeconds wall-clock cap on every {@link #executeTool} call, sourced from
- * {@code kyuubi.engine.data.agent.tool.call.timeout}. When the timeout fires, the thread is
- * interrupted and a descriptive error is returned to the LLM.
+ * @param toolCallTimeoutSeconds wall-clock cap on every {@link #executeTool} / {@link
+ * #submitTool} call, sourced from {@code kyuubi.engine.data.agent.tool.call.timeout}. When
+ * the timeout fires, the worker thread is interrupted and a descriptive error is returned to
+ * the LLM.
*/
public ToolRegistry(long toolCallTimeoutSeconds) {
this.toolCallTimeoutSeconds = toolCallTimeoutSeconds;
@@ -93,12 +99,20 @@ public ToolRegistry(long toolCallTimeoutSeconds) {
t.setDaemon(true);
return t;
});
+ this.timeoutScheduler =
+ Executors.newSingleThreadScheduledExecutor(
+ r -> {
+ Thread t = new Thread(r, "tool-call-timeout");
+ t.setDaemon(true);
+ return t;
+ });
}
- /** Shut down the worker pool. Idempotent. */
+ /** Shut down the worker pool and the timeout scheduler. Idempotent. */
@Override
public void close() {
executor.shutdownNow();
+ timeoutScheduler.shutdownNow();
}
/** Register a tool. Keyed by {@link AgentTool#name()}. */
@@ -138,61 +152,100 @@ private synchronized Map ensureSpecs() {
}
/**
- * Execute a tool call: deserialize the JSON args, then delegate to the tool, with a wall-clock
- * timeout sourced from {@code kyuubi.engine.data.agent.tool.call.timeout}. If the tool does not
- * finish within the timeout, the worker thread is interrupted and a descriptive error is returned
- * to the LLM so it can react (e.g. simplify the query, retry with LIMIT).
+ * Synchronous entry point. Blocks until the tool finishes, times out, or the registry rejects the
+ * submission. Errors are surfaced as strings (never as exceptions) so the LLM can observe and
+ * react to them.
*
* @param toolName the function name from the LLM response
* @param argsJson the raw JSON arguments string
* @return the result string, or an error message
*/
- @SuppressWarnings("unchecked")
public String executeTool(String toolName, String argsJson) {
- AgentTool> tool;
+ return submitTool(toolName, argsJson, ToolContext.EMPTY).join();
+ }
+
+ /** Synchronous entry point with an explicit {@link ToolContext}. */
+ public String executeTool(String toolName, String argsJson, ToolContext ctx) {
+ return submitTool(toolName, argsJson, ctx).join();
+ }
+
+ public CompletableFuture submitTool(String toolName, String argsJson) {
+ return submitTool(toolName, argsJson, ToolContext.EMPTY);
+ }
+
+ /**
+ * Asynchronous entry point. Deserialize args, run the tool on the worker pool, and apply a
+ * wall-clock timeout sourced from {@code kyuubi.engine.data.agent.tool.call.timeout}. The
+ * returned future is guaranteed to complete normally — timeouts, pool saturation, unknown tool,
+ * and execution failures are all translated into error strings. Callers can therefore use {@code
+ * .join()} / {@code .get()} without handling {@link java.util.concurrent.TimeoutException} or
+ * {@link java.util.concurrent.ExecutionException}.
+ *
+ * @param toolName the function name from the LLM response
+ * @param argsJson the raw JSON arguments string
+ */
+ @SuppressWarnings("unchecked")
+ public CompletableFuture submitTool(String toolName, String argsJson, ToolContext ctx) {
+ AgentTool tool;
synchronized (this) {
- tool = tools.get(toolName);
+ tool = (AgentTool) tools.get(toolName);
}
if (tool == null) {
- return "Error: unknown tool '" + toolName + "'";
+ return CompletableFuture.completedFuture("Error: unknown tool '" + toolName + "'");
}
- return executeWithTimeout((AgentTool) tool, argsJson);
- }
+ ToolContext toolCtx = ctx != null ? ctx : ToolContext.EMPTY;
- private String executeWithTimeout(AgentTool tool, String argsJson) {
- Callable task =
- () -> {
- T args = JSON.readValue(argsJson, tool.argsType());
- return tool.execute(args);
- };
- Future future;
+ CompletableFuture result = new CompletableFuture<>();
+ Future> submitted;
try {
- future = executor.submit(task);
+ submitted =
+ executor.submit(
+ () -> {
+ try {
+ Object args = JSON.readValue(argsJson, tool.argsType());
+ String out = tool.execute(args, toolCtx);
+ // When the timeout handler interrupts us, the tool may still unwind cleanly and
+ // produce a stale return value — don't race the scheduler's timeout message with
+ // it. Let the timeout path be the single authority for the final result.
+ if (!Thread.currentThread().isInterrupted()) {
+ result.complete(out);
+ }
+ } catch (Exception e) {
+ result.complete("Error executing " + toolName + ": " + e.getMessage());
+ }
+ });
} catch (RejectedExecutionException e) {
- LOG.warn("Tool call '{}' rejected — worker pool saturated at {}", tool.name(), MAX_POOL_SIZE);
- return "Error: tool call '"
- + tool.name()
- + "' rejected — server is handling too many concurrent tool calls. "
- + "Retry in a moment.";
- }
- try {
- return future.get(toolCallTimeoutSeconds, TimeUnit.SECONDS);
- } catch (TimeoutException e) {
- future.cancel(true);
- LOG.warn("Tool call '{}' timed out after {} seconds", tool.name(), toolCallTimeoutSeconds);
- return "Error: tool call '"
- + tool.name()
- + "' timed out after "
- + toolCallTimeoutSeconds
- + " seconds. "
- + "Try simplifying the query or adding filters to reduce execution time.";
- } catch (ExecutionException e) {
- Throwable cause = e.getCause() != null ? e.getCause() : e;
- return "Error executing " + tool.name() + ": " + cause.getMessage();
- } catch (InterruptedException e) {
- Thread.currentThread().interrupt();
- return "Error: tool call '" + tool.name() + "' was interrupted.";
+ LOG.warn("Tool call '{}' rejected — worker pool saturated at {}", toolName, MAX_POOL_SIZE);
+ return CompletableFuture.completedFuture(
+ "Error: tool call '"
+ + toolName
+ + "' rejected — server is handling too many concurrent tool calls. "
+ + "Retry in a moment.");
}
+
+ ScheduledFuture> timer =
+ timeoutScheduler.schedule(
+ () -> {
+ if (!result.isDone()) {
+ // cancel(true) interrupts the worker thread directly — the inner task's
+ // catch-all will see the interrupt and call result.complete(...), but the
+ // timeout message below wins because complete() is idempotent on first-winner.
+ submitted.cancel(true);
+ LOG.warn(
+ "Tool call '{}' timed out after {} seconds", toolName, toolCallTimeoutSeconds);
+ result.complete(
+ "Error: tool call '"
+ + toolName
+ + "' timed out after "
+ + toolCallTimeoutSeconds
+ + " seconds. "
+ + "Try simplifying the query or adding filters to reduce execution time.");
+ }
+ },
+ toolCallTimeoutSeconds,
+ TimeUnit.SECONDS);
+ result.whenComplete((r, e) -> timer.cancel(false));
+ return result;
}
private static ChatCompletionTool buildChatCompletionTool(AgentTool> tool) {
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputArgs.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputArgs.java
new file mode 100644
index 00000000000..b05a199090a
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputArgs.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.tool.output;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.annotation.JsonPropertyDescription;
+
+/** Args for {@link GrepToolOutputTool}. */
+public class GrepToolOutputArgs {
+
+ @JsonProperty(required = true)
+ @JsonPropertyDescription(
+ "Absolute path to the offloaded tool-output file, as reported by the truncation notice.")
+ public String path;
+
+ @JsonProperty(required = true)
+ @JsonPropertyDescription(
+ "Java regex pattern to search for. Matches are returned as ':'.")
+ public String pattern;
+
+ @JsonProperty("max_matches")
+ @JsonPropertyDescription("Maximum number of matching lines to return. Defaults to 50.")
+ public Integer maxMatches;
+}
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputTool.java
new file mode 100644
index 00000000000..f94fb21b315
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputTool.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.tool.output;
+
+import org.apache.kyuubi.engine.dataagent.runtime.ToolOutputStore;
+import org.apache.kyuubi.engine.dataagent.tool.AgentTool;
+import org.apache.kyuubi.engine.dataagent.tool.ToolContext;
+
+/**
+ * Regex-search a previously offloaded tool-output file. Companion to {@code
+ * ToolResultOffloadMiddleware}.
+ */
+public class GrepToolOutputTool implements AgentTool {
+
+ public static final String NAME = "grep_tool_output";
+ private static final int DEFAULT_MAX_MATCHES = 50;
+
+ private final ToolOutputStore store;
+
+ public GrepToolOutputTool(ToolOutputStore store) {
+ this.store = store;
+ }
+
+ @Override
+ public String name() {
+ return NAME;
+ }
+
+ @Override
+ public String description() {
+ return "Regex-search a previously offloaded tool-output file "
+ + "(the path is supplied in the truncation notice of a prior tool result). "
+ + "Cheaper than read_tool_output when you know what you're looking for. "
+ + "Returns matching lines as ':'.";
+ }
+
+ @Override
+ public Class argsType() {
+ return GrepToolOutputArgs.class;
+ }
+
+ @Override
+ public String execute(GrepToolOutputArgs args, ToolContext ctx) {
+ if (args == null || args.path == null || args.path.isEmpty()) {
+ return "Error: 'path' parameter is required.";
+ }
+ if (ctx == null || ctx.sessionId() == null) {
+ return "Error: grep_tool_output requires a session context.";
+ }
+ int max = args.maxMatches != null ? args.maxMatches : DEFAULT_MAX_MATCHES;
+ return store.grep(ctx.sessionId(), args.path, args.pattern, max);
+ }
+}
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputArgs.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputArgs.java
new file mode 100644
index 00000000000..458fbfa4f66
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputArgs.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.tool.output;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.annotation.JsonPropertyDescription;
+
+/** Args for {@link ReadToolOutputTool}. */
+public class ReadToolOutputArgs {
+
+ @JsonProperty(required = true)
+ @JsonPropertyDescription(
+ "Absolute path to the offloaded tool-output file, as reported by the truncation notice.")
+ public String path;
+
+ @JsonPropertyDescription("0-based line offset into the file. Defaults to 0.")
+ public Integer offset;
+
+ @JsonPropertyDescription(
+ "Number of lines to return starting at 'offset'. Defaults to 200; capped at 500.")
+ public Integer limit;
+}
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputTool.java
new file mode 100644
index 00000000000..93e837998f3
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputTool.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.tool.output;
+
+import org.apache.kyuubi.engine.dataagent.runtime.ToolOutputStore;
+import org.apache.kyuubi.engine.dataagent.tool.AgentTool;
+import org.apache.kyuubi.engine.dataagent.tool.ToolContext;
+
+/**
+ * Read a line window from a previously offloaded tool-output file. Companion to {@code
+ * ToolResultOffloadMiddleware}.
+ */
+public class ReadToolOutputTool implements AgentTool {
+
+ public static final String NAME = "read_tool_output";
+ private static final int DEFAULT_LIMIT = 200;
+ private static final int MAX_LIMIT = 500;
+
+ private final ToolOutputStore store;
+
+ public ReadToolOutputTool(ToolOutputStore store) {
+ this.store = store;
+ }
+
+ @Override
+ public String name() {
+ return NAME;
+ }
+
+ @Override
+ public String description() {
+ return "Read a line window from a previously offloaded tool-output file "
+ + "(the path is supplied in the truncation notice of a prior tool result). "
+ + "Returns '[lines X-Y of Z total]' header followed by the requested window. "
+ + "Use when a prior tool's output was truncated and you need to inspect more of it.";
+ }
+
+ @Override
+ public Class argsType() {
+ return ReadToolOutputArgs.class;
+ }
+
+ @Override
+ public String execute(ReadToolOutputArgs args, ToolContext ctx) {
+ if (args == null || args.path == null || args.path.isEmpty()) {
+ return "Error: 'path' parameter is required.";
+ }
+ if (ctx == null || ctx.sessionId() == null) {
+ return "Error: read_tool_output requires a session context.";
+ }
+ int offset = args.offset != null ? args.offset : 0;
+ int limit = args.limit != null ? args.limit : DEFAULT_LIMIT;
+ if (limit > MAX_LIMIT) limit = MAX_LIMIT;
+ return store.read(ctx.sessionId(), args.path, offset, limit);
+ }
+}
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryTool.java
index 06b12f2be72..88838ca477a 100644
--- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryTool.java
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryTool.java
@@ -19,6 +19,7 @@
import javax.sql.DataSource;
import org.apache.kyuubi.engine.dataagent.tool.AgentTool;
+import org.apache.kyuubi.engine.dataagent.tool.ToolContext;
import org.apache.kyuubi.engine.dataagent.tool.ToolRiskLevel;
/**
@@ -69,7 +70,7 @@ public Class argsType() {
}
@Override
- public String execute(SqlQueryArgs args) {
+ public String execute(SqlQueryArgs args, ToolContext ctx) {
return SqlExecutor.execute(dataSource, args.sql, queryTimeoutSeconds);
}
}
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryTool.java
index 0c57cc049ed..9136b0aa903 100644
--- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryTool.java
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryTool.java
@@ -19,6 +19,7 @@
import javax.sql.DataSource;
import org.apache.kyuubi.engine.dataagent.tool.AgentTool;
+import org.apache.kyuubi.engine.dataagent.tool.ToolContext;
import org.apache.kyuubi.engine.dataagent.tool.ToolRiskLevel;
/**
@@ -69,7 +70,7 @@ public Class argsType() {
}
@Override
- public String execute(SqlQueryArgs args) {
+ public String execute(SqlQueryArgs args, ToolContext ctx) {
String sql = args.sql;
if (sql == null || sql.trim().isEmpty()) {
return "Error: 'sql' parameter is required.";
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlReadOnlyChecker.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlReadOnlyChecker.java
index d2cac9ae518..52b83182add 100644
--- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlReadOnlyChecker.java
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlReadOnlyChecker.java
@@ -53,6 +53,7 @@ final class SqlReadOnlyChecker {
* {@code USE} — switch session catalog/database; not data-mutating
* {@code LIST} — Spark {@code LIST FILE} / {@code LIST JAR} inspection
* {@code HELP} — some engines expose interactive help
+ * {@code PRAGMA} — SQLite schema/metadata inspection
*
*/
private static final Set READ_ONLY_KEYWORDS =
@@ -70,7 +71,8 @@ final class SqlReadOnlyChecker {
"EXPLAIN",
"USE",
"LIST",
- "HELP")));
+ "HELP",
+ "PRAGMA")));
private SqlReadOnlyChecker() {}
diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/util/ConfUtils.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/util/ConfUtils.java
new file mode 100644
index 00000000000..f4366094670
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/util/ConfUtils.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.util;
+
+import org.apache.kyuubi.config.ConfigEntry;
+import org.apache.kyuubi.config.KyuubiConf;
+import org.apache.kyuubi.config.OptionalConfigEntry;
+
+/** Small helpers for reading typed values out of {@link KyuubiConf}. */
+public final class ConfUtils {
+
+ private ConfUtils() {}
+
+ /** Return the string value, or throw if the entry is not set. */
+ public static String requireString(KyuubiConf conf, OptionalConfigEntry key) {
+ scala.Option opt = conf.get(key);
+ if (opt.isEmpty()) {
+ throw new IllegalArgumentException(key.key() + " is required");
+ }
+ return opt.get();
+ }
+
+ /** Return the string value, or {@code null} if the entry is not set. */
+ public static String optionalString(KyuubiConf conf, OptionalConfigEntry key) {
+ scala.Option opt = conf.get(key);
+ return opt.isDefined() ? opt.get() : null;
+ }
+
+ /** Return the value for a raw key, or {@code null} if not set. */
+ public static String optionalString(KyuubiConf conf, String key) {
+ scala.Option opt = conf.getOption(key);
+ return opt.isDefined() ? opt.get() : null;
+ }
+
+ public static int intConf(KyuubiConf conf, ConfigEntry key) {
+ return ((Number) conf.get(key)).intValue();
+ }
+
+ public static long longConf(KyuubiConf conf, ConfigEntry key) {
+ return ((Number) conf.get(key)).longValue();
+ }
+
+ /** Read a millisecond-valued entry and return it as whole seconds. */
+ public static long millisAsSeconds(KyuubiConf conf, ConfigEntry key) {
+ return longConf(conf, key) / 1000L;
+ }
+}
diff --git a/externals/kyuubi-data-agent-engine/src/main/scala/org/apache/kyuubi/engine/dataagent/operation/ExecuteStatement.scala b/externals/kyuubi-data-agent-engine/src/main/scala/org/apache/kyuubi/engine/dataagent/operation/ExecuteStatement.scala
index 04d0defa2dc..3d902677a0a 100644
--- a/externals/kyuubi-data-agent-engine/src/main/scala/org/apache/kyuubi/engine/dataagent/operation/ExecuteStatement.scala
+++ b/externals/kyuubi-data-agent-engine/src/main/scala/org/apache/kyuubi/engine/dataagent/operation/ExecuteStatement.scala
@@ -25,7 +25,7 @@ import org.slf4j.MDC
import org.apache.kyuubi.{KyuubiSQLException, Logging}
import org.apache.kyuubi.config.KyuubiConf
import org.apache.kyuubi.engine.dataagent.provider.{DataAgentProvider, ProviderRunRequest}
-import org.apache.kyuubi.engine.dataagent.runtime.event.{AgentError, AgentEvent, AgentFinish, ApprovalRequest, ContentDelta, EventType, StepEnd, StepStart, ToolCall, ToolResult}
+import org.apache.kyuubi.engine.dataagent.runtime.event.{AgentError, AgentEvent, AgentFinish, ApprovalRequest, Compaction, ContentDelta, EventType, StepEnd, StepStart, ToolCall, ToolResult}
import org.apache.kyuubi.operation.OperationState
import org.apache.kyuubi.operation.log.OperationLog
import org.apache.kyuubi.session.Session
@@ -117,7 +117,7 @@ class ExecuteStatement(
n.put("type", sseType)
n.put("id", toolCall.toolCallId())
n.put("name", toolCall.toolName())
- n.set("args", JSON.valueToTree(toolCall.toolArgs()))
+ n.set[ObjectNode]("args", JSON.valueToTree(toolCall.toolArgs()))
}))
case EventType.TOOL_RESULT =>
val toolResult = event.asInstanceOf[ToolResult]
@@ -148,6 +148,15 @@ class ExecuteStatement(
n.set("args", JSON.valueToTree(req.toolArgs()))
n.put("riskLevel", req.riskLevel().name())
}))
+ case EventType.COMPACTION =>
+ val c = event.asInstanceOf[Compaction]
+ incrementalIter.append(Array(toJson { n =>
+ n.put("type", sseType)
+ n.put("summarized", c.summarizedCount())
+ n.put("kept", c.keptCount())
+ n.put("triggerTokens", c.triggerTokens())
+ n.put("observedTokens", c.observedTokens())
+ }))
case EventType.AGENT_FINISH =>
val finish = event.asInstanceOf[AgentFinish]
incrementalIter.append(Array(toJson { n =>
diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialectTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialectTest.java
index e728ff871e7..c43942a8f35 100644
--- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialectTest.java
+++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialectTest.java
@@ -19,6 +19,7 @@
import static org.junit.Assert.*;
+import org.apache.kyuubi.engine.dataagent.datasource.dialect.GenericDialect;
import org.junit.Test;
public class JdbcDialectTest {
diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java
index 4e713b9cca3..cc45ebdc7e7 100644
--- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java
+++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java
@@ -20,8 +20,9 @@
import static org.junit.Assert.*;
import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect;
-import org.apache.kyuubi.engine.dataagent.datasource.MysqlDialect;
+import org.apache.kyuubi.engine.dataagent.datasource.dialect.MysqlDialect;
import org.apache.kyuubi.engine.dataagent.prompt.SystemPromptBuilder;
+import org.apache.kyuubi.engine.dataagent.tool.ToolContext;
import org.apache.kyuubi.engine.dataagent.tool.sql.RunSelectQueryTool;
import org.apache.kyuubi.engine.dataagent.tool.sql.SqlQueryArgs;
import org.junit.BeforeClass;
@@ -66,7 +67,7 @@ public void testBacktickQuotingWithReservedWord() {
SqlQueryArgs args = new SqlQueryArgs();
args.sql = "SELECT " + quotedCol + " FROM " + quotedTable + " WHERE id = 1";
- String result = selectTool.execute(args);
+ String result = selectTool.execute(args, ToolContext.EMPTY);
assertFalse(result.startsWith("Error:"));
assertTrue(result.contains("value1"));
diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/provider/mock/MockLlmProvider.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/provider/mock/MockLlmProvider.java
new file mode 100644
index 00000000000..2756b3e2087
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/provider/mock/MockLlmProvider.java
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.provider.mock;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.Consumer;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+import javax.sql.DataSource;
+import org.apache.kyuubi.config.KyuubiConf;
+import org.apache.kyuubi.engine.dataagent.datasource.DataSourceFactory;
+import org.apache.kyuubi.engine.dataagent.provider.DataAgentProvider;
+import org.apache.kyuubi.engine.dataagent.provider.ProviderRunRequest;
+import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent;
+import org.apache.kyuubi.engine.dataagent.runtime.event.AgentFinish;
+import org.apache.kyuubi.engine.dataagent.runtime.event.AgentStart;
+import org.apache.kyuubi.engine.dataagent.runtime.event.ContentComplete;
+import org.apache.kyuubi.engine.dataagent.runtime.event.ContentDelta;
+import org.apache.kyuubi.engine.dataagent.runtime.event.StepEnd;
+import org.apache.kyuubi.engine.dataagent.runtime.event.StepStart;
+import org.apache.kyuubi.engine.dataagent.runtime.event.ToolCall;
+import org.apache.kyuubi.engine.dataagent.runtime.event.ToolResult;
+import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry;
+import org.apache.kyuubi.engine.dataagent.tool.sql.RunSelectQueryTool;
+import org.apache.kyuubi.engine.dataagent.util.ConfUtils;
+
+/**
+ * A mock LLM provider for testing the full tool-call pipeline without a real LLM. Simulates the
+ * ReAct loop: extracts SQL from the user question, executes it via SqlQueryTool, and returns the
+ * result as a formatted answer.
+ *
+ * Recognizes two patterns:
+ *
+ *
+ * Questions containing SQL keywords (SELECT, SHOW, DESCRIBE) — extracts and executes the SQL
+ * All other questions — returns a canned response without tool calls
+ *
+ */
+public class MockLlmProvider implements DataAgentProvider {
+
+ private static final Pattern SQL_PATTERN =
+ Pattern.compile(
+ "(SELECT\\b.+|SHOW\\b.+|DESCRIBE\\b.+)", Pattern.CASE_INSENSITIVE | Pattern.DOTALL);
+
+ /**
+ * Simple natural-language-to-SQL mappings so tests can use human-readable questions instead of
+ * raw SQL. Checked before the regex pattern — if a question matches a key (case-insensitive
+ * prefix), the mapped SQL is executed.
+ */
+ private static final Map NL_TO_SQL = new java.util.LinkedHashMap<>();
+
+ static {
+ NL_TO_SQL.put(
+ "list all employee names and departments",
+ "SELECT name, department FROM employees ORDER BY id");
+ NL_TO_SQL.put(
+ "how many employees in each department",
+ "SELECT department, COUNT(*) as cnt FROM employees GROUP BY department");
+ NL_TO_SQL.put("count the total number of employees", "SELECT COUNT(*) FROM employees");
+ }
+
+ private final ConcurrentHashMap sessions = new ConcurrentHashMap<>();
+ private final ToolRegistry toolRegistry;
+ private final DataSource dataSource;
+
+ public MockLlmProvider(KyuubiConf conf) {
+ String jdbcUrl = ConfUtils.requireString(conf, KyuubiConf.ENGINE_DATA_AGENT_JDBC_URL());
+ this.dataSource = DataSourceFactory.create(jdbcUrl);
+ this.toolRegistry = new ToolRegistry(30);
+ this.toolRegistry.register(new RunSelectQueryTool(dataSource, 0));
+ }
+
+ @Override
+ public void open(String sessionId, String user) {
+ sessions.put(sessionId, new Object());
+ }
+
+ @Override
+ public void run(String sessionId, ProviderRunRequest request, Consumer onEvent) {
+ String question = request.getQuestion();
+ onEvent.accept(new AgentStart());
+
+ // Trigger an error for testing the error path in ExecuteStatement
+ if (question.trim().equalsIgnoreCase("__error__")) {
+ throw new RuntimeException("MockLlmProvider simulated failure");
+ }
+
+ // First check natural-language mappings, then fall back to SQL pattern extraction
+ String sql = resolveToSql(question);
+ if (sql != null) {
+ runWithToolCall(sql, onEvent);
+ } else {
+ runWithoutToolCall(question, onEvent);
+ }
+ }
+
+ private void runWithToolCall(String sql, Consumer onEvent) {
+ // Step 1: LLM "decides" to call sql_query tool
+ onEvent.accept(new StepStart(1));
+ String toolCallId = "mock_call_" + System.nanoTime();
+ Map toolArgs = new HashMap<>();
+ toolArgs.put("sql", sql);
+ onEvent.accept(new ToolCall(toolCallId, "run_select_query", toolArgs));
+
+ // Execute the tool
+ String toolOutput =
+ toolRegistry.executeTool("run_select_query", "{\"sql\":\"" + escapeJson(sql) + "\"}");
+ onEvent.accept(new ToolResult(toolCallId, "run_select_query", toolOutput, false));
+ onEvent.accept(new StepEnd(1));
+
+ // Step 2: LLM "summarizes" the result
+ onEvent.accept(new StepStart(2));
+ String answer = "Based on the query result:\n\n" + toolOutput;
+ for (String token : answer.split("(?<=\\n)")) {
+ onEvent.accept(new ContentDelta(token));
+ }
+ onEvent.accept(new ContentComplete(answer));
+ onEvent.accept(new StepEnd(2));
+ onEvent.accept(new AgentFinish(2, 100, 50, 150));
+ }
+
+ private void runWithoutToolCall(String question, Consumer onEvent) {
+ onEvent.accept(new StepStart(1));
+ String answer = "[MockLLM] No SQL detected in: " + question;
+ onEvent.accept(new ContentDelta(answer));
+ onEvent.accept(new ContentComplete(answer));
+ onEvent.accept(new StepEnd(1));
+ onEvent.accept(new AgentFinish(1, 50, 20, 70));
+ }
+
+ @Override
+ public void close(String sessionId) {
+ sessions.remove(sessionId);
+ }
+
+ @Override
+ public void stop() {
+ if (dataSource instanceof com.zaxxer.hikari.HikariDataSource) {
+ ((com.zaxxer.hikari.HikariDataSource) dataSource).close();
+ }
+ }
+
+ /**
+ * Resolve a user question to SQL. Checks NL_TO_SQL mappings first, then falls back to regex
+ * extraction of raw SQL from the input.
+ */
+ private static String resolveToSql(String question) {
+ String lower = question.toLowerCase().trim();
+ for (Map.Entry entry : NL_TO_SQL.entrySet()) {
+ if (lower.startsWith(entry.getKey())) {
+ return entry.getValue();
+ }
+ }
+ Matcher matcher = SQL_PATTERN.matcher(question);
+ if (matcher.find()) {
+ return matcher.group(1).trim();
+ }
+ return null;
+ }
+
+ private static String escapeJson(String s) {
+ return s.replace("\\", "\\\\")
+ .replace("\"", "\\\"")
+ .replace("\n", "\\n")
+ .replace("\r", "\\r")
+ .replace("\t", "\\t");
+ }
+}
diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemoryTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemoryTest.java
new file mode 100644
index 00000000000..d8eb96cfbc1
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemoryTest.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.runtime;
+
+import static org.junit.Assert.*;
+
+import com.openai.models.chat.completions.ChatCompletionMessageParam;
+import com.openai.models.chat.completions.ChatCompletionUserMessageParam;
+import java.util.Collections;
+import org.junit.Test;
+
+public class ConversationMemoryTest {
+
+ @Test
+ public void testReplaceHistoryClearsLastTotalTokensButKeepsCumulative() {
+ ConversationMemory memory = new ConversationMemory();
+ memory.addCumulativeTokens(100, 50, 150);
+ memory.addCumulativeTokens(200, 80, 280);
+ assertEquals(280L, memory.getLastTotalTokens());
+ assertEquals(300L, memory.getCumulativePromptTokens());
+ assertEquals(430L, memory.getCumulativeTotalTokens());
+
+ ChatCompletionMessageParam summary =
+ ChatCompletionMessageParam.ofUser(
+ ChatCompletionUserMessageParam.builder().content("summary").build());
+ memory.replaceHistory(Collections.singletonList(summary));
+
+ assertEquals("lastTotalTokens reset after compaction", 0L, memory.getLastTotalTokens());
+ assertEquals("cumulative totals preserved", 300L, memory.getCumulativePromptTokens());
+ assertEquals("cumulative totals preserved", 430L, memory.getCumulativeTotalTokens());
+ }
+}
diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgentLiveTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgentLiveTest.java
new file mode 100644
index 00000000000..75553f46998
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgentLiveTest.java
@@ -0,0 +1,568 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.runtime;
+
+import static org.junit.Assert.*;
+import static org.junit.Assume.assumeTrue;
+
+import com.openai.client.OpenAIClient;
+import com.openai.client.okhttp.OpenAIOkHttpClient;
+import java.io.File;
+import java.sql.Connection;
+import java.sql.Statement;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.CopyOnWriteArrayList;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
+import java.util.stream.Collectors;
+import org.apache.kyuubi.engine.dataagent.prompt.SystemPromptBuilder;
+import org.apache.kyuubi.engine.dataagent.runtime.event.*;
+import org.apache.kyuubi.engine.dataagent.runtime.middleware.ApprovalMiddleware;
+import org.apache.kyuubi.engine.dataagent.runtime.middleware.LoggingMiddleware;
+import org.apache.kyuubi.engine.dataagent.runtime.middleware.ToolResultOffloadMiddleware;
+import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry;
+import org.apache.kyuubi.engine.dataagent.tool.sql.RunMutationQueryTool;
+import org.apache.kyuubi.engine.dataagent.tool.sql.RunSelectQueryTool;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.sqlite.SQLiteDataSource;
+
+/**
+ * Live integration test with a real LLM and real SQLite database. Exercises the full ReAct loop:
+ * LLM reasoning -> tool calls -> result verification. Requires DATA_AGENT_LLM_API_KEY and
+ * DATA_AGENT_LLM_API_URL environment variables. Works with any OpenAI-compatible LLM service.
+ */
+public class ReactAgentLiveTest {
+
+ private static final String API_KEY = System.getenv().getOrDefault("DATA_AGENT_LLM_API_KEY", "");
+ private static final String BASE_URL = System.getenv().getOrDefault("DATA_AGENT_LLM_API_URL", "");
+ private static final String MODEL_NAME =
+ System.getenv().getOrDefault("DATA_AGENT_LLM_MODEL", "gpt-4o");
+
+ private static final String SYSTEM_PROMPT =
+ SystemPromptBuilder.create().datasource("sqlite").build();
+
+ private final List tempFiles = new ArrayList<>();
+ private OpenAIClient client;
+
+ @Before
+ public void setUp() {
+ assumeTrue("DATA_AGENT_LLM_API_KEY not set, skipping live tests", !API_KEY.isEmpty());
+ assumeTrue("DATA_AGENT_LLM_API_URL not set, skipping live tests", !BASE_URL.isEmpty());
+ client = OpenAIOkHttpClient.builder().apiKey(API_KEY).baseUrl(BASE_URL).build();
+ }
+
+ @After
+ public void tearDown() {
+ tempFiles.forEach(File::delete);
+ }
+
+ @Test
+ public void testPlainTextStreamingWithoutTools() {
+ ReactAgent agent =
+ ReactAgent.builder()
+ .client(client)
+ .modelName(MODEL_NAME)
+ .toolRegistry(new ToolRegistry(30))
+ .addMiddleware(new LoggingMiddleware())
+ .maxIterations(3)
+ .systemPrompt("You are a helpful assistant. Answer concisely in 1-2 sentences.")
+ .build();
+
+ List events = new CopyOnWriteArrayList<>();
+ ConversationMemory memory = new ConversationMemory();
+
+ agent.run(new AgentInvocation("What is Apache Kyuubi?"), memory, events::add);
+
+ List deltas =
+ events.stream()
+ .filter(e -> e instanceof ContentDelta)
+ .map(e -> ((ContentDelta) e).text())
+ .collect(Collectors.toList());
+
+ assertTrue("Expected multiple ContentDelta events", deltas.size() > 1);
+ assertFalse("Streamed text should not be empty", String.join("", deltas).isEmpty());
+ assertTrue(events.stream().anyMatch(e -> e instanceof StepStart));
+ assertTrue(events.stream().anyMatch(e -> e instanceof ContentComplete));
+ assertTrue(events.get(events.size() - 1) instanceof AgentFinish);
+ assertEquals(2, memory.getHistory().size()); // user + assistant
+ }
+
+ @Test
+ public void testFullReActLoopWithSchemaInspectAndSqlQuery() {
+ SQLiteDataSource ds = createSalesDatabase();
+ ToolRegistry registry = new ToolRegistry(30);
+ registry.register(new RunSelectQueryTool(ds, 0));
+
+ ReactAgent agent =
+ ReactAgent.builder()
+ .client(client)
+ .modelName(MODEL_NAME)
+ .toolRegistry(registry)
+ .addMiddleware(new LoggingMiddleware())
+ .maxIterations(10)
+ .systemPrompt(SYSTEM_PROMPT)
+ .build();
+
+ List events = new CopyOnWriteArrayList<>();
+ ConversationMemory memory = new ConversationMemory();
+
+ agent.run(
+ new AgentInvocation(
+ "What is the total revenue by product category? Which category has the highest revenue?"),
+ memory,
+ events::add);
+
+ // Verify tool calls happened
+ List toolCalls =
+ events.stream()
+ .filter(e -> e instanceof ToolCall)
+ .map(e -> (ToolCall) e)
+ .collect(Collectors.toList());
+ List toolResults =
+ events.stream()
+ .filter(e -> e instanceof ToolResult)
+ .map(e -> (ToolResult) e)
+ .collect(Collectors.toList());
+
+ assertFalse("Agent should have called at least one tool", toolCalls.isEmpty());
+ assertFalse("Agent should have received tool results", toolResults.isEmpty());
+ assertTrue("Tool calls should not error", toolResults.stream().noneMatch(ToolResult::isError));
+
+ // Verify SQL query was called
+ assertTrue(
+ "Agent should execute at least one SQL query",
+ toolCalls.stream().anyMatch(tc -> "run_select_query".equals(tc.toolName())));
+
+ // Verify final answer mentions "Electronics" (highest revenue)
+ List completions =
+ events.stream()
+ .filter(e -> e instanceof ContentComplete)
+ .map(e -> ((ContentComplete) e).fullText())
+ .collect(Collectors.toList());
+ String lastAnswer = completions.get(completions.size() - 1);
+ assertTrue(
+ "Final answer should mention Electronics, got: " + lastAnswer,
+ lastAnswer.toLowerCase().contains("electronics"));
+
+ // Verify agent finished successfully
+ AgentEvent last = events.get(events.size() - 1);
+ assertTrue(last instanceof AgentFinish);
+ assertTrue("Should take multiple steps", ((AgentFinish) last).totalSteps() > 1);
+ }
+
+ @Test
+ public void testMultiTurnConversationWithToolUse() {
+ SQLiteDataSource ds = createSalesDatabase();
+ ToolRegistry registry = new ToolRegistry(30);
+ registry.register(new RunSelectQueryTool(ds, 0));
+
+ ReactAgent agent =
+ ReactAgent.builder()
+ .client(client)
+ .modelName(MODEL_NAME)
+ .toolRegistry(registry)
+ .addMiddleware(new LoggingMiddleware())
+ .maxIterations(10)
+ .systemPrompt(SYSTEM_PROMPT)
+ .build();
+
+ // Shared memory across turns
+ ConversationMemory memory = new ConversationMemory();
+
+ // Turn 1
+ List events1 = new CopyOnWriteArrayList<>();
+ agent.run(new AgentInvocation("How many orders are there in total?"), memory, events1::add);
+
+ assertTrue(
+ "Turn 1 should query the database",
+ events1.stream()
+ .anyMatch(
+ e ->
+ e instanceof ToolCall && "run_select_query".equals(((ToolCall) e).toolName())));
+ assertTrue(events1.get(events1.size() - 1) instanceof AgentFinish);
+
+ // Turn 2: follow-up relying on conversation context
+ List events2 = new CopyOnWriteArrayList<>();
+ agent.run(
+ new AgentInvocation("Now show me only orders above 500 dollars."), memory, events2::add);
+
+ assertTrue(
+ "Turn 2 should also query the database",
+ events2.stream()
+ .anyMatch(
+ e ->
+ e instanceof ToolCall && "run_select_query".equals(((ToolCall) e).toolName())));
+ assertTrue(events2.get(events2.size() - 1) instanceof AgentFinish);
+
+ // Verify memory accumulated across both turns
+ assertTrue(
+ "Memory should contain messages from both turns, got " + memory.getHistory().size(),
+ memory.getHistory().size() > 4);
+ }
+
+ @Test
+ public void testToolOutputOffloadThenGrep() throws Exception {
+ // Large result forces ToolResultOffloadMiddleware to truncate the tool output and
+ // emit a preview hint telling the LLM to use grep_tool_output / read_tool_output.
+ // A correct answer proves the LLM read the hint and drove the retrieval itself.
+ SQLiteDataSource ds = createNeedleInHaystackDatabase();
+ ToolRegistry registry = new ToolRegistry(30);
+ registry.register(new RunSelectQueryTool(ds, 0));
+
+ ReactAgent agent =
+ ReactAgent.builder()
+ .client(client)
+ .modelName(MODEL_NAME)
+ .toolRegistry(registry)
+ .addMiddleware(new LoggingMiddleware())
+ .addMiddleware(new ToolResultOffloadMiddleware())
+ .maxIterations(10)
+ .systemPrompt(SYSTEM_PROMPT)
+ .build();
+
+ List events = new CopyOnWriteArrayList<>();
+ ConversationMemory memory = new ConversationMemory();
+
+ // Explicit workflow: full-table scan first, then retrieval tool. Without this hint a
+ // capable LLM just issues SELECT note FROM events WHERE tag='NEEDLE' and skips the
+ // offload path entirely -- smart behavior, but defeats the purpose of this test.
+ agent.run(
+ new AgentInvocation(
+ "Step 1: issue exactly this query: SELECT id, tag, note FROM events (no"
+ + " WHERE clause, return every row). Step 2: the result will be"
+ + " truncated; call the grep_tool_output tool with pattern 'NEEDLE' on"
+ + " the saved output file to find the matching row. Step 3: respond with"
+ + " ONLY the note text from that row, nothing else. Do NOT add a WHERE"
+ + " clause. Do NOT issue any other SQL.")
+ // Offload middleware requires a non-null session id -- without it the offload
+ // path skips entirely (see ToolResultOffloadMiddleware.afterToolCall).
+ .sessionId("offload-live-" + java.util.UUID.randomUUID()),
+ memory,
+ events::add);
+
+ List toolCalls =
+ events.stream()
+ .filter(e -> e instanceof ToolCall)
+ .map(e -> (ToolCall) e)
+ .collect(Collectors.toList());
+ List toolResults =
+ events.stream()
+ .filter(e -> e instanceof ToolResult)
+ .map(e -> (ToolResult) e)
+ .collect(Collectors.toList());
+
+ // Dump the tool trace -- diagnoses whether the LLM followed the workflow, added a
+ // WHERE clause despite instructions, or deviated in some other way.
+ for (ToolCall tc : toolCalls) {
+ System.out.println("[ToolCall] " + tc.toolName() + " args=" + tc.toolArgs());
+ }
+ for (ToolResult tr : toolResults) {
+ String out = tr.output();
+ String preview =
+ out.length() > 300
+ ? out.substring(0, 300) + "...(+" + (out.length() - 300) + " chars)"
+ : out;
+ System.out.println("[ToolResult] " + tr.toolName() + " -> " + preview);
+ }
+
+ assertTrue(
+ "Agent should have run a select query first",
+ toolCalls.stream().anyMatch(tc -> "run_select_query".equals(tc.toolName())));
+
+ // The SELECT returned 800 rows, which must trip the offload threshold.
+ assertTrue(
+ "Expected at least one offload preview marker in tool results",
+ toolResults.stream().anyMatch(tr -> tr.output().contains("Tool output truncated")));
+
+ assertTrue(
+ "Agent should have used grep_tool_output or read_tool_output after seeing"
+ + " the offload preview; actual tool calls: "
+ + toolCalls.stream().map(ToolCall::toolName).collect(Collectors.toList()),
+ toolCalls.stream()
+ .anyMatch(
+ tc ->
+ "grep_tool_output".equals(tc.toolName())
+ || "read_tool_output".equals(tc.toolName())));
+
+ String finalAnswer =
+ events.stream()
+ .filter(e -> e instanceof ContentComplete)
+ .map(e -> ((ContentComplete) e).fullText())
+ .reduce((a, b) -> b)
+ .orElse("");
+ assertTrue(
+ "Final answer should contain the needle note 'the-answer-is-42', got: " + finalAnswer,
+ finalAnswer.contains("the-answer-is-42"));
+ }
+
+ @Test
+ public void testApprovalApproveFlow() throws Exception {
+ // Real LLM picks run_mutation_query (DESTRUCTIVE) -> ApprovalMiddleware pauses ->
+ // background thread resolves(approved=true) -> mutation actually runs in SQLite.
+ SQLiteDataSource ds = createCountersDatabase();
+ ToolRegistry registry = new ToolRegistry(30);
+ registry.register(new RunSelectQueryTool(ds, 0));
+ registry.register(new RunMutationQueryTool(ds, 0));
+
+ ApprovalMiddleware approval = new ApprovalMiddleware(30);
+
+ ReactAgent agent =
+ ReactAgent.builder()
+ .client(client)
+ .modelName(MODEL_NAME)
+ .toolRegistry(registry)
+ .addMiddleware(new LoggingMiddleware())
+ .addMiddleware(approval)
+ .maxIterations(10)
+ .systemPrompt(SYSTEM_PROMPT)
+ .build();
+
+ List events = new CopyOnWriteArrayList<>();
+ ConversationMemory memory = new ConversationMemory();
+
+ // Auto-approve any approval request that shows up.
+ ExecutorService approver = Executors.newSingleThreadExecutor();
+ Consumer listener =
+ event -> {
+ events.add(event);
+ if (event instanceof ApprovalRequest) {
+ String rid = ((ApprovalRequest) event).requestId();
+ approver.submit(() -> approval.resolve(rid, true));
+ }
+ };
+
+ try {
+ agent.run(
+ new AgentInvocation(
+ "Increment the 'hits' counter in the counters table by 1, then tell me its"
+ + " new value. Respond with ONLY the new value, no explanation."),
+ memory,
+ listener);
+ } finally {
+ approver.shutdown();
+ approver.awaitTermination(5, TimeUnit.SECONDS);
+ }
+
+ List approvals =
+ events.stream()
+ .filter(e -> e instanceof ApprovalRequest)
+ .map(e -> (ApprovalRequest) e)
+ .collect(Collectors.toList());
+ assertFalse("Expected at least one ApprovalRequest", approvals.isEmpty());
+ assertTrue(
+ "ApprovalRequest should target run_mutation_query",
+ approvals.stream().anyMatch(a -> "run_mutation_query".equals(a.toolName())));
+
+ // Mutation must have actually executed — check the DB directly.
+ try (Connection conn = ds.getConnection();
+ Statement stmt = conn.createStatement();
+ java.sql.ResultSet rs = stmt.executeQuery("SELECT value FROM counters WHERE name='hits'")) {
+ assertTrue(rs.next());
+ assertEquals("Counter should be 1 after approved mutation", 1, rs.getInt(1));
+ }
+
+ String finalAnswer =
+ events.stream()
+ .filter(e -> e instanceof ContentComplete)
+ .map(e -> ((ContentComplete) e).fullText())
+ .reduce((a, b) -> b)
+ .orElse("");
+ assertTrue("Final answer should mention 1, got: " + finalAnswer, finalAnswer.contains("1"));
+ }
+
+ @Test
+ public void testApprovalDenyFlow() throws Exception {
+ // Same setup as the approve test, but the approval listener denies. The mutation
+ // must NOT run, and the LLM must surface the denial to the user naturally.
+ SQLiteDataSource ds = createCountersDatabase();
+ ToolRegistry registry = new ToolRegistry(30);
+ registry.register(new RunSelectQueryTool(ds, 0));
+ registry.register(new RunMutationQueryTool(ds, 0));
+
+ ApprovalMiddleware approval = new ApprovalMiddleware(30);
+
+ ReactAgent agent =
+ ReactAgent.builder()
+ .client(client)
+ .modelName(MODEL_NAME)
+ .toolRegistry(registry)
+ .addMiddleware(new LoggingMiddleware())
+ .addMiddleware(approval)
+ .maxIterations(10)
+ .systemPrompt(SYSTEM_PROMPT)
+ .build();
+
+ List events = new CopyOnWriteArrayList<>();
+ ConversationMemory memory = new ConversationMemory();
+
+ ExecutorService approver = Executors.newSingleThreadExecutor();
+ Consumer listener =
+ event -> {
+ events.add(event);
+ if (event instanceof ApprovalRequest) {
+ String rid = ((ApprovalRequest) event).requestId();
+ approver.submit(() -> approval.resolve(rid, false));
+ }
+ };
+
+ try {
+ agent.run(
+ new AgentInvocation(
+ "Delete all rows from the counters table. If you cannot, explain why."),
+ memory,
+ listener);
+ } finally {
+ approver.shutdown();
+ approver.awaitTermination(5, TimeUnit.SECONDS);
+ }
+
+ assertTrue(
+ "Expected at least one ApprovalRequest",
+ events.stream().anyMatch(e -> e instanceof ApprovalRequest));
+
+ // DB must be untouched.
+ try (Connection conn = ds.getConnection();
+ Statement stmt = conn.createStatement();
+ java.sql.ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM counters")) {
+ assertTrue(rs.next());
+ assertTrue("Counters rows must survive denied mutation", rs.getInt(1) > 0);
+ }
+
+ // LLM should tell the user the operation didn't go through. Loose lexical check to
+ // absorb model wording drift.
+ String finalAnswer =
+ events.stream()
+ .filter(e -> e instanceof ContentComplete)
+ .map(e -> ((ContentComplete) e).fullText())
+ .reduce((a, b) -> b)
+ .orElse("")
+ .toLowerCase();
+ assertTrue(
+ "Final answer should indicate the deletion was refused/denied/not executed, got: "
+ + finalAnswer,
+ finalAnswer.contains("den")
+ || finalAnswer.contains("reject")
+ || finalAnswer.contains("not ")
+ || finalAnswer.contains("refus")
+ || finalAnswer.contains("unable")
+ || finalAnswer.contains("could not"));
+ }
+
+ // --- Helpers ---
+
+ private SQLiteDataSource createSalesDatabase() {
+ SQLiteDataSource ds = createDataSource();
+ try (Connection conn = ds.getConnection();
+ Statement stmt = conn.createStatement()) {
+ stmt.execute(
+ "CREATE TABLE products ("
+ + "id INTEGER PRIMARY KEY, name TEXT NOT NULL, "
+ + "category TEXT NOT NULL, price REAL NOT NULL)");
+ stmt.execute(
+ "INSERT INTO products VALUES "
+ + "(1, 'Laptop', 'Electronics', 999.99), "
+ + "(2, 'Headphones', 'Electronics', 199.99), "
+ + "(3, 'T-Shirt', 'Clothing', 29.99), "
+ + "(4, 'Jeans', 'Clothing', 59.99), "
+ + "(5, 'Novel', 'Books', 14.99), "
+ + "(6, 'Textbook', 'Books', 89.99)");
+ stmt.execute(
+ "CREATE TABLE orders ("
+ + "id INTEGER PRIMARY KEY, product_id INTEGER NOT NULL, "
+ + "customer_name TEXT NOT NULL, quantity INTEGER NOT NULL, "
+ + "order_date TEXT NOT NULL, "
+ + "FOREIGN KEY (product_id) REFERENCES products(id))");
+ stmt.execute(
+ "INSERT INTO orders VALUES "
+ + "(1, 1, 'Alice', 1, '2024-01-15'), "
+ + "(2, 2, 'Bob', 2, '2024-01-20'), "
+ + "(3, 3, 'Charlie', 3, '2024-02-01'), "
+ + "(4, 4, 'Alice', 1, '2024-02-10'), "
+ + "(5, 5, 'Bob', 5, '2024-02-15'), "
+ + "(6, 1, 'Diana', 1, '2024-03-01'), "
+ + "(7, 6, 'Charlie', 2, '2024-03-05'), "
+ + "(8, 2, 'Diana', 1, '2024-03-10')");
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ return ds;
+ }
+
+ private SQLiteDataSource createNeedleInHaystackDatabase() {
+ SQLiteDataSource ds = createDataSource();
+ try (Connection conn = ds.getConnection();
+ Statement stmt = conn.createStatement()) {
+ stmt.execute(
+ "CREATE TABLE events ("
+ + "id INTEGER PRIMARY KEY, tag TEXT NOT NULL, note TEXT NOT NULL)");
+ conn.setAutoCommit(false);
+ try (java.sql.PreparedStatement ps =
+ conn.prepareStatement("INSERT INTO events VALUES (?, ?, ?)")) {
+ // 800 filler rows, guaranteed to blow past ToolResultOffloadMiddleware's 500-line
+ // threshold when the LLM issues a SELECT *. Exactly one row carries the NEEDLE tag.
+ int needleId = 573;
+ for (int i = 1; i <= 800; i++) {
+ ps.setInt(1, i);
+ if (i == needleId) {
+ ps.setString(2, "NEEDLE");
+ ps.setString(3, "the-answer-is-42");
+ } else {
+ ps.setString(2, "FILLER");
+ ps.setString(3, "filler-note-" + i);
+ }
+ ps.addBatch();
+ }
+ ps.executeBatch();
+ }
+ conn.commit();
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ return ds;
+ }
+
+ private SQLiteDataSource createCountersDatabase() {
+ SQLiteDataSource ds = createDataSource();
+ try (Connection conn = ds.getConnection();
+ Statement stmt = conn.createStatement()) {
+ stmt.execute("CREATE TABLE counters (name TEXT PRIMARY KEY, value INTEGER NOT NULL)");
+ stmt.execute("INSERT INTO counters VALUES ('hits', 0)");
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ return ds;
+ }
+
+ private SQLiteDataSource createDataSource() {
+ try {
+ File tmpFile = File.createTempFile("kyuubi-agent-live-", ".db");
+ tmpFile.deleteOnExit();
+ tempFiles.add(tmpFile);
+ SQLiteDataSource ds = new SQLiteDataSource();
+ ds.setUrl("jdbc:sqlite:" + tmpFile.getAbsolutePath());
+ return ds;
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStoreTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStoreTest.java
new file mode 100644
index 00000000000..e3eda5ff5e4
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStoreTest.java
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.runtime;
+
+import static org.junit.Assert.*;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import org.junit.Before;
+import org.junit.Test;
+
+public class ToolOutputStoreTest {
+
+ private ToolOutputStore store;
+
+ @Before
+ public void setUp() throws IOException {
+ store = ToolOutputStore.create();
+ }
+
+ @Test
+ public void writeAndReadWindow() throws IOException {
+ StringBuilder sb = new StringBuilder();
+ for (int i = 1; i <= 100; i++) sb.append("row").append(i).append('\n');
+ Path p = store.write("sess1", "call1", sb.toString());
+ assertTrue(Files.exists(p));
+
+ String out = store.read("sess1", p.toString(), 10, 5);
+ assertTrue(out, out.contains("lines 11-15 of"));
+ assertTrue(out, out.contains("row11"));
+ assertTrue(out, out.contains("row15"));
+ assertFalse(out, out.contains("row16"));
+ assertFalse(out, out.contains("row10"));
+ }
+
+ @Test
+ public void grepReturnsMatchingLinesWithLineNumbers() throws IOException {
+ String content = "apple\nbanana\ncherry\napple pie\ndate\n";
+ Path p = store.write("sess1", "call1", content);
+
+ String out = store.grep("sess1", p.toString(), "apple", 10);
+ assertTrue(out, out.contains("1:apple"));
+ assertTrue(out, out.contains("4:apple pie"));
+ assertFalse(out, out.contains("banana"));
+ }
+
+ @Test
+ public void grepRespectsMaxMatches() throws IOException {
+ StringBuilder sb = new StringBuilder();
+ for (int i = 0; i < 20; i++) sb.append("hit\n");
+ Path p = store.write("sess1", "call1", sb.toString());
+
+ String out = store.grep("sess1", p.toString(), "hit", 3);
+ assertTrue(out, out.contains("[3 matches]"));
+ assertTrue(out, out.contains("1:hit"));
+ assertTrue(out, out.contains("3:hit"));
+ assertFalse("should stop after 3 matches", out.contains("4:hit"));
+ }
+
+ @Test
+ public void grepInvalidRegexReturnsError() throws IOException {
+ Path p = store.write("sess1", "call1", "x\n");
+ String out = store.grep("sess1", p.toString(), "[", 10);
+ assertTrue(out, out.startsWith("Error:"));
+ }
+
+ @Test
+ public void readRejectsCrossSessionPath() throws IOException {
+ Path victim = store.write("victim", "secret_call", "top secret\n");
+ assertTrue(Files.exists(victim));
+
+ String out = store.read("attacker", victim.toString(), 0, 10);
+ assertTrue(out, out.startsWith("Error:"));
+ assertFalse(out, out.contains("top secret"));
+ }
+
+ @Test
+ public void grepRejectsCrossSessionPath() throws IOException {
+ Path victim = store.write("victim", "secret_call", "api_key=xyz\n");
+ String out = store.grep("attacker", victim.toString(), "api_key", 10);
+ assertTrue(out, out.startsWith("Error:"));
+ assertFalse(out, out.contains("xyz"));
+ }
+
+ @Test
+ public void cleanupSessionRemovesSubtree() throws IOException {
+ Path p1 = store.write("sessA", "call1", "a\n");
+ Path p2 = store.write("sessA", "call2", "b\n");
+ Path p3 = store.write("sessB", "call1", "c\n");
+ assertTrue(Files.exists(p1));
+ assertTrue(Files.exists(p2));
+ assertTrue(Files.exists(p3));
+
+ store.cleanupSession("sessA");
+
+ assertFalse(Files.exists(p1));
+ assertFalse(Files.exists(p2));
+ assertTrue("other sessions untouched", Files.exists(p3));
+ }
+}
diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventTest.java
index 50d22da416d..b6a5f093b61 100644
--- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventTest.java
+++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventTest.java
@@ -113,6 +113,7 @@ public void testEventTypeSseNames() {
assertEquals("step_end", EventType.STEP_END.sseEventName());
assertEquals("error", EventType.ERROR.sseEventName());
assertEquals("approval_request", EventType.APPROVAL_REQUEST.sseEventName());
+ assertEquals("compaction", EventType.COMPACTION.sseEventName());
assertEquals("agent_finish", EventType.AGENT_FINISH.sseEventName());
}
@@ -123,6 +124,6 @@ public void testAllEventTypesHaveUniqueSseNames() {
for (EventType type : values) {
assertTrue("Duplicate SSE name: " + type.sseEventName(), names.add(type.sseEventName()));
}
- assertEquals(10, values.length);
+ assertEquals(11, values.length);
}
}
diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddlewareTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddlewareTest.java
new file mode 100644
index 00000000000..a84bbc25948
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddlewareTest.java
@@ -0,0 +1,294 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.runtime.middleware;
+
+import static org.junit.Assert.*;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext;
+import org.apache.kyuubi.engine.dataagent.runtime.ApprovalMode;
+import org.apache.kyuubi.engine.dataagent.runtime.ConversationMemory;
+import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent;
+import org.apache.kyuubi.engine.dataagent.runtime.event.ApprovalRequest;
+import org.apache.kyuubi.engine.dataagent.runtime.event.EventType;
+import org.apache.kyuubi.engine.dataagent.tool.AgentTool;
+import org.apache.kyuubi.engine.dataagent.tool.ToolContext;
+import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry;
+import org.apache.kyuubi.engine.dataagent.tool.ToolRiskLevel;
+import org.junit.Before;
+import org.junit.Test;
+
+public class ApprovalMiddlewareTest {
+
+ private ToolRegistry registry;
+ private List emittedEvents;
+
+ @Before
+ public void setUp() {
+ registry = new ToolRegistry(30);
+ registry.register(safeTool("safe_tool"));
+ registry.register(destructiveTool("dangerous_tool"));
+ emittedEvents = Collections.synchronizedList(new ArrayList<>());
+ }
+
+ // --- Auto-approve mode: all tools pass ---
+
+ @Test
+ public void testAutoApproveModeSkipsAllApproval() {
+ ApprovalMiddleware mw = newApprovalMiddleware();
+ AgentRunContext ctx = makeContext(ApprovalMode.AUTO_APPROVE);
+
+ assertNull(mw.beforeToolCall(ctx, "tc1", "dangerous_tool", Collections.emptyMap()));
+ assertNull(mw.beforeToolCall(ctx, "tc2", "safe_tool", Collections.emptyMap()));
+ assertTrue("No approval events should be emitted", emittedEvents.isEmpty());
+ }
+
+ // --- Normal mode: safe auto-approved, destructive needs approval ---
+
+ @Test
+ public void testNormalModeAutoApprovesSafeTool() {
+ ApprovalMiddleware mw = newApprovalMiddleware();
+ AgentRunContext ctx = makeContext(ApprovalMode.NORMAL);
+
+ assertNull(mw.beforeToolCall(ctx, "tc1", "safe_tool", Collections.emptyMap()));
+ assertTrue(emittedEvents.isEmpty());
+ }
+
+ @Test
+ public void testNormalModeRequiresApprovalForDestructiveTool() throws Exception {
+ ApprovalMiddleware mw = newApprovalMiddleware(5);
+ AgentRunContext ctx = makeContext(ApprovalMode.NORMAL);
+
+ ExecutorService exec = Executors.newSingleThreadExecutor();
+ try {
+ CountDownLatch eventEmitted = new CountDownLatch(1);
+ // Capture the emitted event to get the requestId
+ ctx.setEventEmitter(
+ event -> {
+ emittedEvents.add(event);
+ eventEmitted.countDown();
+ });
+
+ Future future =
+ exec.submit(
+ () -> mw.beforeToolCall(ctx, "tc1", "dangerous_tool", Collections.emptyMap()));
+
+ // Wait for the approval request event
+ assertTrue("Approval event should be emitted", eventEmitted.await(2, TimeUnit.SECONDS));
+ assertEquals(1, emittedEvents.size());
+ assertEquals(EventType.APPROVAL_REQUEST, emittedEvents.get(0).eventType());
+
+ ApprovalRequest req = (ApprovalRequest) emittedEvents.get(0);
+ assertEquals("dangerous_tool", req.toolName());
+ assertEquals(ToolRiskLevel.DESTRUCTIVE, req.riskLevel());
+
+ // Approve
+ assertTrue(mw.resolve(req.requestId(), true));
+ assertNull("Approved tool should return null (no denial)", future.get(2, TimeUnit.SECONDS));
+ } finally {
+ exec.shutdownNow();
+ }
+ }
+
+ @Test
+ public void testDeniedToolReturnsToolCallDenial() throws Exception {
+ ApprovalMiddleware mw = newApprovalMiddleware(5);
+ AgentRunContext ctx = makeContext(ApprovalMode.NORMAL);
+
+ ExecutorService exec = Executors.newSingleThreadExecutor();
+ try {
+ CountDownLatch eventEmitted = new CountDownLatch(1);
+ ctx.setEventEmitter(
+ event -> {
+ emittedEvents.add(event);
+ eventEmitted.countDown();
+ });
+
+ Future future =
+ exec.submit(
+ () -> mw.beforeToolCall(ctx, "tc1", "dangerous_tool", Collections.emptyMap()));
+
+ assertTrue(eventEmitted.await(2, TimeUnit.SECONDS));
+ ApprovalRequest req = (ApprovalRequest) emittedEvents.get(0);
+
+ // Deny
+ assertTrue(mw.resolve(req.requestId(), false));
+ AgentMiddleware.ToolCallDenial denial = future.get(2, TimeUnit.SECONDS);
+ assertNotNull(denial);
+ assertTrue(denial.reason().contains("denied"));
+ } finally {
+ exec.shutdownNow();
+ }
+ }
+
+ // --- Strict mode: all tools need approval ---
+
+ @Test
+ public void testStrictModeRequiresApprovalForSafeTool() throws Exception {
+ ApprovalMiddleware mw = newApprovalMiddleware(5);
+ AgentRunContext ctx = makeContext(ApprovalMode.STRICT);
+
+ ExecutorService exec = Executors.newSingleThreadExecutor();
+ try {
+ CountDownLatch eventEmitted = new CountDownLatch(1);
+ ctx.setEventEmitter(
+ event -> {
+ emittedEvents.add(event);
+ eventEmitted.countDown();
+ });
+
+ Future future =
+ exec.submit(() -> mw.beforeToolCall(ctx, "tc1", "safe_tool", Collections.emptyMap()));
+
+ assertTrue(eventEmitted.await(2, TimeUnit.SECONDS));
+ ApprovalRequest req = (ApprovalRequest) emittedEvents.get(0);
+ assertEquals("safe_tool", req.toolName());
+
+ assertTrue(mw.resolve(req.requestId(), true));
+ assertNull(future.get(2, TimeUnit.SECONDS));
+ } finally {
+ exec.shutdownNow();
+ }
+ }
+
+ // --- Timeout ---
+
+ @Test
+ public void testApprovalTimeoutReturnsDenial() throws Exception {
+ ApprovalMiddleware mw = newApprovalMiddleware(1); // 1 second timeout
+ AgentRunContext ctx = makeContext(ApprovalMode.STRICT);
+ ctx.setEventEmitter(emittedEvents::add);
+
+ ExecutorService exec = Executors.newSingleThreadExecutor();
+ try {
+ Future future =
+ exec.submit(() -> mw.beforeToolCall(ctx, "tc1", "safe_tool", Collections.emptyMap()));
+
+ // Don't resolve — let it time out
+ AgentMiddleware.ToolCallDenial denial = future.get(5, TimeUnit.SECONDS);
+ assertNotNull("Timeout should produce a denial", denial);
+ assertTrue(denial.reason().contains("timed out"));
+ } finally {
+ exec.shutdownNow();
+ }
+ }
+
+ // --- Cancel all ---
+
+ @Test
+ public void testOnStopUnblocksPendingRequests() throws Exception {
+ ApprovalMiddleware mw = newApprovalMiddleware(30);
+ AgentRunContext ctx = makeContext(ApprovalMode.STRICT);
+ ctx.setEventEmitter(emittedEvents::add);
+
+ ExecutorService exec = Executors.newSingleThreadExecutor();
+ try {
+ CountDownLatch started = new CountDownLatch(1);
+ Future future =
+ exec.submit(
+ () -> {
+ started.countDown();
+ return mw.beforeToolCall(ctx, "tc1", "safe_tool", Collections.emptyMap());
+ });
+
+ assertTrue(started.await(2, TimeUnit.SECONDS));
+ Thread.sleep(100); // let the thread enter the blocking wait
+
+ mw.onStop();
+
+ AgentMiddleware.ToolCallDenial denial = future.get(2, TimeUnit.SECONDS);
+ assertNotNull("onStop should unblock with a denial", denial);
+ } finally {
+ exec.shutdownNow();
+ }
+ }
+
+ // --- Helpers ---
+
+ private ApprovalMiddleware newApprovalMiddleware() {
+ ApprovalMiddleware mw = new ApprovalMiddleware();
+ mw.onRegister(registry);
+ return mw;
+ }
+
+ private ApprovalMiddleware newApprovalMiddleware(long timeoutSeconds) {
+ ApprovalMiddleware mw = new ApprovalMiddleware(timeoutSeconds);
+ mw.onRegister(registry);
+ return mw;
+ }
+
+ private AgentRunContext makeContext(ApprovalMode mode) {
+ AgentRunContext ctx = new AgentRunContext(new ConversationMemory(), mode);
+ ctx.setEventEmitter(emittedEvents::add);
+ return ctx;
+ }
+
+ private static AgentTool safeTool(String name) {
+ return new DummyTool(name, ToolRiskLevel.SAFE);
+ }
+
+ private static AgentTool destructiveTool(String name) {
+ return new DummyTool(name, ToolRiskLevel.DESTRUCTIVE);
+ }
+
+ public static class DummyArgs {
+ public String value;
+ }
+
+ private static class DummyTool implements AgentTool {
+ private final String name;
+ private final ToolRiskLevel riskLevel;
+
+ DummyTool(String name, ToolRiskLevel riskLevel) {
+ this.name = name;
+ this.riskLevel = riskLevel;
+ }
+
+ @Override
+ public String name() {
+ return name;
+ }
+
+ @Override
+ public String description() {
+ return "dummy tool";
+ }
+
+ @Override
+ public ToolRiskLevel riskLevel() {
+ return riskLevel;
+ }
+
+ @Override
+ public Class argsType() {
+ return DummyArgs.class;
+ }
+
+ @Override
+ public String execute(DummyArgs args, ToolContext ctx) {
+ return "ok";
+ }
+ }
+}
diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareLiveTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareLiveTest.java
new file mode 100644
index 00000000000..1c4eeb7c5ad
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareLiveTest.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.runtime.middleware;
+
+import static org.junit.Assert.*;
+import static org.junit.Assume.assumeTrue;
+
+import com.openai.client.OpenAIClient;
+import com.openai.client.okhttp.OpenAIOkHttpClient;
+import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam;
+import com.openai.models.chat.completions.ChatCompletionMessageParam;
+import java.util.List;
+import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext;
+import org.apache.kyuubi.engine.dataagent.runtime.ApprovalMode;
+import org.apache.kyuubi.engine.dataagent.runtime.ConversationMemory;
+import org.junit.Before;
+import org.junit.Test;
+
+/**
+ * Live integration test for {@link CompactionMiddleware}: exercises the full compaction path
+ * against a real OpenAI-compatible LLM. Requires {@code DATA_AGENT_LLM_API_KEY} and {@code
+ * DATA_AGENT_LLM_API_URL} environment variables; skipped otherwise.
+ */
+public class CompactionMiddlewareLiveTest {
+
+ private static final String API_KEY = System.getenv().getOrDefault("DATA_AGENT_LLM_API_KEY", "");
+ private static final String BASE_URL = System.getenv().getOrDefault("DATA_AGENT_LLM_API_URL", "");
+ private static final String MODEL_NAME = System.getenv().getOrDefault("DATA_AGENT_LLM_MODEL", "");
+
+ private OpenAIClient client;
+
+ @Before
+ public void setUp() {
+ assumeTrue("DATA_AGENT_LLM_API_KEY not set, skipping live tests", !API_KEY.isEmpty());
+ assumeTrue("DATA_AGENT_LLM_API_URL not set, skipping live tests", !BASE_URL.isEmpty());
+ client = OpenAIOkHttpClient.builder().apiKey(API_KEY).baseUrl(BASE_URL).build();
+ }
+
+ @Test
+ public void compactsHistoryWhenThresholdCrossed() {
+ // Seed a realistic ReAct-style history so the summarizer has something non-trivial to
+ // summarize. ~20 alternating user/assistant turns.
+ ConversationMemory memory = new ConversationMemory();
+ memory.setSystemPrompt(
+ "You are a data agent. You previously helped the user investigate the orders table.");
+ for (int i = 0; i < 10; i++) {
+ memory.addUserMessage(
+ "Follow-up question " + i + ": what about the column customer_id in orders?");
+ memory.addAssistantMessage(
+ ChatCompletionAssistantMessageParam.builder()
+ .content(
+ "Assistant turn "
+ + i
+ + ": the orders table has a customer_id BIGINT column referencing customers.id.")
+ .build());
+ }
+ int originalSize = memory.size();
+
+ AgentRunContext ctx = new AgentRunContext(memory, ApprovalMode.AUTO_APPROVE);
+ // Simulate the previous LLM call having reported a large prompt_tokens so the next
+ // beforeLlmCall trips the threshold.
+ ctx.addTokenUsage(60_000, 0, 60_000);
+
+ CompactionMiddleware mw = new CompactionMiddleware(client, MODEL_NAME, /* trigger */ 50_000L);
+
+ AgentMiddleware.LlmCallAction action = mw.beforeLlmCall(ctx, memory.buildLlmMessages());
+
+ assertNotNull("expected compaction to fire", action);
+ assertTrue(action instanceof AgentMiddleware.LlmModifyMessages);
+
+ // History got rewritten: [summary user msg] + kept tail.
+ List hist = memory.getHistory();
+ assertTrue(hist.size() < originalSize);
+ assertTrue(hist.get(0).isUser());
+ String first = hist.get(0).asUser().content().text().orElse("");
+ assertTrue(
+ "summary message should be wrapped in ",
+ first.contains(""));
+
+ // The LLM was told to emit 8 markdown sections; sanity-check a couple show up.
+ assertTrue(
+ "summary should contain '## User Intent' section",
+ first.contains("## User Intent") || first.contains("User Intent"));
+ }
+}
diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareTest.java
new file mode 100644
index 00000000000..d4c0cc581db
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareTest.java
@@ -0,0 +1,322 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.runtime.middleware;
+
+import static org.junit.Assert.*;
+
+import com.openai.client.OpenAIClient;
+import com.openai.client.okhttp.OpenAIOkHttpClient;
+import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam;
+import com.openai.models.chat.completions.ChatCompletionMessageFunctionToolCall;
+import com.openai.models.chat.completions.ChatCompletionMessageParam;
+import com.openai.models.chat.completions.ChatCompletionMessageToolCall;
+import com.openai.models.chat.completions.ChatCompletionToolMessageParam;
+import com.openai.models.chat.completions.ChatCompletionUserMessageParam;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext;
+import org.apache.kyuubi.engine.dataagent.runtime.ApprovalMode;
+import org.apache.kyuubi.engine.dataagent.runtime.ConversationMemory;
+import org.junit.Before;
+import org.junit.Test;
+
+/**
+ * Unit tests that exercise only the deterministic, LLM-free parts of {@link CompactionMiddleware}:
+ *
+ *
+ * Static helpers {@code computeSplit} and {@code estimateTailAfterLastAssistant}.
+ * Pre-summarizer gating paths in {@code beforeLlmCall} (below threshold, empty old after
+ * split) that never reach the LLM call.
+ * Constructor validation.
+ *
+ *
+ * The end-to-end behaviour (actual summarization output, compacted-history structure, no re-trigger
+ * next turn) is covered by {@code CompactionMiddlewareLiveTest}, which is gated on real LLM
+ * credentials.
+ */
+public class CompactionMiddlewareTest {
+
+ /**
+ * A minimally-configured real {@link OpenAIClient}. Never invoked in these tests because every
+ * code path exercised here returns before reaching the summarizer. Uses dummy credentials and a
+ * bogus base URL so an accidental invocation would fail loudly.
+ */
+ private static final OpenAIClient DUMMY_CLIENT =
+ OpenAIOkHttpClient.builder()
+ .apiKey("dummy-for-unit-tests")
+ .baseUrl("http://127.0.0.1:0")
+ .build();
+
+ private ConversationMemory memory;
+ private AgentRunContext ctx;
+
+ @Before
+ public void setUp() {
+ memory = new ConversationMemory();
+ memory.setSystemPrompt("SYS");
+ ctx = new AgentRunContext(memory, ApprovalMode.AUTO_APPROVE);
+ }
+
+ // ----- computeSplit -----
+
+ @Test
+ public void computeSplit_shortHistoryReturnsAllInKept() {
+ List history =
+ Arrays.asList(userMsg("u0"), asstMsg(asstPlain("a0")));
+ CompactionMiddleware.Split s = CompactionMiddleware.computeSplit(history, 4);
+ assertEquals(0, s.old.size());
+ assertEquals(2, s.kept.size());
+ }
+
+ @Test
+ public void computeSplit_simpleAlternationSplitsCorrectly() {
+ // 10 msgs: u0,a0,u1,a1,u2,a2,u3,a3,u4,a4. keep=2 → boundary at u3 (2 users from tail).
+ List history = alternatingHistory(10);
+ CompactionMiddleware.Split s = CompactionMiddleware.computeSplit(history, 2);
+ // 4 users from tail: u4,u3 → splitIdx = index of u3 = 6 → old=[0..5], kept=[6..9]
+ assertEquals(6, s.old.size());
+ assertEquals(4, s.kept.size());
+ assertTrue(s.kept.get(0).isUser());
+ }
+
+ @Test
+ public void computeSplit_neverOrphansToolResult() {
+ // Layout: u0 a0 u1 a1 u2 a2 u3 a3(tc1) tool(tc1) u4 a4 u5 a5 u6 a6 u7 a7
+ // keep=4 → naive splitIdx lands between tool(tc1) and u4; pair-protection must shift it back
+ // to before a3(tc1), so a3(tc1) + tool(tc1) end up in kept.
+ List history = new ArrayList<>();
+ history.add(userMsg("u0"));
+ history.add(asstMsg(asstPlain("a0")));
+ history.add(userMsg("u1"));
+ history.add(asstMsg(asstPlain("a1")));
+ history.add(userMsg("u2"));
+ history.add(asstMsg(asstPlain("a2")));
+ history.add(userMsg("u3"));
+ history.add(asstMsg(asstWithToolCall("a3", "tc1", "sql_query", "{}")));
+ history.add(toolMsg("tc1", "r1"));
+ history.add(userMsg("u4"));
+ history.add(asstMsg(asstPlain("a4")));
+ history.add(userMsg("u5"));
+ history.add(asstMsg(asstPlain("a5")));
+ history.add(userMsg("u6"));
+ history.add(asstMsg(asstPlain("a6")));
+ history.add(userMsg("u7"));
+ history.add(asstMsg(asstPlain("a7")));
+
+ CompactionMiddleware.Split s = CompactionMiddleware.computeSplit(history, 4);
+ assertNoOrphanToolResult(s.kept);
+ // Verify the tc1 pair really did end up in kept, not old.
+ assertTrue("a3(tc1) must be in kept", containsToolCallId(s.kept, "tc1"));
+ assertTrue("tool_result(tc1) must be in kept", containsToolCallIdAsResult(s.kept, "tc1"));
+ }
+
+ @Test
+ public void computeSplit_keepCountExceedsAvailableUsers() {
+ // Only 2 user msgs but we ask to keep 4 — boundary walks to the top, old=[], kept=all.
+ List history = alternatingHistory(4); // u0,a0,u1,a1
+ CompactionMiddleware.Split s = CompactionMiddleware.computeSplit(history, 4);
+ assertEquals(0, s.old.size());
+ assertEquals(4, s.kept.size());
+ }
+
+ // ----- estimateTailAfterLastAssistant -----
+
+ @Test
+ public void estimateTail_afterLastAssistant() {
+ // u(200 chars) a(50) u(100) → last assistant at index 1, tail is the final user = 100 chars
+ // → 100/4 = 25 tokens
+ List msgs =
+ Arrays.asList(
+ userMsg(repeat('x', 200)),
+ asstMsg(asstPlain(repeat('y', 50))),
+ userMsg(repeat('z', 100)));
+ assertEquals(25L, CompactionMiddleware.estimateTailAfterLastAssistant(msgs));
+ }
+
+ @Test
+ public void estimateTail_noAssistantMeansEverythingIsTail() {
+ List msgs =
+ Arrays.asList(userMsg(repeat('x', 400)), userMsg(repeat('y', 400)));
+ assertEquals(200L, CompactionMiddleware.estimateTailAfterLastAssistant(msgs));
+ }
+
+ @Test
+ public void estimateTail_emptyReturnsZero() {
+ assertEquals(0L, CompactionMiddleware.estimateTailAfterLastAssistant(Collections.emptyList()));
+ }
+
+ // ----- beforeLlmCall pre-summarizer gating -----
+
+ @Test
+ public void belowThresholdReturnsNull() {
+ seedSimpleHistory(6);
+ ctx.addTokenUsage(1000, 0, 1000);
+ CompactionMiddleware mw = new CompactionMiddleware(DUMMY_CLIENT, "m", 50_000L);
+
+ assertNull(mw.beforeLlmCall(ctx, memory.buildLlmMessages()));
+ // Nothing was mutated.
+ assertEquals(6, memory.size());
+ }
+
+ @Test
+ public void aboveThresholdButHistoryTooShortReturnsNull() {
+ // Threshold crossed (60k cumulative) but history has only 2 user turns → computeSplit
+ // can't satisfy KEEP_RECENT_TURNS=4 and keeps everything, leaving split.old empty; so
+ // beforeLlmCall bails out before ever invoking the summarizer.
+ memory.addUserMessage("u0");
+ memory.addAssistantMessage(asstPlain("a0"));
+ memory.addUserMessage("u1");
+ ctx.addTokenUsage(60_000, 0, 60_000);
+ CompactionMiddleware mw = new CompactionMiddleware(DUMMY_CLIENT, "m", 50_000L);
+
+ assertNull(mw.beforeLlmCall(ctx, memory.buildLlmMessages()));
+ assertEquals(3, memory.size());
+ }
+
+ @Test
+ public void triggerUsesLastCallTotalNotCumulative() {
+ // Two consecutive calls with total_tokens below threshold. The middleware must key on the
+ // *last* call's total (prompt + completion), not the session cumulative — otherwise a session
+ // that has accumulated large cumulative cost but then compacted would misfire. Using total
+ // (not just prompt) also covers the last assistant message — e.g. a tool_call's completion
+ // tokens — which is part of the next prompt but sits beyond the tail estimator's window.
+ seedSimpleHistory(6);
+ CompactionMiddleware mw = new CompactionMiddleware(DUMMY_CLIENT, "m", 50_000L);
+
+ ctx.addTokenUsage(4_000, 1_000, 5_000);
+ assertNull(mw.beforeLlmCall(ctx, memory.buildLlmMessages()));
+
+ ctx.addTokenUsage(8_000, 2_000, 10_000);
+ assertNull(mw.beforeLlmCall(ctx, memory.buildLlmMessages()));
+
+ assertEquals(10_000L, memory.getLastTotalTokens());
+ assertEquals(15_000L, memory.getCumulativeTotalTokens());
+ }
+
+ // ----- helpers -----
+
+ private void seedSimpleHistory(int n) {
+ for (int i = 0; i < n; i++) {
+ if (i % 2 == 0) {
+ memory.addUserMessage("u" + i);
+ } else {
+ memory.addAssistantMessage(asstPlain("a" + i));
+ }
+ }
+ }
+
+ private static List alternatingHistory(int n) {
+ List out = new ArrayList<>(n);
+ for (int i = 0; i < n; i++) {
+ if (i % 2 == 0) {
+ out.add(userMsg("u" + i));
+ } else {
+ out.add(asstMsg(asstPlain("a" + i)));
+ }
+ }
+ return out;
+ }
+
+ private static ChatCompletionMessageParam userMsg(String text) {
+ return ChatCompletionMessageParam.ofUser(
+ ChatCompletionUserMessageParam.builder().content(text).build());
+ }
+
+ private static ChatCompletionMessageParam asstMsg(ChatCompletionAssistantMessageParam p) {
+ return ChatCompletionMessageParam.ofAssistant(p);
+ }
+
+ private static ChatCompletionMessageParam toolMsg(String toolCallId, String content) {
+ return ChatCompletionMessageParam.ofTool(
+ ChatCompletionToolMessageParam.builder().toolCallId(toolCallId).content(content).build());
+ }
+
+ private static ChatCompletionAssistantMessageParam asstPlain(String text) {
+ return ChatCompletionAssistantMessageParam.builder().content(text).build();
+ }
+
+ private static ChatCompletionAssistantMessageParam asstWithToolCall(
+ String text, String toolCallId, String toolName, String args) {
+ List calls = new ArrayList<>();
+ calls.add(
+ ChatCompletionMessageToolCall.ofFunction(
+ ChatCompletionMessageFunctionToolCall.builder()
+ .id(toolCallId)
+ .function(
+ ChatCompletionMessageFunctionToolCall.Function.builder()
+ .name(toolName)
+ .arguments(args)
+ .build())
+ .build()));
+ return ChatCompletionAssistantMessageParam.builder().content(text).toolCalls(calls).build();
+ }
+
+ private static String repeat(char c, int n) {
+ char[] arr = new char[n];
+ Arrays.fill(arr, c);
+ return new String(arr);
+ }
+
+ private static boolean containsToolCallId(List msgs, String id) {
+ for (ChatCompletionMessageParam m : msgs) {
+ if (m.isAssistant()) {
+ List calls = m.asAssistant().toolCalls().orElse(null);
+ if (calls == null) continue;
+ for (ChatCompletionMessageToolCall tc : calls) {
+ if (tc.isFunction() && id.equals(tc.asFunction().id())) return true;
+ }
+ }
+ }
+ return false;
+ }
+
+ private static boolean containsToolCallIdAsResult(
+ List msgs, String id) {
+ for (ChatCompletionMessageParam m : msgs) {
+ if (m.isTool() && id.equals(m.asTool().toolCallId())) return true;
+ }
+ return false;
+ }
+
+ private static void assertNoOrphanToolResult(List msgs) {
+ Set issued = new HashSet<>();
+ for (ChatCompletionMessageParam m : msgs) {
+ if (m.isAssistant()) {
+ m.asAssistant()
+ .toolCalls()
+ .ifPresent(
+ calls -> {
+ for (ChatCompletionMessageToolCall tc : calls) {
+ if (tc.isFunction()) issued.add(tc.asFunction().id());
+ }
+ });
+ }
+ }
+ for (ChatCompletionMessageParam m : msgs) {
+ if (m.isTool()) {
+ assertTrue(
+ "tool_result id=" + m.asTool().toolCallId() + " has no matching tool_call",
+ issued.contains(m.asTool().toolCallId()));
+ }
+ }
+ }
+}
diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddlewareTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddlewareTest.java
new file mode 100644
index 00000000000..cb107b775f5
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddlewareTest.java
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.runtime.middleware;
+
+import static org.junit.Assert.*;
+
+import java.util.Collections;
+import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext;
+import org.apache.kyuubi.engine.dataagent.runtime.ApprovalMode;
+import org.apache.kyuubi.engine.dataagent.runtime.ConversationMemory;
+import org.apache.kyuubi.engine.dataagent.tool.output.GrepToolOutputTool;
+import org.apache.kyuubi.engine.dataagent.tool.output.ReadToolOutputTool;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+public class ToolResultOffloadMiddlewareTest {
+
+ private ToolResultOffloadMiddleware mw;
+ private AgentRunContext ctxWithSession;
+ private AgentRunContext ctxNoSession;
+
+ @Before
+ public void setUp() {
+ mw = new ToolResultOffloadMiddleware();
+ ctxWithSession =
+ new AgentRunContext(new ConversationMemory(), ApprovalMode.AUTO_APPROVE, "sess-1");
+ ctxNoSession = new AgentRunContext(new ConversationMemory(), ApprovalMode.AUTO_APPROVE, null);
+ }
+
+ @After
+ public void tearDown() {
+ mw.onStop();
+ }
+
+ @Test
+ public void underThresholdPassesThrough() {
+ String small = "row1\nrow2\nrow3\n";
+ String out =
+ mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), small);
+ assertNull(out);
+ }
+
+ @Test
+ public void overLineThresholdTriggersOffload() {
+ StringBuilder sb = new StringBuilder();
+ for (int i = 0; i < 600; i++) sb.append("row").append(i).append('\n');
+ String out =
+ mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString());
+
+ assertNotNull(out);
+ assertTrue(out, out.contains("Tool output truncated"));
+ assertTrue(out, out.contains("Saved to:"));
+ assertTrue(out, out.contains(ReadToolOutputTool.NAME));
+ assertTrue(out, out.contains(GrepToolOutputTool.NAME));
+ assertTrue(out, out.contains("row0"));
+ assertTrue(out, out.contains("row599"));
+ }
+
+ @Test
+ public void overByteThresholdTriggersOffload() {
+ // 60 lines of ~1 KB each = ~60 KB — over the byte threshold but well under the line threshold.
+ StringBuilder sb = new StringBuilder();
+ for (int i = 0; i < 60; i++) {
+ for (int j = 0; j < 1024; j++) sb.append('a');
+ sb.append('\n');
+ }
+ String out =
+ mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString());
+
+ assertNotNull("byte threshold should trigger", out);
+ assertTrue(out, out.contains("Tool output truncated"));
+ }
+
+ @Test
+ public void retrievalToolsAreExemptFromGate() {
+ StringBuilder sb = new StringBuilder();
+ for (int i = 0; i < 2000; i++) sb.append("row").append(i).append('\n');
+ assertNull(
+ mw.afterToolCall(
+ ctxWithSession, ReadToolOutputTool.NAME, Collections.emptyMap(), sb.toString()));
+ assertNull(
+ mw.afterToolCall(
+ ctxWithSession, GrepToolOutputTool.NAME, Collections.emptyMap(), sb.toString()));
+ }
+
+ @Test
+ public void missingSessionIdPassesThrough() {
+ StringBuilder sb = new StringBuilder();
+ for (int i = 0; i < 1000; i++) sb.append("row").append(i).append('\n');
+ String out =
+ mw.afterToolCall(ctxNoSession, "run_select_query", Collections.emptyMap(), sb.toString());
+ assertNull("without sessionId, cannot offload safely — pass through", out);
+ }
+
+ @Test
+ public void onSessionCloseClearsCounterAndFiles() {
+ StringBuilder sb = new StringBuilder();
+ for (int i = 0; i < 600; i++) sb.append("row").append(i).append('\n');
+ mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString());
+ assertEquals(1, mw.trackedSessions());
+
+ mw.onSessionClose("sess-1");
+ assertEquals(0, mw.trackedSessions());
+ }
+
+ @Test
+ public void multipleOffloadsReuseSameSessionDir() {
+ StringBuilder sb = new StringBuilder();
+ for (int i = 0; i < 600; i++) sb.append("row").append(i).append('\n');
+ String out1 =
+ mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString());
+ String out2 =
+ mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString());
+ // Both previews reference the same session dir, different file names.
+ assertNotEquals(extractPath(out1), extractPath(out2));
+ assertTrue(extractPath(out1).contains("sess-1"));
+ assertTrue(extractPath(out2).contains("sess-1"));
+ }
+
+ private static String extractPath(String preview) {
+ int i = preview.indexOf("Saved to:");
+ int eol = preview.indexOf('\n', i);
+ return preview.substring(i + "Saved to:".length(), eol).trim();
+ }
+}
diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistryThreadSafetyTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistryThreadSafetyTest.java
index c3ceb1a4dc0..7f790e1238e 100644
--- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistryThreadSafetyTest.java
+++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistryThreadSafetyTest.java
@@ -66,7 +66,7 @@ public Class argsType() {
}
@Override
- public String execute(DummyArgs args) {
+ public String execute(DummyArgs args, ToolContext ctx) {
return "result_" + idx;
}
});
@@ -118,7 +118,7 @@ public Class argsType() {
}
@Override
- public String execute(DummyArgs args) {
+ public String execute(DummyArgs args, ToolContext ctx) {
return "existing_result";
}
});
@@ -159,7 +159,7 @@ public Class argsType() {
}
@Override
- public String execute(DummyArgs args) {
+ public String execute(DummyArgs args, ToolContext ctx) {
return "dynamic";
}
});
diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolTest.java
index 8e0a5cd01b0..777017c4438 100644
--- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolTest.java
+++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolTest.java
@@ -109,7 +109,7 @@ public Class argsType() {
}
@Override
- public String execute(ToolRegistryThreadSafetyTest.DummyArgs args) {
+ public String execute(ToolRegistryThreadSafetyTest.DummyArgs args, ToolContext ctx) {
try {
Thread.sleep(60_000);
} catch (InterruptedException e) {
@@ -147,7 +147,7 @@ public Class argsType() {
}
@Override
- public String execute(ToolRegistryThreadSafetyTest.DummyArgs args) {
+ public String execute(ToolRegistryThreadSafetyTest.DummyArgs args, ToolContext ctx) {
throw new RuntimeException("intentional failure");
}
});
diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryToolTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryToolTest.java
index c46fc2501bc..5384bd937d9 100644
--- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryToolTest.java
+++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryToolTest.java
@@ -24,6 +24,7 @@
import java.sql.Statement;
import java.util.ArrayList;
import java.util.List;
+import org.apache.kyuubi.engine.dataagent.tool.ToolContext;
import org.apache.kyuubi.engine.dataagent.tool.ToolRiskLevel;
import org.junit.After;
import org.junit.Before;
@@ -60,7 +61,7 @@ public void testRiskLevelDestructive() {
public void testInsert() {
SqlQueryArgs args = new SqlQueryArgs();
args.sql = "INSERT INTO t VALUES (9999, 'hello')";
- String result = tool.execute(args);
+ String result = tool.execute(args, ToolContext.EMPTY);
assertTrue(result.contains("1 row(s) affected"));
}
@@ -68,21 +69,21 @@ public void testInsert() {
public void testUpdate() {
SqlQueryArgs args = new SqlQueryArgs();
args.sql = "UPDATE t SET v = 'updated' WHERE id = 1";
- assertTrue(tool.execute(args).contains("1 row(s) affected"));
+ assertTrue(tool.execute(args, ToolContext.EMPTY).contains("1 row(s) affected"));
}
@Test
public void testDelete() {
SqlQueryArgs args = new SqlQueryArgs();
args.sql = "DELETE FROM t WHERE id = 1";
- assertTrue(tool.execute(args).contains("1 row(s) affected"));
+ assertTrue(tool.execute(args, ToolContext.EMPTY).contains("1 row(s) affected"));
}
@Test
public void testCreateTable() {
SqlQueryArgs args = new SqlQueryArgs();
args.sql = "CREATE TABLE new_t (id INTEGER PRIMARY KEY, v TEXT)";
- assertTrue(tool.execute(args).contains("executed successfully"));
+ assertTrue(tool.execute(args, ToolContext.EMPTY).contains("executed successfully"));
}
@Test
@@ -90,7 +91,7 @@ public void testAlsoAcceptsSelect() {
// Mutation tool does not enforce read-only check; SELECT works fine here.
SqlQueryArgs args = new SqlQueryArgs();
args.sql = "SELECT v FROM t WHERE id = 1";
- String result = tool.execute(args);
+ String result = tool.execute(args, ToolContext.EMPTY);
assertFalse(result.startsWith("Error:"));
}
@@ -98,18 +99,18 @@ public void testAlsoAcceptsSelect() {
public void testRejectsEmptyAndNullSql() {
SqlQueryArgs emptyArgs = new SqlQueryArgs();
emptyArgs.sql = "";
- assertTrue(tool.execute(emptyArgs).startsWith("Error:"));
+ assertTrue(tool.execute(emptyArgs, ToolContext.EMPTY).startsWith("Error:"));
SqlQueryArgs nullArgs = new SqlQueryArgs();
nullArgs.sql = null;
- assertTrue(tool.execute(nullArgs).startsWith("Error:"));
+ assertTrue(tool.execute(nullArgs, ToolContext.EMPTY).startsWith("Error:"));
}
@Test
public void testInvalidSqlReturnsError() {
SqlQueryArgs args = new SqlQueryArgs();
args.sql = "INSERT INTO nonexistent_table VALUES (1)";
- assertTrue(tool.execute(args).startsWith("Error:"));
+ assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:"));
}
// --- Helpers ---
diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryToolTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryToolTest.java
index 3c6579cf9a7..d1015ede63d 100644
--- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryToolTest.java
+++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryToolTest.java
@@ -24,6 +24,7 @@
import java.sql.Statement;
import java.util.ArrayList;
import java.util.List;
+import org.apache.kyuubi.engine.dataagent.tool.ToolContext;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -56,7 +57,7 @@ public void tearDown() {
public void testRejectsInsert() {
SqlQueryArgs args = new SqlQueryArgs();
args.sql = "INSERT INTO large_table VALUES (9999, 'x')";
- String result = tool.execute(args);
+ String result = tool.execute(args, ToolContext.EMPTY);
assertTrue(result.startsWith("Error:"));
assertTrue(result.contains("read-only"));
assertTrue(result.contains("run_mutation_query"));
@@ -66,28 +67,28 @@ public void testRejectsInsert() {
public void testRejectsUpdate() {
SqlQueryArgs args = new SqlQueryArgs();
args.sql = "UPDATE large_table SET value = 'x' WHERE id = 1";
- assertTrue(tool.execute(args).startsWith("Error:"));
+ assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:"));
}
@Test
public void testRejectsDelete() {
SqlQueryArgs args = new SqlQueryArgs();
args.sql = "DELETE FROM large_table WHERE id = 1";
- assertTrue(tool.execute(args).startsWith("Error:"));
+ assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:"));
}
@Test
public void testRejectsCreateTable() {
SqlQueryArgs args = new SqlQueryArgs();
args.sql = "CREATE TABLE x (id INT)";
- assertTrue(tool.execute(args).startsWith("Error:"));
+ assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:"));
}
@Test
public void testAllowsSelect() {
SqlQueryArgs args = new SqlQueryArgs();
args.sql = "SELECT id FROM large_table LIMIT 100";
- String result = tool.execute(args);
+ String result = tool.execute(args, ToolContext.EMPTY);
assertFalse(result.startsWith("Error:"));
assertTrue(result.contains("[100 row(s) returned]"));
}
@@ -96,7 +97,7 @@ public void testAllowsSelect() {
public void testAllowsCte() {
SqlQueryArgs args = new SqlQueryArgs();
args.sql = "WITH cte AS (SELECT id, value FROM large_table LIMIT 5) SELECT * FROM cte";
- String result = tool.execute(args);
+ String result = tool.execute(args, ToolContext.EMPTY);
assertFalse(result.startsWith("Error:"));
assertTrue(result.contains("row(s)"));
}
@@ -107,7 +108,7 @@ public void testAllowsCte() {
public void testRespectsLimitInSql() {
SqlQueryArgs args = new SqlQueryArgs();
args.sql = "SELECT id FROM large_table LIMIT 5";
- assertTrue(tool.execute(args).contains("[5 row(s) returned]"));
+ assertTrue(tool.execute(args, ToolContext.EMPTY).contains("[5 row(s) returned]"));
}
@Test
@@ -116,7 +117,7 @@ public void testNoClientSideCapWhenLimitOmitted() {
// Cap discipline is delegated to the LLM via the system prompt.
SqlQueryArgs args = new SqlQueryArgs();
args.sql = "SELECT id FROM large_table";
- assertTrue(tool.execute(args).contains("[1500 row(s) returned]"));
+ assertTrue(tool.execute(args, ToolContext.EMPTY).contains("[1500 row(s) returned]"));
}
// --- Zero-row result ---
@@ -125,7 +126,7 @@ public void testNoClientSideCapWhenLimitOmitted() {
public void testZeroRowsResult() {
SqlQueryArgs args = new SqlQueryArgs();
args.sql = "SELECT id FROM large_table WHERE id < 0";
- String result = tool.execute(args);
+ String result = tool.execute(args, ToolContext.EMPTY);
assertFalse(result.startsWith("Error:"));
assertTrue(result.contains("[0 row(s) returned]"));
}
@@ -136,15 +137,15 @@ public void testZeroRowsResult() {
public void testSelectWithLeadingBlockComment() {
SqlQueryArgs args = new SqlQueryArgs();
args.sql = "/* get count */ SELECT COUNT(*) FROM large_table";
- assertFalse(tool.execute(args).startsWith("Error:"));
+ assertFalse(tool.execute(args, ToolContext.EMPTY).startsWith("Error:"));
}
@Test
public void testRejectsMutationHiddenBehindComment() {
SqlQueryArgs args = new SqlQueryArgs();
args.sql = "-- looks innocent\nDROP TABLE large_table";
- assertTrue(tool.execute(args).startsWith("Error:"));
- assertTrue(tool.execute(args).contains("read-only"));
+ assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:"));
+ assertTrue(tool.execute(args, ToolContext.EMPTY).contains("read-only"));
}
// --- Edge cases ---
@@ -153,25 +154,25 @@ public void testRejectsMutationHiddenBehindComment() {
public void testRejectsEmptyAndNullSql() {
SqlQueryArgs emptyArgs = new SqlQueryArgs();
emptyArgs.sql = "";
- assertTrue(tool.execute(emptyArgs).startsWith("Error:"));
+ assertTrue(tool.execute(emptyArgs, ToolContext.EMPTY).startsWith("Error:"));
SqlQueryArgs nullArgs = new SqlQueryArgs();
nullArgs.sql = null;
- assertTrue(tool.execute(nullArgs).startsWith("Error:"));
+ assertTrue(tool.execute(nullArgs, ToolContext.EMPTY).startsWith("Error:"));
}
@Test
public void testRejectsWhitespaceOnlySql() {
SqlQueryArgs args = new SqlQueryArgs();
args.sql = " \t\n ";
- assertTrue(tool.execute(args).startsWith("Error:"));
+ assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:"));
}
@Test
public void testInvalidSqlReturnsError() {
SqlQueryArgs args = new SqlQueryArgs();
args.sql = "SELECT * FROM nonexistent_table";
- assertTrue(tool.execute(args).startsWith("Error:"));
+ assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:"));
}
// --- Output formatting ---
@@ -188,7 +189,7 @@ public void testNullValuesRenderedAsNULL() {
}
SqlQueryArgs args = new SqlQueryArgs();
args.sql = "SELECT id, name FROM nullable_test ORDER BY ROWID";
- String result = tool.execute(args);
+ String result = tool.execute(args, ToolContext.EMPTY);
assertTrue(result.contains("NULL"));
assertTrue(result.contains("Alice"));
}
@@ -204,7 +205,7 @@ public void testPipeCharacterEscapedInOutput() {
}
SqlQueryArgs args = new SqlQueryArgs();
args.sql = "SELECT val FROM pipe_test";
- String result = tool.execute(args);
+ String result = tool.execute(args, ToolContext.EMPTY);
assertTrue("Pipe should be escaped for markdown table", result.contains("a\\|b\\|c"));
}
@@ -214,7 +215,7 @@ public void testPipeCharacterEscapedInOutput() {
public void testExtractRootCauseFromNestedExceptions() {
SqlQueryArgs args = new SqlQueryArgs();
args.sql = "SELECT * FROM this_table_does_not_exist_at_all";
- String result = tool.execute(args);
+ String result = tool.execute(args, ToolContext.EMPTY);
assertTrue(result.startsWith("Error:"));
assertTrue(result.contains("this_table_does_not_exist_at_all"));
}
@@ -223,7 +224,7 @@ public void testExtractRootCauseFromNestedExceptions() {
public void testErrorMessageIsConcise() {
SqlQueryArgs args = new SqlQueryArgs();
args.sql = "SELEC INVALID SYNTAX HERE !!!";
- String result = tool.execute(args);
+ String result = tool.execute(args, ToolContext.EMPTY);
assertTrue(result.startsWith("Error:"));
long newlines = result.chars().filter(c -> c == '\n').count();
assertTrue("Error should be concise (<=2 newlines), got " + newlines, newlines <= 2);
@@ -236,7 +237,7 @@ public void testCustomQueryTimeout() {
RunSelectQueryTool customTool = new RunSelectQueryTool(ds, 5);
SqlQueryArgs args = new SqlQueryArgs();
args.sql = "SELECT COUNT(*) FROM large_table";
- assertFalse(customTool.execute(args).startsWith("Error:"));
+ assertFalse(customTool.execute(args, ToolContext.EMPTY).startsWith("Error:"));
}
@Test
@@ -314,7 +315,7 @@ public boolean isWrapperFor(Class> iface) {
RunSelectQueryTool timeoutTool = new RunSelectQueryTool(slowDs, 1);
SqlQueryArgs args = new SqlQueryArgs();
args.sql = "SELECT * FROM large_table";
- String result = timeoutTool.execute(args);
+ String result = timeoutTool.execute(args, ToolContext.EMPTY);
assertTrue("Expected error on timeout", result.startsWith("Error:"));
assertTrue("Expected timeout message", result.contains("timed out"));
}
diff --git a/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentCompactionE2ESuite.scala b/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentCompactionE2ESuite.scala
new file mode 100644
index 00000000000..5fd95bb0fdc
--- /dev/null
+++ b/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentCompactionE2ESuite.scala
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.dataagent.operation
+
+import java.sql.DriverManager
+
+import scala.collection.JavaConverters._
+
+import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper}
+
+import org.apache.kyuubi.config.KyuubiConf._
+import org.apache.kyuubi.engine.dataagent.WithDataAgentEngine
+import org.apache.kyuubi.operation.HiveJDBCTestHelper
+
+/**
+ * End-to-end test that forces CompactionMiddleware to fire inside the engine and verifies
+ * two things simultaneously:
+ * 1. The `compaction` SSE event reaches the JDBC client (wiring works).
+ * 2. The agent still answers correctly *after* compaction, proving the summary preserved
+ * the facts the follow-up question depends on.
+ *
+ * The trigger threshold is set extremely low (500 tokens) so that the schema dump plus the
+ * first turn's completion already blow past it, forcing compaction before turn 2 or 3.
+ *
+ * Requires DATA_AGENT_LLM_API_KEY and DATA_AGENT_LLM_API_URL.
+ */
+class DataAgentCompactionE2ESuite extends HiveJDBCTestHelper with WithDataAgentEngine {
+
+ private val apiKey = sys.env.getOrElse("DATA_AGENT_LLM_API_KEY", "")
+ private val apiUrl = sys.env.getOrElse("DATA_AGENT_LLM_API_URL", "")
+ private val modelName = sys.env.getOrElse("DATA_AGENT_LLM_MODEL", "")
+ private val dbPath = {
+ val tmp = System.getProperty("java.io.tmpdir")
+ val uid = java.util.UUID.randomUUID()
+ s"$tmp/dataagent_compaction_e2e_$uid.db"
+ }
+
+ override def withKyuubiConf: Map[String, String] = Map(
+ ENGINE_DATA_AGENT_PROVIDER.key -> "OPENAI_COMPATIBLE",
+ ENGINE_DATA_AGENT_LLM_API_KEY.key -> apiKey,
+ ENGINE_DATA_AGENT_LLM_API_URL.key -> apiUrl,
+ ENGINE_DATA_AGENT_LLM_MODEL.key -> modelName,
+ ENGINE_DATA_AGENT_MAX_ITERATIONS.key -> "10",
+ ENGINE_DATA_AGENT_APPROVAL_MODE.key -> "AUTO_APPROVE",
+ // Force compaction to fire aggressively -- any realistic prompt + one LLM round-trip
+ // will exceed 500 tokens, so compaction must trigger by turn 2 or 3.
+ ENGINE_DATA_AGENT_COMPACTION_TRIGGER_TOKENS.key -> "500",
+ ENGINE_DATA_AGENT_JDBC_URL.key -> s"jdbc:sqlite:$dbPath")
+
+ override protected def jdbcUrl: String = jdbcConnectionUrl
+
+ private val enabled: Boolean = apiKey.nonEmpty && apiUrl.nonEmpty
+
+ override def beforeAll(): Unit = {
+ if (enabled) {
+ setupTestDatabase()
+ super.beforeAll()
+ }
+ }
+
+ override def afterAll(): Unit = {
+ if (enabled) {
+ super.afterAll()
+ new java.io.File(dbPath).delete()
+ }
+ }
+
+ private def setupTestDatabase(): Unit = {
+ new java.io.File(dbPath).delete()
+ val conn = DriverManager.getConnection(s"jdbc:sqlite:$dbPath")
+ try {
+ val stmt = conn.createStatement()
+ stmt.execute(
+ """
+ |CREATE TABLE employees (
+ | id INTEGER PRIMARY KEY,
+ | name TEXT NOT NULL,
+ | department TEXT NOT NULL,
+ | salary REAL NOT NULL
+ |)""".stripMargin)
+ // 6 employees across 3 departments; Frank is unambiguously the top earner.
+ stmt.execute("INSERT INTO employees VALUES (1, 'Alice', 'Engineering', 25000)")
+ stmt.execute("INSERT INTO employees VALUES (2, 'Bob', 'Engineering', 30000)")
+ stmt.execute("INSERT INTO employees VALUES (3, 'Charlie', 'Sales', 20000)")
+ stmt.execute("INSERT INTO employees VALUES (4, 'Diana', 'Sales', 22000)")
+ stmt.execute("INSERT INTO employees VALUES (5, 'Eve', 'Marketing', 18000)")
+ stmt.execute("INSERT INTO employees VALUES (6, 'Frank', 'Engineering', 35000)")
+ } finally {
+ conn.close()
+ }
+ }
+
+ private val mapper = new ObjectMapper()
+
+ private def drainReply(rs: java.sql.ResultSet): String = {
+ val sb = new StringBuilder
+ while (rs.next()) {
+ sb.append(rs.getString("reply"))
+ }
+ sb.toString()
+ }
+
+ private def parseEvents(stream: String): Seq[JsonNode] = {
+ val parser = mapper.getFactory.createParser(stream)
+ try mapper.readValues(parser, classOf[JsonNode]).asScala.toList
+ finally parser.close()
+ }
+
+ private def extractAnswer(events: Seq[JsonNode]): String = {
+ val sb = new StringBuilder
+ events.foreach { node =>
+ if ("content_delta" == node.path("type").asText()) {
+ sb.append(node.path("text").asText(""))
+ }
+ }
+ sb.toString()
+ }
+
+ private val strictFormatHint =
+ "Respond with ONLY the answer, no explanation, no markdown, no punctuation."
+
+ test("E2E: compaction fires mid-conversation and preserves facts across turns") {
+ assume(enabled, "DATA_AGENT_LLM_API_KEY/API_URL not set, skipping")
+
+ // CompactionMiddleware.KEEP_RECENT_TURNS is hardcoded to 4. computeSplit needs a
+ // non-empty 'old' slice, so at least 5 distinct user turns must accumulate before
+ // compaction can fire -- turns 1..(N-4) become the old slice, 4 most recent are
+ // kept verbatim. Turn 5 is the observable trigger point.
+ //
+ // Turn 5 is phrased to force a fresh SQL query rather than relying on recall,
+ // because summary quality varies across LLMs (some drop the top-earner fact).
+ // Correctness of the final answer then validates that the post-compaction history
+ // still gives the agent enough context to pick the right tool and query -- which is
+ // the compaction contract we actually care about: mechanism fires, agent recovers.
+ withJdbcStatement() { stmt =>
+ Seq(
+ "List every department that appears in the employees table.",
+ "How many employees work in Engineering?",
+ "What salaries do Sales employees earn?",
+ "Who works in Marketing?").zipWithIndex.foreach { case (q, i) =>
+ val events = parseEvents(drainReply(stmt.executeQuery(q)))
+ info(s"Turn ${i + 1} answer: ${extractAnswer(events)}")
+ }
+
+ // Turn 5 -- explicitly instruct the agent to re-query so the answer does not
+ // depend on summary fidelity, only on the agent still functioning after
+ // compaction rewrote history.
+ val events5 = parseEvents(drainReply(stmt.executeQuery(
+ "Run a SELECT against the employees table to find the single employee with" +
+ " the highest salary. Report ONLY that employee's name." +
+ s" $strictFormatHint")))
+ val answer5 = extractAnswer(events5)
+ info(s"Turn 5 answer: $answer5")
+
+ val compactionEvents =
+ events5.filter(_.path("type").asText() == "compaction")
+ assert(
+ compactionEvents.nonEmpty,
+ "Expected at least one compaction event in turn 5 (trigger=500 tokens, 5 turns)")
+
+ // Sanity-check event shape -- field names must match ExecuteStatement's SSE encoder.
+ val c = compactionEvents.head
+ assert(
+ c.has("summarized") && c.get("summarized").asInt() > 0,
+ s"compaction event should carry a positive summarized count: $c")
+ assert(c.has("kept") && c.get("kept").asInt() >= 0, s"compaction event missing kept: $c")
+ assert(
+ c.has("triggerTokens") && c.get("triggerTokens").asLong() == 500L,
+ s"compaction event should echo configured trigger: $c")
+
+ // Turn 5 was told to SELECT fresh; Frank (35000) is unambiguously the top earner.
+ // If we don't get "Frank", either the agent failed to re-query after compaction
+ // (real bug in post-compaction history) or the tool layer is broken.
+ assert(
+ answer5.contains("Frank"),
+ s"Turn 5 should identify Frank as the top earner after re-querying; the agent" +
+ s" must remain functional post-compaction. Got: $answer5")
+ }
+ }
+}
diff --git a/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentE2ESuite.scala b/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentE2ESuite.scala
index 977c655dabd..cfb83ecc9b2 100644
--- a/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentE2ESuite.scala
+++ b/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentE2ESuite.scala
@@ -19,6 +19,10 @@ package org.apache.kyuubi.engine.dataagent.operation
import java.sql.DriverManager
+import scala.collection.JavaConverters._
+
+import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper}
+
import org.apache.kyuubi.config.KyuubiConf._
import org.apache.kyuubi.engine.dataagent.WithDataAgentEngine
import org.apache.kyuubi.operation.HiveJDBCTestHelper
@@ -34,7 +38,7 @@ class DataAgentE2ESuite extends HiveJDBCTestHelper with WithDataAgentEngine {
private val apiKey = sys.env.getOrElse("DATA_AGENT_LLM_API_KEY", "")
private val apiUrl = sys.env.getOrElse("DATA_AGENT_LLM_API_URL", "")
- private val modelName = sys.env.getOrElse("DATA_AGENT_LLM_MODEL", "gpt-4o")
+ private val modelName = sys.env.getOrElse("DATA_AGENT_LLM_MODEL", "")
private val dbPath =
s"${System.getProperty("java.io.tmpdir")}/dataagent_e2e_test_${java.util.UUID.randomUUID()}.db"
@@ -107,33 +111,71 @@ class DataAgentE2ESuite extends HiveJDBCTestHelper with WithDataAgentEngine {
new java.io.File(dbPath).delete()
}
+ private val mapper = new ObjectMapper()
+
+ private def drainReply(rs: java.sql.ResultSet): String = {
+ val sb = new StringBuilder
+ while (rs.next()) {
+ sb.append(rs.getString("reply"))
+ }
+ val stream = sb.toString()
+ info(s"Agent event stream: $stream")
+ stream
+ }
+
+ /**
+ * The JDBC `reply` column is a concatenated stream of SSE events
+ * (`agent_start`, `tool_call`, `tool_result`, `content_delta`, ...). Only
+ * `content_delta.text` is actual model output - this pulls those out and
+ * joins them to recover the final natural-language answer.
+ */
+ private def extractAnswer(eventStream: String): String = {
+ val parser = mapper.getFactory.createParser(eventStream)
+ val sb = new StringBuilder
+ try {
+ mapper.readValues(parser, classOf[JsonNode]).asScala.foreach { node =>
+ if ("content_delta" == node.path("type").asText()) {
+ sb.append(node.path("text").asText(""))
+ }
+ }
+ } finally {
+ parser.close()
+ }
+ sb.toString()
+ }
+
+ private val strictFormatHint =
+ "Respond with ONLY the answer, no explanation, no markdown, no punctuation."
+
test("E2E: agent answers data question through full Kyuubi pipeline") {
assume(enabled, "DATA_AGENT_LLM_API_KEY/API_URL not set, skipping E2E tests")
- // scalastyle:off println
withJdbcStatement() { stmt =>
- // Ask a question that requires schema exploration + SQL execution
- val result = stmt.executeQuery(
- "Which department has the highest average salary?")
-
- val sb = new StringBuilder
- while (result.next()) {
- val chunk = result.getString("reply")
- sb.append(chunk)
- print(chunk) // real-time output for debugging
- }
- println()
-
- val reply = sb.toString()
-
- // The agent should have:
- // 1. Explored the schema (mentioned table names or columns)
- // 2. Executed SQL (the reply should contain actual data)
- // 3. Answered with "Engineering" (avg salary 30000)
- assert(reply.nonEmpty, "Agent should return a non-empty response")
- assert(
- reply.toLowerCase.contains("engineering") || reply.contains("30000"),
- s"Expected the answer to mention 'Engineering' or '30000', got: ${reply.take(500)}")
+ val stream = drainReply(
+ stmt.executeQuery(
+ s"Which department has the highest average salary? $strictFormatHint"))
+ assert(extractAnswer(stream) == "Engineering")
+ }
+ }
+
+ test("E2E: agent resolves follow-up question using prior conversation context") {
+ assume(enabled, "DATA_AGENT_LLM_API_KEY/API_URL not set, skipping E2E tests")
+ // Two executeQuery calls on the same Statement share the JDBC session, which means
+ // the provider reuses the same ConversationMemory across turns. Turn 2 uses the
+ // demonstrative "that department" - it can only be answered correctly if Turn 1's
+ // answer (Engineering) is carried over in the agent's conversation history.
+ withJdbcStatement() { stmt =>
+ val stream1 = drainReply(
+ stmt.executeQuery(
+ s"Which department has the highest average salary? $strictFormatHint"))
+ assert(extractAnswer(stream1) == "Engineering")
+
+ // Engineering has 3 employees (Alice, Bob, Frank). If memory is not shared
+ // the agent cannot resolve "that department" and cannot produce the exact
+ // integer 3 - nothing in Turn 2's prompt points to Engineering.
+ val stream2 = drainReply(
+ stmt.executeQuery(
+ s"How many employees work in that department? $strictFormatHint"))
+ assert(extractAnswer(stream2) == "3")
}
- // scalastyle:on println
}
}
diff --git a/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala b/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala
index 215a76e26d4..31d308e9c67 100644
--- a/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala
+++ b/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala
@@ -3848,6 +3848,21 @@ object KyuubiConf {
.checkValue(_ > 0, "must be positive number")
.createWithDefault(100)
+ val ENGINE_DATA_AGENT_COMPACTION_TRIGGER_TOKENS: ConfigEntry[Long] =
+ buildConf("kyuubi.engine.data.agent.compaction.trigger.tokens")
+ .doc("The prompt-token threshold above which the Data Agent's compaction middleware " +
+ "summarizes older conversation history into a compact message. The check is made each " +
+ "turn as " +
+ "real_prompt_tokens_of_previous_LLM_call + estimate_of_newly_appended_tail; " +
+ "when this predicted prompt size reaches the configured value, older messages are " +
+ "replaced by a single summary message while the most recent exchanges are kept verbatim. " +
+ "Set to a very large value (e.g., 9223372036854775807) to effectively " +
+ "disable compaction.")
+ .version("1.12.0")
+ .longConf
+ .checkValue(_ > 0, "must be positive number")
+ .createWithDefault(128000L)
+
val ENGINE_DATA_AGENT_QUERY_TIMEOUT: ConfigEntry[Long] =
buildConf("kyuubi.engine.data.agent.query.timeout")
.doc("The JDBC query execution timeout for the Data Agent SQL tools. " +