diff --git a/docs/configuration/settings.md b/docs/configuration/settings.md index 3e34882b28f..a2910abfca6 100644 --- a/docs/configuration/settings.md +++ b/docs/configuration/settings.md @@ -144,6 +144,7 @@ You can configure the Kyuubi properties in `$KYUUBI_HOME/conf/kyuubi-defaults.co | kyuubi.engine.chat.provider | ECHO | The provider for the Chat engine. Candidates: | string | 1.8.0 | | kyuubi.engine.connection.url.use.hostname | true | (deprecated) When true, the engine registers with hostname to zookeeper. When Spark runs on K8s with cluster mode, set to false to ensure that server can connect to engine | boolean | 1.3.0 | | kyuubi.engine.data.agent.approval.mode | NORMAL | Default approval mode for tool execution in the Data Agent engine. Candidates: | string | 1.12.0 | +| kyuubi.engine.data.agent.compaction.trigger.tokens | 128000 | The prompt-token threshold above which the Data Agent's compaction middleware summarizes older conversation history into a compact message. The check is made each turn as real_prompt_tokens_of_previous_LLM_call + estimate_of_newly_appended_tail; when this predicted prompt size reaches the configured value, older messages are replaced by a single summary message while the most recent exchanges are kept verbatim. Set to a very large value (e.g., 9223372036854775807) to effectively disable compaction. | long | 1.12.0 | | kyuubi.engine.data.agent.extra.classpath | <undefined> | The extra classpath for the Data Agent engine, for configuring the location of the LLM SDK and etc. | string | 1.12.0 | | kyuubi.engine.data.agent.java.options | <undefined> | The extra Java options for the Data Agent engine | string | 1.12.0 | | kyuubi.engine.data.agent.jdbc.url | <undefined> | The JDBC URL for the Data Agent engine to connect to the target database. If not set, the Data Agent will connect back to Kyuubi server via ZooKeeper service discovery. | string | 1.12.0 | diff --git a/externals/kyuubi-data-agent-engine/pom.xml b/externals/kyuubi-data-agent-engine/pom.xml index c34d049360c..43a5008dda9 100644 --- a/externals/kyuubi-data-agent-engine/pom.xml +++ b/externals/kyuubi-data-agent-engine/pom.xml @@ -50,45 +50,63 @@ ${project.version} + com.openai openai-java + ${openai.sdk.version} + com.github.victools jsonschema-generator + ${victools.jsonschema.version} - com.github.victools jsonschema-module-jackson + ${victools.jsonschema.version} - + - org.apache.kyuubi - kyuubi-common_${scala.binary.version} - ${project.version} - test-jar - test + org.xerial + sqlite-jdbc + ${sqlite.version} + - org.xerial - sqlite-jdbc + com.mysql + mysql-connector-j test + - org.testcontainers - testcontainers-mysql + io.trino + trino-jdbc + + + + + com.zaxxer + HikariCP + + + + + org.apache.kyuubi + kyuubi-common_${scala.binary.version} + ${project.version} + test-jar test - com.mysql - mysql-connector-j + org.testcontainers + testcontainers-mysql test diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialect.java index c3be1dad61a..c771ad222aa 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialect.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialect.java @@ -17,6 +17,12 @@ package org.apache.kyuubi.engine.dataagent.datasource; +import org.apache.kyuubi.engine.dataagent.datasource.dialect.GenericDialect; +import org.apache.kyuubi.engine.dataagent.datasource.dialect.MysqlDialect; +import org.apache.kyuubi.engine.dataagent.datasource.dialect.SparkDialect; +import org.apache.kyuubi.engine.dataagent.datasource.dialect.SqliteDialect; +import org.apache.kyuubi.engine.dataagent.datasource.dialect.TrinoDialect; + /** * SQL dialect abstraction for datasource-specific SQL generation. * diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/GenericDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/GenericDialect.java similarity index 92% rename from externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/GenericDialect.java rename to externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/GenericDialect.java index 3ea22ed54e3..d8c4512de03 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/GenericDialect.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/GenericDialect.java @@ -15,7 +15,9 @@ * limitations under the License. */ -package org.apache.kyuubi.engine.dataagent.datasource; +package org.apache.kyuubi.engine.dataagent.datasource.dialect; + +import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect; /** * Fallback dialect for JDBC subprotocols that have no dedicated implementation. Carries the diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/MysqlDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/MysqlDialect.java similarity index 85% rename from externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/MysqlDialect.java rename to externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/MysqlDialect.java index 98747ffa30c..350789a6a87 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/MysqlDialect.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/MysqlDialect.java @@ -15,12 +15,14 @@ * limitations under the License. */ -package org.apache.kyuubi.engine.dataagent.datasource; +package org.apache.kyuubi.engine.dataagent.datasource.dialect; + +import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect; /** MySQL dialect. Uses backtick quoting for identifiers. */ public final class MysqlDialect implements JdbcDialect { - static final MysqlDialect INSTANCE = new MysqlDialect(); + public static final MysqlDialect INSTANCE = new MysqlDialect(); private MysqlDialect() {} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SparkDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SparkDialect.java similarity index 85% rename from externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SparkDialect.java rename to externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SparkDialect.java index 3adb43fa398..34e20034bfb 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SparkDialect.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SparkDialect.java @@ -15,12 +15,14 @@ * limitations under the License. */ -package org.apache.kyuubi.engine.dataagent.datasource; +package org.apache.kyuubi.engine.dataagent.datasource.dialect; + +import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect; /** Spark SQL dialect. Uses backtick quoting for identifiers. */ public final class SparkDialect implements JdbcDialect { - static final SparkDialect INSTANCE = new SparkDialect(); + public static final SparkDialect INSTANCE = new SparkDialect(); private SparkDialect() {} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SqliteDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SqliteDialect.java similarity index 85% rename from externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SqliteDialect.java rename to externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SqliteDialect.java index a53255a9c67..eb98ca8edfa 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/SqliteDialect.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/SqliteDialect.java @@ -15,12 +15,14 @@ * limitations under the License. */ -package org.apache.kyuubi.engine.dataagent.datasource; +package org.apache.kyuubi.engine.dataagent.datasource.dialect; + +import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect; /** SQLite dialect. Uses double-quote quoting for identifiers. */ public final class SqliteDialect implements JdbcDialect { - static final SqliteDialect INSTANCE = new SqliteDialect(); + public static final SqliteDialect INSTANCE = new SqliteDialect(); private SqliteDialect() {} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/TrinoDialect.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/TrinoDialect.java similarity index 85% rename from externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/TrinoDialect.java rename to externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/TrinoDialect.java index edacf2f87e2..75fbd4bb242 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/TrinoDialect.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/datasource/dialect/TrinoDialect.java @@ -15,12 +15,14 @@ * limitations under the License. */ -package org.apache.kyuubi.engine.dataagent.datasource; +package org.apache.kyuubi.engine.dataagent.datasource.dialect; + +import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect; /** Trino SQL dialect. Uses double-quote quoting for identifiers. */ public final class TrinoDialect implements JdbcDialect { - static final TrinoDialect INSTANCE = new TrinoDialect(); + public static final TrinoDialect INSTANCE = new TrinoDialect(); private TrinoDialect() {} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/ProviderRunRequest.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/ProviderRunRequest.java index f4e40b2fae8..26ad8be77fb 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/ProviderRunRequest.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/ProviderRunRequest.java @@ -17,13 +17,23 @@ package org.apache.kyuubi.engine.dataagent.provider; +import org.apache.kyuubi.engine.dataagent.runtime.ApprovalMode; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + /** * User-facing request parameters for a provider-level agent invocation. Only contains fields from * the caller (question, model override, etc.). Adding new per-request options does not require * changing the {@link DataAgentProvider} interface. + * + *

The approval mode is accepted as a raw string (natural for config-driven callers) and parsed + * into {@link ApprovalMode} by {@link #getApprovalMode()}. Unrecognised values fall back to {@link + * ApprovalMode#NORMAL} with a warning. */ public class ProviderRunRequest { + private static final Logger LOG = LoggerFactory.getLogger(ProviderRunRequest.class); + private final String question; private String modelName; private String approvalMode; @@ -45,8 +55,20 @@ public ProviderRunRequest modelName(String modelName) { return this; } - public String getApprovalMode() { - return approvalMode; + /** + * Resolved approval mode. Returns {@link ApprovalMode#NORMAL} when the caller did not set one or + * supplied an unknown value. + */ + public ApprovalMode getApprovalMode() { + if (approvalMode == null || approvalMode.isEmpty()) { + return ApprovalMode.NORMAL; + } + try { + return ApprovalMode.valueOf(approvalMode.toUpperCase()); + } catch (IllegalArgumentException e) { + LOG.warn("Unknown approval mode '{}', using default NORMAL", approvalMode); + return ApprovalMode.NORMAL; + } } public ProviderRunRequest approvalMode(String approvalMode) { diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/openai/OpenAiProvider.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/openai/OpenAiProvider.java new file mode 100644 index 00000000000..bcd647b9326 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/provider/openai/OpenAiProvider.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.provider.openai; + +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.zaxxer.hikari.HikariDataSource; +import java.time.Duration; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; +import javax.sql.DataSource; +import org.apache.kyuubi.config.KyuubiConf; +import org.apache.kyuubi.config.KyuubiReservedKeys; +import org.apache.kyuubi.engine.dataagent.datasource.DataSourceFactory; +import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect; +import org.apache.kyuubi.engine.dataagent.prompt.SystemPromptBuilder; +import org.apache.kyuubi.engine.dataagent.provider.DataAgentProvider; +import org.apache.kyuubi.engine.dataagent.provider.ProviderRunRequest; +import org.apache.kyuubi.engine.dataagent.runtime.AgentInvocation; +import org.apache.kyuubi.engine.dataagent.runtime.ConversationMemory; +import org.apache.kyuubi.engine.dataagent.runtime.ReactAgent; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.ApprovalMiddleware; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.CompactionMiddleware; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.LoggingMiddleware; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.ToolResultOffloadMiddleware; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; +import org.apache.kyuubi.engine.dataagent.tool.sql.RunMutationQueryTool; +import org.apache.kyuubi.engine.dataagent.tool.sql.RunSelectQueryTool; +import org.apache.kyuubi.engine.dataagent.util.ConfUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An OpenAI-compatible provider that wires up the full ReactAgent with streaming LLM, tools, and + * middleware pipeline. Uses the official OpenAI Java SDK. + * + *

The ReactAgent, DataSource, and ToolRegistry are shared across all sessions within this engine + * instance. Each session only maintains its own {@link ConversationMemory}. This works because each + * engine is bound to one user + one datasource, so all sessions within the engine naturally share + * the same data connection. + */ +public class OpenAiProvider implements DataAgentProvider { + + private static final Logger LOG = LoggerFactory.getLogger(OpenAiProvider.class); + + private final ReactAgent agent; + private final ToolRegistry toolRegistry; + private final DataSource dataSource; + private final OpenAIClient client; + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + public OpenAiProvider(KyuubiConf conf) { + String apiKey = ConfUtils.requireString(conf, KyuubiConf.ENGINE_DATA_AGENT_LLM_API_KEY()); + String baseUrl = ConfUtils.requireString(conf, KyuubiConf.ENGINE_DATA_AGENT_LLM_API_URL()); + String modelName = ConfUtils.requireString(conf, KyuubiConf.ENGINE_DATA_AGENT_LLM_MODEL()); + + int maxIterations = ConfUtils.intConf(conf, KyuubiConf.ENGINE_DATA_AGENT_MAX_ITERATIONS()); + long compactionTriggerTokens = + ConfUtils.longConf(conf, KyuubiConf.ENGINE_DATA_AGENT_COMPACTION_TRIGGER_TOKENS()); + int queryTimeoutSeconds = + (int) ConfUtils.millisAsSeconds(conf, KyuubiConf.ENGINE_DATA_AGENT_QUERY_TIMEOUT()); + long toolCallTimeoutSeconds = + ConfUtils.millisAsSeconds(conf, KyuubiConf.ENGINE_DATA_AGENT_TOOL_CALL_TIMEOUT()); + + this.client = + OpenAIOkHttpClient.builder() + .apiKey(apiKey) + .baseUrl(baseUrl) + .maxRetries(3) + .timeout(Duration.ofSeconds(180)) + .build(); + + this.toolRegistry = new ToolRegistry(toolCallTimeoutSeconds); + + SystemPromptBuilder promptBuilder = SystemPromptBuilder.create(); + this.dataSource = attachJdbcDataSource(conf, toolRegistry, promptBuilder, queryTimeoutSeconds); + + this.agent = + ReactAgent.builder() + .client(client) + .modelName(modelName) + .toolRegistry(toolRegistry) + .addMiddleware(new ToolResultOffloadMiddleware()) + .addMiddleware(new LoggingMiddleware()) + .addMiddleware(new CompactionMiddleware(client, modelName, compactionTriggerTokens)) + .addMiddleware(new ApprovalMiddleware()) + .maxIterations(maxIterations) + .systemPrompt(promptBuilder.build()) + .build(); + } + + /** + * Register JDBC-backed SQL tools if a JDBC URL is configured. Returns the created {@link + * DataSource} so the provider can close it on shutdown, or {@code null} when no JDBC is wired. + */ + private static DataSource attachJdbcDataSource( + KyuubiConf conf, + ToolRegistry registry, + SystemPromptBuilder promptBuilder, + int queryTimeoutSeconds) { + String jdbcUrl = ConfUtils.optionalString(conf, KyuubiConf.ENGINE_DATA_AGENT_JDBC_URL()); + if (jdbcUrl == null) { + return null; + } + LOG.info("Data Agent JDBC URL configured ({})", jdbcUrl.replaceAll("//.*@", "//@")); + + String sessionUser = + ConfUtils.optionalString(conf, KyuubiReservedKeys.KYUUBI_SESSION_USER_KEY()); + + DataSource ds = DataSourceFactory.create(jdbcUrl, sessionUser); + registry.register(new RunSelectQueryTool(ds, queryTimeoutSeconds)); + registry.register(new RunMutationQueryTool(ds, queryTimeoutSeconds)); + promptBuilder.datasource(JdbcDialect.fromUrl(jdbcUrl).datasourceName()); + return ds; + } + + @Override + public void open(String sessionId, String user) { + sessions.put(sessionId, new ConversationMemory()); + LOG.info("Opened Data Agent session {} for user {}", sessionId, user); + } + + @Override + public void run(String sessionId, ProviderRunRequest request, Consumer onEvent) { + ConversationMemory memory = sessions.get(sessionId); + if (memory == null) { + throw new IllegalStateException("No open Data Agent session for id=" + sessionId); + } + + AgentInvocation invocation = + new AgentInvocation(request.getQuestion()) + .modelName(request.getModelName()) + .approvalMode(request.getApprovalMode()) + .sessionId(sessionId); + agent.run(invocation, memory, onEvent); + } + + @Override + public boolean resolveApproval(String requestId, boolean approved) { + return agent.resolveApproval(requestId, approved); + } + + @Override + public void close(String sessionId) { + sessions.remove(sessionId); + agent.closeSession(sessionId); + LOG.info("Closed Data Agent session {}", sessionId); + } + + @Override + public void stop() { + agent.stop(); + toolRegistry.close(); + if (dataSource instanceof HikariDataSource) { + ((HikariDataSource) dataSource).close(); + LOG.info("Closed Data Agent connection pool"); + } + client.close(); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentInvocation.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentInvocation.java new file mode 100644 index 00000000000..0695d556058 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentInvocation.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +import java.util.Objects; + +/** + * User-facing request parameters for a single agent invocation. Only contains fields that come from + * the caller (question, model override, etc.). Framework-level concerns like memory and event + * consumer are separate method parameters. + * + *

Adding new per-request options (e.g. temperature, maxTokens) does not require changing the + * {@code ReactAgent.run()} signature. + */ +public class AgentInvocation { + + private final String userInput; + private String modelName; + private ApprovalMode approvalMode = ApprovalMode.NORMAL; + private String sessionId; + + public AgentInvocation(String userInput) { + this.userInput = Objects.requireNonNull(userInput, "userInput must not be null"); + } + + public String getUserInput() { + return userInput; + } + + public String getModelName() { + return modelName; + } + + public AgentInvocation modelName(String modelName) { + this.modelName = modelName; + return this; + } + + public ApprovalMode getApprovalMode() { + return approvalMode; + } + + public AgentInvocation approvalMode(ApprovalMode approvalMode) { + this.approvalMode = approvalMode; + return this; + } + + public String getSessionId() { + return sessionId; + } + + /** Upstream session id, propagated into {@link AgentRunContext#getSessionId()}. */ + public AgentInvocation sessionId(String sessionId) { + this.sessionId = sessionId; + return this; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentRunContext.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentRunContext.java new file mode 100644 index 00000000000..e7c92df8033 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/AgentRunContext.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +import java.util.function.Consumer; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent; + +/** + * Mutable context passed through the middleware pipeline and agent loop. Tracks the current state + * of agent execution including iteration count, token usage, and custom middleware state. + */ +public class AgentRunContext { + + private final ConversationMemory memory; + private final String sessionId; + private Consumer eventEmitter; + private int iteration; + private long promptTokens; + private long completionTokens; + private long totalTokens; + private ApprovalMode approvalMode; + + public AgentRunContext(ConversationMemory memory, ApprovalMode approvalMode) { + this(memory, approvalMode, null); + } + + public AgentRunContext(ConversationMemory memory, ApprovalMode approvalMode, String sessionId) { + this.memory = memory; + this.iteration = 0; + this.approvalMode = approvalMode; + this.sessionId = sessionId; + } + + public ConversationMemory getMemory() { + return memory; + } + + /** + * The upstream session identifier this run belongs to. Threaded down from {@code + * DataAgentProvider.run(sessionId, ...)}. May be {@code null} in unit tests that do not exercise + * session-scoped middleware. + */ + public String getSessionId() { + return sessionId; + } + + public int getIteration() { + return iteration; + } + + public void setIteration(int iteration) { + this.iteration = iteration; + } + + public long getPromptTokens() { + return promptTokens; + } + + public long getCompletionTokens() { + return completionTokens; + } + + public long getTotalTokens() { + return totalTokens; + } + + /** + * Record one LLM call's usage. Updates both the per-run counters on this context and the + * session-level cumulative on the underlying {@link ConversationMemory}, so middlewares that need + * a session-wide picture can read it directly from memory without keeping their own bookkeeping. + */ + public void addTokenUsage(long prompt, long completion, long total) { + this.promptTokens += prompt; + this.completionTokens += completion; + this.totalTokens += total; + memory.addCumulativeTokens(prompt, completion, total); + } + + public ApprovalMode getApprovalMode() { + return approvalMode; + } + + public void setApprovalMode(ApprovalMode approvalMode) { + this.approvalMode = approvalMode; + } + + public void setEventEmitter(Consumer eventEmitter) { + this.eventEmitter = eventEmitter; + } + + /** Emit an event through the agent's event pipeline. Available for use by middlewares. */ + public void emit(AgentEvent event) { + if (eventEmitter != null) { + eventEmitter.accept(event); + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ApprovalMode.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ApprovalMode.java new file mode 100644 index 00000000000..57bc20bc2bb --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ApprovalMode.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +/** Approval modes for tool execution in the Data Agent engine. */ +public enum ApprovalMode { + /** All tools require explicit user approval. */ + STRICT, + /** Only non-readonly tools require approval. */ + NORMAL, + /** All tools are auto-approved. */ + AUTO_APPROVE +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemory.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemory.java new file mode 100644 index 00000000000..0bfae26ec27 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemory.java @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import com.openai.models.chat.completions.ChatCompletionSystemMessageParam; +import com.openai.models.chat.completions.ChatCompletionToolMessageParam; +import com.openai.models.chat.completions.ChatCompletionUserMessageParam; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * Manages conversation history for a Data Agent session. Ensures tool result messages are never + * orphaned from their corresponding AI messages. + * + *

Each instance is session-scoped and accessed sequentially within a single ReAct loop — no + * synchronization is needed. Cross-session concurrency is handled by the provider's session map. + */ +public class ConversationMemory { + + /** + * System prompt prepended to every LLM call built from this memory. Rebuilt by the provider on + * each invocation (datasource/tool metadata can change between turns), so it lives outside the + * {@link #messages} list rather than being inserted as the first entry. May be {@code null} until + * the first {@link #setSystemPrompt} call — {@link #buildLlmMessages()} will simply omit the + * system slot in that case. + */ + private String systemPrompt; + + /** + * The raw content of the most recent user message added via {@link #addUserMessage}. Cached + * separately from {@link #messages} so middleware and callers can recover the current turn's + * question after compaction rewrites history. Not used by the LLM call path itself. + */ + private String lastUserInput; + + /** + * The ordered conversation history: user / assistant / tool messages, in the order the LLM will + * see them. The system prompt is intentionally NOT stored here (see {@link #systemPrompt}). + * Mutated in place by {@link #replaceHistory} during compaction; otherwise append-only. + */ + private final List messages = new ArrayList<>(); + + /** + * Session-level running total of {@code prompt_tokens} reported by every LLM call on this + * conversation (across ReAct turns). Intended for billing, quota, and observability — not used by + * any runtime decision. Updated via {@link #addCumulativeTokens}. + */ + private long cumulativePromptTokens; + + /** + * Session-level running total of {@code completion_tokens}. See {@link #cumulativePromptTokens}. + */ + private long cumulativeCompletionTokens; + + /** + * Session-level running total of {@code total_tokens} (prompt + completion as reported by the + * provider — not necessarily the sum of the two counters above, since providers may count + * cached/reasoning tokens differently). See {@link #cumulativePromptTokens}. + */ + private long cumulativeTotalTokens; + + /** + * The {@code total_tokens} reported by the single most recent LLM call, or {@code 0} if no call + * has completed yet. Distinct from the cumulative counters: this is a snapshot, overwritten every + * call. Used by {@link + * org.apache.kyuubi.engine.dataagent.runtime.middleware.CompactionMiddleware} to estimate the + * next prompt size (the last response becomes part of the next prompt, so the next call's prompt + * is at least {@code lastTotalTokens}). Persists across ReAct turns until the next call + * overwrites it. + */ + private long lastTotalTokens; + + public ConversationMemory() {} + + public String getSystemPrompt() { + return systemPrompt; + } + + public void setSystemPrompt(String prompt) { + this.systemPrompt = prompt; + } + + public void addUserMessage(String content) { + this.lastUserInput = content; + messages.add( + ChatCompletionMessageParam.ofUser( + ChatCompletionUserMessageParam.builder().content(content).build())); + } + + public String getLastUserInput() { + return lastUserInput; + } + + public void addAssistantMessage(ChatCompletionAssistantMessageParam message) { + messages.add(ChatCompletionMessageParam.ofAssistant(message)); + } + + public void addToolResult(String toolCallId, String content) { + messages.add( + ChatCompletionMessageParam.ofTool( + ChatCompletionToolMessageParam.builder() + .toolCallId(toolCallId) + .content(content) + .build())); + } + + /** + * Build the full message list for LLM API invocation: [system prompt] + conversation history. + * + *

No windowing is applied — callers are responsible for managing context length (e.g. via a + * token-based truncation strategy). + * + * @see #getHistory() for history-only access without system prompt + */ + public List buildLlmMessages() { + List result = new ArrayList<>(messages.size() + 1); + if (systemPrompt != null) { + result.add( + ChatCompletionMessageParam.ofSystem( + ChatCompletionSystemMessageParam.builder().content(systemPrompt).build())); + } + result.addAll(messages); + return Collections.unmodifiableList(result); + } + + /** + * Returns the conversation history (user, assistant, tool messages) without the system prompt. + * Useful for middleware that needs to inspect or compact history. + */ + public List getHistory() { + return Collections.unmodifiableList(new ArrayList<>(messages)); + } + + /** + * Replace the conversation history with a compacted version. Useful for context-length management + * strategies (e.g., summarizing older messages). + * + *

Also clears {@link #lastTotalTokens}: the prior snapshot referred to a prompt whose bulk we + * just discarded, so it no longer describes anything in memory. Leaving it stale would keep the + * compaction trigger armed until the next successful LLM call overwrites it — fine on the happy + * path, but if that call fails the next turn would re-enter compaction against already-compacted + * history. Zeroing means "unknown, wait for the next real usage report". Cumulative totals are + * intentionally preserved (session-level accounting, must not regress on internal compaction). + */ + public void replaceHistory(List compacted) { + messages.clear(); + messages.addAll(compacted); + this.lastTotalTokens = 0; + } + + public void clear() { + messages.clear(); + } + + public int size() { + return messages.size(); + } + + public long getCumulativePromptTokens() { + return cumulativePromptTokens; + } + + public long getCumulativeCompletionTokens() { + return cumulativeCompletionTokens; + } + + public long getCumulativeTotalTokens() { + return cumulativeTotalTokens; + } + + public long getLastTotalTokens() { + return lastTotalTokens; + } + + /** Add one LLM call's usage to the session cumulative. Intended for {@link AgentRunContext}. */ + public void addCumulativeTokens(long prompt, long completion, long total) { + this.cumulativePromptTokens += prompt; + this.cumulativeCompletionTokens += completion; + this.cumulativeTotalTokens += total; + this.lastTotalTokens = total; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgent.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgent.java new file mode 100644 index 00000000000..520cd963ef8 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgent.java @@ -0,0 +1,606 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.openai.client.OpenAIClient; +import com.openai.core.http.StreamResponse; +import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; +import com.openai.models.chat.completions.ChatCompletionChunk; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import com.openai.models.chat.completions.ChatCompletionMessageFunctionToolCall; +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageToolCall; +import com.openai.models.chat.completions.ChatCompletionStreamOptions; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentError; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentFinish; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentStart; +import org.apache.kyuubi.engine.dataagent.runtime.event.ContentComplete; +import org.apache.kyuubi.engine.dataagent.runtime.event.ContentDelta; +import org.apache.kyuubi.engine.dataagent.runtime.event.StepEnd; +import org.apache.kyuubi.engine.dataagent.runtime.event.StepStart; +import org.apache.kyuubi.engine.dataagent.runtime.event.ToolCall; +import org.apache.kyuubi.engine.dataagent.runtime.event.ToolResult; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.AgentMiddleware; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.ApprovalMiddleware; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * ReAct (Reasoning + Acting) agent loop using the OpenAI official Java SDK. Iterates through LLM + * reasoning, tool execution, and result verification until the agent produces a final answer or + * hits the iteration limit. + * + *

Emits {@link AgentEvent}s via the provided consumer for real-time token-level streaming. + */ +public class ReactAgent { + + private static final Logger LOG = LoggerFactory.getLogger(ReactAgent.class); + private static final ObjectMapper JSON = new ObjectMapper(); + + private final OpenAIClient client; + private final String defaultModelName; + private final ToolRegistry toolRegistry; + private final List middlewares; + private final ApprovalMiddleware approvalMiddleware; + private final int maxIterations; + private final String systemPrompt; + + public ReactAgent( + OpenAIClient client, + String modelName, + ToolRegistry toolRegistry, + List middlewares, + int maxIterations, + String systemPrompt) { + this.client = client; + this.defaultModelName = modelName; + this.toolRegistry = toolRegistry; + this.middlewares = middlewares; + this.approvalMiddleware = findApprovalMiddleware(middlewares); + this.maxIterations = maxIterations; + this.systemPrompt = systemPrompt; + } + + private static ApprovalMiddleware findApprovalMiddleware(List middlewares) { + for (AgentMiddleware mw : middlewares) { + if (mw instanceof ApprovalMiddleware) return (ApprovalMiddleware) mw; + } + return null; + } + + /** Resolve a pending approval request. Returns false if no pending request matches. */ + public boolean resolveApproval(String requestId, boolean approved) { + if (approvalMiddleware == null) return false; + return approvalMiddleware.resolve(requestId, approved); + } + + /** Fan out session-close to every middleware. Errors in one middleware don't block others. */ + public void closeSession(String sessionId) { + for (AgentMiddleware mw : middlewares) { + try { + mw.onSessionClose(sessionId); + } catch (Exception e) { + LOG.warn("Middleware onSessionClose error", e); + } + } + } + + /** Fan out engine-stop to every middleware. Errors in one middleware don't block others. */ + public void stop() { + for (AgentMiddleware mw : middlewares) { + try { + mw.onStop(); + } catch (Exception e) { + LOG.warn("Middleware onStop error", e); + } + } + } + + /** + * Run the ReAct loop for the given request. + * + * @param request user-facing parameters (question, model override, etc.) + * @param memory the conversation memory (may contain prior context) + * @param eventConsumer callback for each agent event (token-level streaming) + */ + public void run( + AgentInvocation request, ConversationMemory memory, Consumer eventConsumer) { + String userInput = request.getUserInput(); + ApprovalMode approvalMode = request.getApprovalMode(); + String modelNameOverride = request.getModelName(); + + String effectiveModel = + (modelNameOverride != null && !modelNameOverride.isEmpty()) + ? modelNameOverride + : defaultModelName; + + // System prompt is immutable for the lifetime of this agent — only set it on first run + // to avoid redundant overwrites on multi-turn conversations. + if (memory.getSystemPrompt() == null) { + memory.setSystemPrompt(systemPrompt); + } + memory.addUserMessage(userInput); + + AgentRunContext ctx = new AgentRunContext(memory, approvalMode, request.getSessionId()); + ctx.setEventEmitter(event -> emit(ctx, event, eventConsumer)); + dispatchAgentStart(ctx); + emit(ctx, new AgentStart(), eventConsumer); + + try { + for (int step = 1; step <= maxIterations; step++) { + ctx.setIteration(step); + emit(ctx, new StepStart(step), eventConsumer); + + List messages = + resolveMessagesForCall(ctx, memory.buildLlmMessages(), eventConsumer); + if (messages == null) { + // Middleware asked us to skip — AgentError + AgentFinish have already been emitted. + return; + } + + StreamResult result = streamLlmResponse(ctx, messages, effectiveModel, eventConsumer); + if (result.isEmpty()) { + emit(ctx, new AgentError("LLM returned empty response"), eventConsumer); + emitFinish(ctx, eventConsumer); + return; + } + + if (!result.content.isEmpty()) { + emit(ctx, new ContentComplete(result.content), eventConsumer); + } + ChatCompletionAssistantMessageParam assistantMsg = buildAssistantMessage(result); + memory.addAssistantMessage(assistantMsg); + dispatchAfterLlmCall(ctx, assistantMsg); + + if (result.toolCalls == null || result.toolCalls.isEmpty()) { + // No tool calls — agent is done. + emit(ctx, new StepEnd(step), eventConsumer); + emitFinish(ctx, eventConsumer); + return; + } + + executeToolCalls(ctx, memory, result.toolCalls, eventConsumer); + emit(ctx, new StepEnd(step), eventConsumer); + } + + emit( + ctx, new AgentError("Reached maximum iterations (" + maxIterations + ")"), eventConsumer); + emitFinish(ctx, eventConsumer); + + } catch (Exception e) { + LOG.error("Agent run failed", e); + emit( + ctx, new AgentError(e.getClass().getSimpleName() + ": " + e.getMessage()), eventConsumer); + emitFinish(ctx, eventConsumer); + } finally { + dispatchAgentFinish(ctx); + } + } + + /** + * Run {@code beforeLlmCall} middleware against {@code messages}. Returns the messages to send, + * possibly rewritten by middleware, or {@code null} if middleware aborted the call (in which case + * this method has already emitted the terminal events). + */ + private List resolveMessagesForCall( + AgentRunContext ctx, + List messages, + Consumer eventConsumer) { + AgentMiddleware.LlmCallAction action = dispatchBeforeLlmCall(ctx, messages); + if (action instanceof AgentMiddleware.LlmSkip) { + String reason = ((AgentMiddleware.LlmSkip) action).reason(); + emit(ctx, new AgentError("LLM call skipped by middleware: " + reason), eventConsumer); + emitFinish(ctx, eventConsumer); + return null; + } + if (action instanceof AgentMiddleware.LlmModifyMessages) { + return ((AgentMiddleware.LlmModifyMessages) action).messages(); + } + return messages; + } + + private static ChatCompletionAssistantMessageParam buildAssistantMessage(StreamResult result) { + ChatCompletionAssistantMessageParam.Builder b = ChatCompletionAssistantMessageParam.builder(); + if (!result.content.isEmpty()) { + b.content(result.content); + } + if (result.toolCalls != null && !result.toolCalls.isEmpty()) { + b.toolCalls(result.toolCalls); + } + return b.build(); + } + + /** + * Execute the assistant's tool calls in 3 phases: + * + *

    + *
  1. Serial: run {@code beforeToolCall} middleware, emit {@link ToolCall} events, and collect + * the calls that survived approval. + *
  2. Concurrent: fan out to {@link ToolRegistry#submitTool}, which always returns a future + * that completes normally — timeouts and execution errors surface as error strings. + *
  3. Serial: join futures in order, run {@code afterToolCall}, and record results to memory. + *
+ */ + private void executeToolCalls( + AgentRunContext ctx, + ConversationMemory memory, + List toolCalls, + Consumer eventConsumer) { + List approved = new ArrayList<>(); + for (ChatCompletionMessageToolCall toolCall : toolCalls) { + ChatCompletionMessageFunctionToolCall fnCall = toolCall.asFunction(); + String toolName = fnCall.function().name(); + Map toolArgs; + try { + toolArgs = parseToolArgs(fnCall.function().arguments()); + } catch (IllegalArgumentException e) { + // Malformed JSON from the LLM: record an error tool_result (preserves the + // assistant/tool_result pairing the next API call needs) and let the loop self-correct. + String err = "Tool call failed: " + e.getMessage(); + memory.addToolResult(fnCall.id(), err); + emit(ctx, new ToolResult(fnCall.id(), toolName, err, true), eventConsumer); + continue; + } + + AgentMiddleware.ToolCallDenial denial = + dispatchBeforeToolCall(ctx, fnCall.id(), toolName, toolArgs); + if (denial != null) { + String denied = "Tool call denied: " + denial.reason(); + memory.addToolResult(fnCall.id(), denied); + emit(ctx, new ToolResult(fnCall.id(), toolName, denied, true), eventConsumer); + continue; + } + + emit(ctx, new ToolCall(fnCall.id(), toolName, toolArgs), eventConsumer); + approved.add(new ToolCallEntry(fnCall, toolName, toolArgs)); + } + + ToolContext toolCtx = new ToolContext(ctx.getSessionId()); + List> futures = new ArrayList<>(approved.size()); + for (ToolCallEntry entry : approved) { + futures.add( + toolRegistry.submitTool(entry.toolName, entry.fnCall.function().arguments(), toolCtx)); + } + + for (int i = 0; i < approved.size(); i++) { + ToolCallEntry entry = approved.get(i); + String output = futures.get(i).join(); + String modified = dispatchAfterToolCall(ctx, entry.toolName, entry.toolArgs, output); + if (modified != null) { + output = modified; + } + memory.addToolResult(entry.fnCall.id(), output); + emit(ctx, new ToolResult(entry.fnCall.id(), entry.toolName, output, false), eventConsumer); + } + } + + /** Result of a streaming LLM call, assembled from chunks. */ + private static class StreamResult { + final String content; + final List toolCalls; + + StreamResult(String content, List toolCalls) { + this.content = content; + this.toolCalls = toolCalls; + } + + boolean isEmpty() { + return content.isEmpty() && (toolCalls == null || toolCalls.isEmpty()); + } + } + + /** Holds an approved tool call's parsed metadata for the 3-phase execution pipeline. */ + private static class ToolCallEntry { + final ChatCompletionMessageFunctionToolCall fnCall; + final String toolName; + final Map toolArgs; + + ToolCallEntry( + ChatCompletionMessageFunctionToolCall fnCall, + String toolName, + Map toolArgs) { + this.fnCall = fnCall; + this.toolName = toolName; + this.toolArgs = toolArgs; + } + } + + /** + * Stream LLM response, emitting ContentDelta for each text chunk. Assembles tool calls directly + * from streamed chunks — no non-streaming fallback. Exceptions propagate to the top-level handler + * in {@link #run}. + */ + private StreamResult streamLlmResponse( + AgentRunContext ctx, + List messages, + String effectiveModel, + Consumer eventConsumer) { + ChatCompletionCreateParams.Builder paramsBuilder = + ChatCompletionCreateParams.builder() + .model(effectiveModel) + .streamOptions(ChatCompletionStreamOptions.builder().includeUsage(true).build()); + for (ChatCompletionMessageParam msg : messages) { + paramsBuilder.addMessage(msg); + } + toolRegistry.addToolsTo(paramsBuilder); + + LOG.info("LLM request: model={}", effectiveModel); + StreamAccumulator acc = new StreamAccumulator(); + try (StreamResponse stream = + client.chat().completions().createStreaming(paramsBuilder.build())) { + stream.stream().forEach(chunk -> consumeChunk(ctx, chunk, acc, eventConsumer)); + } + return new StreamResult(acc.content.toString(), acc.buildToolCalls()); + } + + /** Fold one streaming chunk into {@code acc}, emitting per-token {@link ContentDelta}s. */ + private void consumeChunk( + AgentRunContext ctx, + ChatCompletionChunk chunk, + StreamAccumulator acc, + Consumer eventConsumer) { + if (!acc.serverModelLogged) { + LOG.info("LLM response: server-echoed model={}", chunk.model()); + acc.serverModelLogged = true; + } + chunk + .usage() + .ifPresent(u -> ctx.addTokenUsage(u.promptTokens(), u.completionTokens(), u.totalTokens())); + + for (ChatCompletionChunk.Choice c : chunk.choices()) { + c.delta() + .content() + .ifPresent( + text -> { + acc.content.append(text); + emit(ctx, new ContentDelta(text), eventConsumer); + }); + c.delta().toolCalls().ifPresent(acc::mergeToolCallDeltas); + } + } + + /** + * Mutable accumulator for a single streaming LLM turn. Tool call fields are keyed by the chunk's + * {@code index} because provider SDKs may deliver a single logical call across multiple chunks + * and only surface the {@code id}/{@code name} on the first one. + */ + private static final class StreamAccumulator { + final StringBuilder content = new StringBuilder(); + final Map toolCallIds = new HashMap<>(); + final Map toolCallNames = new HashMap<>(); + final Map toolCallArgs = new HashMap<>(); + boolean serverModelLogged = false; + + void mergeToolCallDeltas(List deltas) { + for (ChatCompletionChunk.Choice.Delta.ToolCall tc : deltas) { + int idx = (int) tc.index(); + tc.id().ifPresent(id -> toolCallIds.put(idx, id)); + tc.function() + .ifPresent( + fn -> { + fn.name().ifPresent(name -> toolCallNames.put(idx, name)); + fn.arguments() + .ifPresent( + args -> + toolCallArgs + .computeIfAbsent(idx, k -> new StringBuilder()) + .append(args)); + }); + } + } + + /** + * Materialize accumulated deltas into SDK tool-call objects. Returns {@code null} (not an empty + * list) if no tool calls were seen, matching the existing {@link StreamResult} contract. + */ + List buildToolCalls() { + if (toolCallIds.isEmpty()) return null; + List out = new ArrayList<>(toolCallIds.size()); + for (Map.Entry e : toolCallIds.entrySet()) { + int idx = e.getKey(); + String id = (e.getValue() == null || e.getValue().isEmpty()) ? synthId() : e.getValue(); + String args = toolCallArgs.containsKey(idx) ? toolCallArgs.get(idx).toString() : "{}"; + out.add( + ChatCompletionMessageToolCall.ofFunction( + ChatCompletionMessageFunctionToolCall.builder() + .id(id) + .function( + ChatCompletionMessageFunctionToolCall.Function.builder() + .name(toolCallNames.getOrDefault(idx, "")) + .arguments(args) + .build()) + .build())); + } + return out; + } + + /** + * Synthesize an id for tool calls whose id never arrived on the stream (some OpenAI-compatible + * providers omit it). The id has to be stable within a turn and unique across turns so the + * assistant/tool_result pairing downstream holds. + */ + private static String synthId() { + return "local_" + java.util.UUID.randomUUID().toString().replace("-", "").substring(0, 24); + } + } + + private static Map parseToolArgs(String json) { + if (json == null || json.isEmpty()) { + return new HashMap<>(); + } + try { + return JSON.readValue(json, new TypeReference>() {}); + } catch (java.io.IOException e) { + throw new IllegalArgumentException("Malformed tool-call arguments JSON: " + json, e); + } + } + + // --- Middleware dispatch methods --- + // + // Middlewares are internal framework code. If one throws, the agent run fails via the + // top-level catch in run() — we do not wrap individual dispatch calls in try/catch. + + private void emitFinish(AgentRunContext ctx, Consumer eventConsumer) { + emit( + ctx, + new AgentFinish( + ctx.getIteration(), + ctx.getPromptTokens(), + ctx.getCompletionTokens(), + ctx.getTotalTokens()), + eventConsumer); + } + + private void emit(AgentRunContext ctx, AgentEvent event, Consumer consumer) { + AgentEvent filtered = event; + for (AgentMiddleware mw : middlewares) { + filtered = mw.onEvent(ctx, filtered); + if (filtered == null) return; + } + consumer.accept(filtered); + } + + private void dispatchAgentStart(AgentRunContext ctx) { + for (AgentMiddleware mw : middlewares) { + mw.onAgentStart(ctx); + } + } + + private void dispatchAgentFinish(AgentRunContext ctx) { + // Runs even when the agent body threw, so swallow here to ensure every middleware's cleanup + // gets a chance to run; otherwise we'd leak session state in later middlewares. + for (int i = middlewares.size() - 1; i >= 0; i--) { + try { + middlewares.get(i).onAgentFinish(ctx); + } catch (Exception e) { + LOG.warn("Middleware onAgentFinish error", e); + } + } + } + + private AgentMiddleware.LlmCallAction dispatchBeforeLlmCall( + AgentRunContext ctx, List messages) { + for (AgentMiddleware mw : middlewares) { + AgentMiddleware.LlmCallAction action = mw.beforeLlmCall(ctx, messages); + if (action != null) return action; + } + return null; + } + + private void dispatchAfterLlmCall( + AgentRunContext ctx, ChatCompletionAssistantMessageParam response) { + for (int i = middlewares.size() - 1; i >= 0; i--) { + middlewares.get(i).afterLlmCall(ctx, response); + } + } + + private AgentMiddleware.ToolCallDenial dispatchBeforeToolCall( + AgentRunContext ctx, String toolCallId, String toolName, Map toolArgs) { + for (AgentMiddleware mw : middlewares) { + AgentMiddleware.ToolCallDenial denial = + mw.beforeToolCall(ctx, toolCallId, toolName, toolArgs); + if (denial != null) return denial; + } + return null; + } + + private String dispatchAfterToolCall( + AgentRunContext ctx, String toolName, Map toolArgs, String result) { + String modified = null; + for (int i = middlewares.size() - 1; i >= 0; i--) { + String mwResult = + middlewares + .get(i) + .afterToolCall(ctx, toolName, toolArgs, modified != null ? modified : result); + if (mwResult != null) { + modified = mwResult; + } + } + return modified; + } + + // --- Builder --- + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private OpenAIClient client; + private String modelName; + private ToolRegistry toolRegistry = new ToolRegistry(ToolRegistry.DEFAULT_TIMEOUT_SECONDS); + private final List middlewares = new ArrayList<>(); + private int maxIterations = 20; + private String systemPrompt; + + public Builder client(OpenAIClient client) { + this.client = client; + return this; + } + + public Builder modelName(String modelName) { + this.modelName = modelName; + return this; + } + + public Builder toolRegistry(ToolRegistry toolRegistry) { + this.toolRegistry = toolRegistry; + return this; + } + + public Builder addMiddleware(AgentMiddleware middleware) { + this.middlewares.add(middleware); + return this; + } + + public Builder maxIterations(int maxIterations) { + if (maxIterations < 1) { + throw new IllegalArgumentException("maxIterations must be >= 1, got " + maxIterations); + } + this.maxIterations = maxIterations; + return this; + } + + public Builder systemPrompt(String systemPrompt) { + this.systemPrompt = systemPrompt; + return this; + } + + public ReactAgent build() { + if (client == null) throw new IllegalStateException("client is required"); + if (modelName == null) throw new IllegalStateException("modelName is required"); + if (toolRegistry == null) throw new IllegalStateException("toolRegistry is required"); + for (AgentMiddleware mw : middlewares) { + mw.onRegister(toolRegistry); + } + return new ReactAgent( + client, modelName, toolRegistry, middlewares, maxIterations, systemPrompt); + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStore.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStore.java new file mode 100644 index 00000000000..d4a8a97e97d --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStore.java @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +import java.io.BufferedReader; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.regex.Pattern; +import java.util.regex.PatternSyntaxException; +import java.util.stream.Stream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Per-session temp-file store for large tool outputs that the input gate has offloaded from + * conversation history. + * + *

Layout: {@code //tool_.txt}, where {@code } is a + * per-engine randomly-named temp directory (see {@link #create()}). UTF-8 text. + * + *

Isolation: every {@code read}/{@code grep} call requires the current session id and + * validates that the caller-supplied path resolves under {@code /}, not merely + * under {@code }. A session cannot read another session's offloaded output even if it somehow + * obtained the absolute path. The per-engine random root additionally isolates different engine + * processes sharing the same host. + * + *

Failures (traversal, missing file, session-id mismatch, IO) are reported as error strings + * rather than thrown — the tool must keep the agent loop alive. + */ +public class ToolOutputStore implements AutoCloseable { + + private static final Logger LOG = LoggerFactory.getLogger(ToolOutputStore.class); + private static final String ROOT_PREFIX = "kyuubi-data-agent-"; + + private final Path root; + + /** + * Create a store backed by a fresh, per-engine random temp directory under {@code + * java.io.tmpdir}. + */ + public static ToolOutputStore create() { + try { + return new ToolOutputStore(Files.createTempDirectory(ROOT_PREFIX)); + } catch (IOException e) { + throw new IllegalStateException("Failed to create ToolOutputStore temp root", e); + } + } + + private ToolOutputStore(Path root) { + try { + Files.createDirectories(root); + this.root = root.toRealPath(); + } catch (IOException e) { + throw new IllegalStateException("Failed to initialize ToolOutputStore root: " + root, e); + } + } + + public Path getRoot() { + return root; + } + + /** Write {@code content} to {@code //tool_.txt}. */ + public Path write(String sessionId, String toolCallId, String content) throws IOException { + Path dir = root.resolve(safeSegment(sessionId)); + Files.createDirectories(dir); + Path file = dir.resolve("tool_" + safeSegment(toolCallId) + ".txt"); + Files.write(file, content.getBytes(StandardCharsets.UTF_8)); + return file; + } + + /** + * Read a line window. Returns a human-readable block including a 1-based {@code [lines X-Y of Z + * total]} header, or an error string on traversal / IO failure / cross-session access. + */ + public String read(String sessionId, String pathStr, long offset, int limit) { + Path file = validatePath(sessionId, pathStr); + if (file == null) { + return "Error: path is outside this session's tool-output directory or does not exist: " + + pathStr; + } + if (offset < 0) offset = 0; + if (limit <= 0) limit = 1; + + List taken = new ArrayList<>(limit); + long totalLines = 0; + try (BufferedReader br = Files.newBufferedReader(file, StandardCharsets.UTF_8)) { + String line; + while ((line = br.readLine()) != null) { + if (totalLines >= offset && taken.size() < limit) { + taken.add(line); + } + totalLines++; + } + } catch (IOException e) { + return "Error reading " + pathStr + ": " + e.getMessage(); + } + + long fromLine = offset + 1; // 1-based + long toLineExclusive = Math.min(offset + limit, totalLines); + StringBuilder sb = new StringBuilder(); + sb.append("[lines ") + .append(fromLine) + .append("-") + .append(toLineExclusive) + .append(" of ") + .append(totalLines) + .append(" total]\n"); + for (String line : taken) { + sb.append(line).append('\n'); + } + return sb.toString(); + } + + /** + * Stream-grep the file. Returns at most {@code maxMatches} matches as {@code lineNo:content}, one + * per line; or an error string on traversal / regex / IO failure / cross-session access. + */ + public String grep(String sessionId, String pathStr, String patternStr, int maxMatches) { + Path file = validatePath(sessionId, pathStr); + if (file == null) { + return "Error: path is outside this session's tool-output directory or does not exist: " + + pathStr; + } + if (patternStr == null || patternStr.isEmpty()) { + return "Error: 'pattern' parameter is required."; + } + if (maxMatches <= 0) maxMatches = 50; + + Pattern pattern; + try { + pattern = Pattern.compile(patternStr); + } catch (PatternSyntaxException e) { + return "Error: invalid regex pattern: " + e.getMessage(); + } + + StringBuilder sb = new StringBuilder(); + int matches = 0; + long lineNo = 0; + try (BufferedReader br = Files.newBufferedReader(file, StandardCharsets.UTF_8)) { + String line; + while ((line = br.readLine()) != null) { + lineNo++; + if (pattern.matcher(line).find()) { + sb.append(lineNo).append(':').append(line).append('\n'); + matches++; + if (matches >= maxMatches) break; + } + } + } catch (IOException e) { + return "Error reading " + pathStr + ": " + e.getMessage(); + } + if (matches == 0) { + return "[no matches for pattern: " + patternStr + "]"; + } + return "[" + matches + " match" + (matches == 1 ? "" : "es") + "]\n" + sb; + } + + /** Recursively delete the session's subtree. Safe to call on missing sessions. */ + public void cleanupSession(String sessionId) { + if (sessionId == null) return; + Path dir = root.resolve(safeSegment(sessionId)); + deleteTree(dir); + } + + /** Delete everything below (and including) the root. Idempotent; safe to call multiple times. */ + @Override + public void close() { + deleteTree(root); + } + + private static void deleteTree(Path dir) { + if (!Files.exists(dir)) return; + try (Stream stream = Files.walk(dir)) { + stream.sorted(Comparator.reverseOrder()).forEach(ToolOutputStore::deleteQuietly); + } catch (IOException e) { + LOG.warn("Failed to clean up dir {}", dir, e); + } + } + + private static void deleteQuietly(Path p) { + try { + Files.deleteIfExists(p); + } catch (IOException e) { + LOG.debug("Failed to delete {}", p, e); + } + } + + /** + * Resolve {@code pathStr} and return it only if (a) it exists as a regular file and (b) the real + * path is under {@code /}. Returns null on any violation — including a null or + * empty session id, since without one we cannot scope the check. + */ + private Path validatePath(String sessionId, String pathStr) { + if (pathStr == null || pathStr.isEmpty()) return null; + if (sessionId == null || sessionId.isEmpty()) return null; + Path sessionRoot = root.resolve(safeSegment(sessionId)); + try { + Path real = Paths.get(pathStr).toRealPath(); + if (!real.startsWith(sessionRoot)) return null; + if (!Files.isRegularFile(real)) return null; + return real; + } catch (IOException | SecurityException e) { + return null; + } + } + + /** Strip anything that could escape a single path segment. */ + private static String safeSegment(String raw) { + if (raw == null || raw.isEmpty()) return "_"; + StringBuilder sb = new StringBuilder(raw.length()); + for (int i = 0; i < raw.length(); i++) { + char c = raw.charAt(i); + if (Character.isLetterOrDigit(c) || c == '-' || c == '_' || c == '.') { + sb.append(c); + } else { + sb.append('_'); + } + } + return sb.toString(); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/Compaction.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/Compaction.java new file mode 100644 index 00000000000..26eb6024e98 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/Compaction.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.event; + +/** + * Emitted by {@code CompactionMiddleware} after it has summarized a prefix of the conversation + * history and replaced it in memory. Purely observational — the LLM call that immediately follows + * uses the already-compacted history, so consumers just see this as a side-channel notice that + * compaction happened. The summary text itself is intentionally not included: it can be large and + * would bloat the event stream; operators who need it can read the middleware log. + */ +public final class Compaction extends AgentEvent { + private final int summarizedCount; + private final int keptCount; + private final long triggerTokens; + private final long observedTokens; + + public Compaction(int summarizedCount, int keptCount, long triggerTokens, long observedTokens) { + super(EventType.COMPACTION); + this.summarizedCount = summarizedCount; + this.keptCount = keptCount; + this.triggerTokens = triggerTokens; + this.observedTokens = observedTokens; + } + + public int summarizedCount() { + return summarizedCount; + } + + public int keptCount() { + return keptCount; + } + + public long triggerTokens() { + return triggerTokens; + } + + public long observedTokens() { + return observedTokens; + } + + @Override + public String toString() { + return "Compaction{summarized=" + + summarizedCount + + ", kept=" + + keptCount + + ", triggerTokens=" + + triggerTokens + + ", observedTokens=" + + observedTokens + + "}"; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventType.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventType.java index 937422e2bf5..d58e5de2ee7 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventType.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventType.java @@ -21,6 +21,8 @@ * Enumerates the types of events emitted by the ReAct agent loop. Each value maps to a * corresponding {@link AgentEvent} subclass and carries an SSE event name used for wire * serialization. + * + * @see org.apache.kyuubi.engine.dataagent.runtime.ReactAgent */ public enum EventType { @@ -51,6 +53,9 @@ public enum EventType { /** The agent requires user approval before executing a tool. */ APPROVAL_REQUEST("approval_request"), + /** The conversation history was compacted by the compaction middleware. */ + COMPACTION("compaction"), + /** The agent has finished its analysis. */ AGENT_FINISH("agent_finish"); diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/AgentMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/AgentMiddleware.java new file mode 100644 index 00000000000..c934bb0882a --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/AgentMiddleware.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import java.util.List; +import java.util.Map; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; + +/** + * Middleware interface for the Data Agent ReAct loop. Middlewares are executed in onion-model + * order: before_* hooks run first-to-last, after_* hooks run last-to-first. + * + *

All hooks have default no-op implementations. Override only what you need. + */ +public interface AgentMiddleware { + + /** + * Called once when the middleware is wired into the agent. Register companion tools that are part + * of the middleware's contract, or capture a reference to the registry for later use. Dispatched + * by {@code ReactAgent.Builder.build} before the agent accepts any requests. + */ + default void onRegister(ToolRegistry registry) {} + + /** Called when the agent starts processing a user query. Runs first-to-last. */ + default void onAgentStart(AgentRunContext ctx) {} + + /** Called when the agent finishes. Runs last-to-first (cleanup order). */ + default void onAgentFinish(AgentRunContext ctx) {} + + /** + * Called before each LLM invocation. Return non-null to skip or modify the LLM call. Runs + * first-to-last. + * + * @return {@code null} to proceed normally, {@link LlmSkip} to abort, or {@link + * LlmModifyMessages} to replace the message list for this call. + */ + default LlmCallAction beforeLlmCall( + AgentRunContext ctx, List messages) { + return null; + } + + /** Called after each LLM invocation. Runs last-to-first. */ + default void afterLlmCall(AgentRunContext ctx, ChatCompletionAssistantMessageParam response) {} + + /** Called before each tool execution. Return non-null to deny the call. Runs first-to-last. */ + default ToolCallDenial beforeToolCall( + AgentRunContext ctx, String toolCallId, String toolName, Map toolArgs) { + return null; + } + + /** + * Called after each tool execution. Runs last-to-first. + * + *

Returns {@code String} (not {@code void}) so that middlewares can intercept and transform + * the tool result before it is fed back to the LLM — e.g. for data masking, output truncation, or + * injecting metadata. Return {@code null} to keep the original result unchanged; return a + * non-null value to replace it. + */ + default String afterToolCall( + AgentRunContext ctx, String toolName, Map toolArgs, String result) { + return null; + } + + /** + * Called for every event before it is emitted. Return null to suppress the event. Runs + * first-to-last. + */ + default AgentEvent onEvent(AgentRunContext ctx, AgentEvent event) { + return event; + } + + /** + * Called when a session is closed. Clean up per-session state (scratch files, pending tasks, + * counters). Idempotent. Dispatched by {@code ReactAgent.closeSession}. + */ + default void onSessionClose(String sessionId) {} + + /** + * Called when the engine is stopping. Release global resources and unblock any threads still + * waiting on this middleware. Dispatched by {@code ReactAgent.stop}. + */ + default void onStop() {} + + /** + * Base type for {@code beforeLlmCall} return values. Subtypes: {@link LlmSkip} to abort the LLM + * call, {@link LlmModifyMessages} to replace the message list for this call. + */ + abstract class LlmCallAction { + private LlmCallAction() {} + } + + /** Returned from {@code beforeLlmCall} to skip the LLM call and abort the agent loop. */ + class LlmSkip extends LlmCallAction { + private final String reason; + + public LlmSkip(String reason) { + this.reason = reason; + } + + public String reason() { + return reason; + } + } + + /** + * Returned from {@code beforeLlmCall} to replace the message list for this LLM invocation. The + * agent loop continues normally with the modified messages. + */ + class LlmModifyMessages extends LlmCallAction { + private final List messages; + + public LlmModifyMessages(List messages) { + this.messages = messages; + } + + public List messages() { + return messages; + } + } + + /** Returned from {@code beforeToolCall} to deny a tool call. Non-null means denied. */ + class ToolCallDenial { + private final String reason; + + public ToolCallDenial(String reason) { + this.reason = reason; + } + + public String reason() { + return reason; + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddleware.java new file mode 100644 index 00000000000..92d25b47b9d --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddleware.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.ApprovalMode; +import org.apache.kyuubi.engine.dataagent.runtime.event.ApprovalRequest; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; +import org.apache.kyuubi.engine.dataagent.tool.ToolRiskLevel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Middleware that enforces human-in-the-loop approval for tool calls based on the {@link + * ApprovalMode} and the tool's {@link ToolRiskLevel}. + * + *

When approval is required, an {@link ApprovalRequest} event is emitted to the client via + * {@link AgentRunContext#emit}, and the agent thread blocks until the client responds via {@link + * #resolve} or the timeout expires. + */ +public class ApprovalMiddleware implements AgentMiddleware { + + private static final Logger LOG = LoggerFactory.getLogger(ApprovalMiddleware.class); + + private static final long DEFAULT_TIMEOUT_SECONDS = 300; // 5 minutes + + private final long timeoutSeconds; + private final ConcurrentHashMap> pending = + new ConcurrentHashMap<>(); + private ToolRegistry toolRegistry; + + public ApprovalMiddleware() { + this(DEFAULT_TIMEOUT_SECONDS); + } + + public ApprovalMiddleware(long timeoutSeconds) { + this.timeoutSeconds = timeoutSeconds; + } + + @Override + public void onRegister(ToolRegistry registry) { + this.toolRegistry = registry; + } + + @Override + public ToolCallDenial beforeToolCall( + AgentRunContext ctx, String toolCallId, String toolName, Map toolArgs) { + ToolRiskLevel riskLevel = toolRegistry.getRiskLevel(toolName); + + if (shouldAutoApprove(ctx.getApprovalMode(), riskLevel)) { + return null; + } + + String requestId = UUID.randomUUID().toString(); + CompletableFuture future = new CompletableFuture<>(); + pending.put(requestId, future); + + ctx.emit(new ApprovalRequest(requestId, toolCallId, toolName, toolArgs, riskLevel)); + LOG.info("Approval requested for tool '{}' (requestId={})", toolName, requestId); + + try { + boolean approved = future.get(timeoutSeconds, TimeUnit.SECONDS); + if (!approved) { + LOG.info("Tool '{}' denied by user (requestId={})", toolName, requestId); + return new ToolCallDenial("User denied execution of " + toolName); + } + LOG.info("Tool '{}' approved by user (requestId={})", toolName, requestId); + return null; + } catch (TimeoutException e) { + // Complete the future so that a late resolve() call is a harmless no-op + // instead of completing a dangling future. + future.completeExceptionally(e); + LOG.warn("Approval timed out for tool '{}' (requestId={})", toolName, requestId); + return new ToolCallDenial("Approval timed out after " + timeoutSeconds + "s for " + toolName); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return new ToolCallDenial("Approval interrupted for " + toolName); + } catch (Exception e) { + LOG.error("Unexpected error waiting for approval", e); + return new ToolCallDenial("Approval error: " + e.getMessage()); + } finally { + pending.remove(requestId); + } + } + + /** + * Resolve a pending approval request. Called by the external approval channel (e.g. a Kyuubi + * operation or REST endpoint). + * + * @param requestId the request ID from the {@link ApprovalRequest} event + * @param approved true to approve, false to deny + * @return true if the request was found and resolved, false if not found (already timed out or + * invalid ID) + */ + public boolean resolve(String requestId, boolean approved) { + CompletableFuture future = pending.get(requestId); + if (future != null) { + return future.complete(approved); + } + LOG.warn("No pending approval found for requestId={}", requestId); + return false; + } + + /** + * Cancel all pending approval requests to unblock any waiting agent threads. Invoked as part of + * engine shutdown via {@code ReactAgent.stop}. + */ + @Override + public void onStop() { + InterruptedException ex = new InterruptedException("Session closed"); + pending.forEachKey( + Long.MAX_VALUE, + key -> { + CompletableFuture future = pending.remove(key); + if (future != null) { + future.completeExceptionally(ex); + } + }); + } + + private static boolean shouldAutoApprove(ApprovalMode mode, ToolRiskLevel riskLevel) { + if (mode == ApprovalMode.AUTO_APPROVE) { + return true; + } + if (mode == ApprovalMode.NORMAL && riskLevel == ToolRiskLevel.SAFE) { + return true; + } + // STRICT: all tools require approval + return false; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddleware.java new file mode 100644 index 00000000000..acab7ae8b75 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddleware.java @@ -0,0 +1,409 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import com.openai.client.OpenAIClient; +import com.openai.models.chat.completions.ChatCompletion; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageToolCall; +import com.openai.models.chat.completions.ChatCompletionSystemMessageParam; +import com.openai.models.chat.completions.ChatCompletionUserMessageParam; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.ConversationMemory; +import org.apache.kyuubi.engine.dataagent.runtime.event.Compaction; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Middleware that compacts conversation history when the prompt grows large. + * + *

Trigger formula: + * + *

+ *   predicted_this_turn_prompt_tokens
+ *       = last_llm_call_total_tokens            // prompt + completion of the previous call
+ *       + estimate_new_tail(messages)           // chars / 4, for content appended after the
+ *                                               // last assistant message (tool results, new user)
+ * 
+ * + * History (already sent to the LLM) must use real token counts — we read them straight off {@link + * ConversationMemory#getLastTotalTokens()}, which the provider updates after every call. We use + * total (prompt + completion) rather than just prompt because the previous call's completion + * — e.g. an assistant {@code tool_call} message — is already appended to history and will be + * tokenized into the next prompt; the tail estimator starts strictly after the last + * assistant message, so without completion we would miss it. Only content appended since that + * assistant message is estimated, which catches spikes before the next API report arrives. + * + *

Post-compaction persistence: the summary + kept tail replace {@link + * ConversationMemory}'s history via {@link ConversationMemory#replaceHistory}. All subsequent turns + * in this session read the compacted history — we do not re-summarize each turn. Because the next + * LLM call uses the compacted messages, its reported {@code prompt_tokens} will be small, naturally + * preventing retriggering. + * + *

Thread safety / shared instance: Instances of this middleware are shared across all + * sessions inside a provider (see {@code OpenAiProvider} javadoc). All per-session state + * (cumulative and last-call totals) lives on {@link ConversationMemory}, so this middleware itself + * is stateless across sessions and requires no per-session cleanup. + * + *

Tool-call pair invariant: The split never separates an assistant message bearing {@code + * tool_calls} from the {@code tool_result} messages that satisfy those calls — an orphan + * tool_result is rejected by the OpenAI API with HTTP 400. + * + *

Failure handling: summarizer failures propagate out of {@code beforeLlmCall} to the + * agent's top-level catch, which surfaces an {@code AgentError} event. We don't silently skip + * compaction — a broken summarizer is a real problem the operator needs to see. + * + *

Disabling: to effectively turn compaction off, construct with a very large {@code + * triggerPromptTokens} (e.g., {@link Long#MAX_VALUE}). + * + *

TODO: support a separate (cheaper) summarization model distinct from the main agent model. + */ +public class CompactionMiddleware implements AgentMiddleware { + + private static final Logger LOG = LoggerFactory.getLogger(CompactionMiddleware.class); + + /** Number of recent user turns (and their assistant/tool companions) to preserve verbatim. */ + private static final int KEEP_RECENT_TURNS = 4; + + private static final String COMPACTION_SYSTEM_PROMPT = + "SYSTEM OPERATION — this is an automated context compaction step, NOT a user message.\n" + + "\n" + + "You are summarizing a conversation between a user and a Data Agent (a ReAct agent" + + " that executes SQL and analytics tools against Kyuubi). The user has NOT asked you" + + " to summarize anything. Do not address the user. Do not ask questions. Produce only" + + " the summary in the schema below.\n" + + "\n" + + "Goal: produce a dense, structured summary the agent can resume from without losing" + + " critical context. Preserve concrete details verbatim — file paths, table names," + + " schema definitions, SQL snippets, column names, error messages.\n" + + "\n" + + "Output EXACTLY these 8 sections, in order, as markdown headers:\n" + + "\n" + + "1. ## User Intent\n" + + " The user's original request, restated in full. Preserve the literal phrasing of" + + " the ask. Include follow-up refinements.\n" + + "\n" + + "2. ## Key Concepts\n" + + " Domain terms, data sources, tables, schemas, SQL dialects, and business logic" + + " the agent has been reasoning about.\n" + + "\n" + + "3. ## Files and Code\n" + + " File paths, query text, DDL, or code artifacts referenced. Include verbatim SQL" + + " snippets that produced meaningful results.\n" + + "\n" + + "4. ## Errors and Recoveries\n" + + " Errors encountered (SQL syntax, permission, timeout, tool failures), what was" + + " tried, and what resolved them. Preserve error messages verbatim.\n" + + "\n" + + "5. ## Pending Work\n" + + " Tasks the agent identified but has not completed yet.\n" + + "\n" + + "6. ## Current State\n" + + " Where the agent is right now — what question is open, what data has been" + + " retrieved, what hypothesis is being tested.\n" + + "\n" + + "7. ## Next Step\n" + + " The immediate next action the agent should take when resuming.\n" + + "\n" + + "8. ## Tool Usage Summary\n" + + " Which tools were called, how many times, and notable results.\n" + + "\n" + + "CRITICAL:\n" + + "- DO NOT ask the user about this summary.\n" + + "- DO NOT mention that compaction occurred in any future assistant response.\n" + + "- DO NOT invent details not present in the conversation.\n" + + "- DO NOT output anything outside the 8 sections.\n"; + + private final OpenAIClient client; + private final String summarizerModel; + private final long triggerPromptTokens; + + public CompactionMiddleware( + OpenAIClient client, String summarizerModel, long triggerPromptTokens) { + this.client = client; + this.summarizerModel = summarizerModel; + this.triggerPromptTokens = triggerPromptTokens; + } + + @Override + public LlmCallAction beforeLlmCall( + AgentRunContext ctx, List messages) { + ConversationMemory mem = ctx.getMemory(); + // 1) Real token count of the previous LLM call (prompt + completion, i.e. everything through + // the last assistant message, which is now part of history). 0 on the first call. + long lastTotal = mem.getLastTotalTokens(); + // 2) Estimated tokens appended to the tail after the last assistant (tool_results, new user). + long newTailEstimate = estimateTailAfterLastAssistant(messages); + + if (lastTotal + newTailEstimate < triggerPromptTokens) { + return null; + } + + List history = mem.getHistory(); + + // 3) Split history into old (to summarize) and kept (recent tail), never orphaning a + // tool_result. + Split split = computeSplit(history, KEEP_RECENT_TURNS); + if (split.old.isEmpty()) { + return null; + } + + String summary = summarize(mem.getSystemPrompt(), split.old); + + // 4) Build the compacted history and persist into ConversationMemory. + List compacted = new ArrayList<>(1 + split.kept.size()); + compacted.add(wrapSummaryAsUserMessage(summary)); + compacted.addAll(split.kept); + mem.replaceHistory(compacted); + + LOG.info( + "Compacted {} old msgs into 1 summary; kept {} tail msgs (lastTotal={}, newTail~={})", + split.old.size(), + split.kept.size(), + lastTotal, + newTailEstimate); + + ctx.emit( + new Compaction( + split.old.size(), split.kept.size(), triggerPromptTokens, lastTotal + newTailEstimate)); + + return new LlmModifyMessages(mem.buildLlmMessages()); + } + + /** Call the LLM to produce a summary of {@code oldMessages}. Failures propagate. */ + private String summarize(String agentSystemPrompt, List oldMessages) { + String systemPrompt = COMPACTION_SYSTEM_PROMPT; + if (agentSystemPrompt != null && !agentSystemPrompt.isEmpty()) { + systemPrompt = + systemPrompt + + "\n---\nFor context, the agent's own system prompt is:\n" + + agentSystemPrompt; + } + + String rendered = renderAsText(oldMessages); + + ChatCompletionCreateParams params = + ChatCompletionCreateParams.builder() + .model(summarizerModel) + .temperature(0.0) + .addMessage( + ChatCompletionMessageParam.ofSystem( + ChatCompletionSystemMessageParam.builder().content(systemPrompt).build())) + .addMessage( + ChatCompletionMessageParam.ofUser( + ChatCompletionUserMessageParam.builder().content(rendered).build())) + .build(); + + ChatCompletion response = client.chat().completions().create(params); + return response.choices().get(0).message().content().get(); + } + + // ----- helpers ----- + + /** Sum of content characters in messages after the last assistant, using ~4 chars per token. */ + static long estimateTailAfterLastAssistant(List messages) { + int lastAssistantIdx = -1; + for (int i = messages.size() - 1; i >= 0; i--) { + if (messages.get(i).isAssistant()) { + lastAssistantIdx = i; + break; + } + } + long totalChars = 0; + for (int i = lastAssistantIdx + 1; i < messages.size(); i++) { + totalChars += contentCharCount(messages.get(i)); + } + return totalChars / 4; + } + + private static long contentCharCount(ChatCompletionMessageParam msg) { + if (msg.isUser()) { + return msg.asUser().content().text().map(String::length).orElse(0); + } + if (msg.isTool()) { + return msg.asTool().content().text().map(String::length).orElse(0); + } + if (msg.isAssistant()) { + return msg.asAssistant().content().flatMap(c -> c.text()).map(String::length).orElse(0); + } + if (msg.isSystem()) { + return msg.asSystem().content().text().map(String::length).orElse(0); + } + return 0; + } + + /** + * Render a list of messages as plain text for the summarizer's user turn. Tool calls and tool + * results are rendered as tagged text so the summarizer LLM doesn't try to continue them as live + * agent state. + */ + static String renderAsText(List messages) { + StringBuilder sb = new StringBuilder(); + for (ChatCompletionMessageParam msg : messages) { + if (sb.length() > 0) sb.append("\n\n"); + if (msg.isUser()) { + sb.append("USER: ").append(extractUserContent(msg)); + } else if (msg.isAssistant()) { + sb.append("ASSISTANT: ").append(extractAssistantContent(msg)); + msg.asAssistant() + .toolCalls() + .ifPresent( + calls -> { + for (ChatCompletionMessageToolCall tc : calls) { + if (tc.isFunction()) { + sb.append("\n[tool_call: ") + .append(tc.asFunction().function().name()) + .append("(") + .append(tc.asFunction().function().arguments()) + .append(") id=") + .append(tc.asFunction().id()) + .append("]"); + } + } + }); + } else if (msg.isTool()) { + sb.append("[tool_result id=") + .append(msg.asTool().toolCallId()) + .append("]: ") + .append(extractToolContent(msg)); + } else if (msg.isSystem()) { + // system prompt should not appear in oldMessages, but render defensively + sb.append("SYSTEM: ").append(extractSystemContent(msg)); + } + } + return sb.toString(); + } + + private static String extractUserContent(ChatCompletionMessageParam msg) { + return msg.asUser().content().text().orElse("[non-text content]"); + } + + private static String extractAssistantContent(ChatCompletionMessageParam msg) { + return msg.asAssistant().content().map(c -> c.text().orElse("[non-text content]")).orElse(""); + } + + private static String extractToolContent(ChatCompletionMessageParam msg) { + return msg.asTool().content().text().orElse("[non-text content]"); + } + + private static String extractSystemContent(ChatCompletionMessageParam msg) { + return msg.asSystem().content().text().orElse("[non-text content]"); + } + + /** Result of splitting the history into a summarizable prefix and a kept tail. */ + static final class Split { + + final List old; + final List kept; + + Split(List old, List kept) { + this.old = old; + this.kept = kept; + } + } + + /** + * Split the history at a boundary that preserves the last {@code keepRecentTurns} user messages, + * with adjustments so that no assistant-tool_use is separated from its tool_results. + */ + static Split computeSplit(List history, int keepRecentTurns) { + if (history.size() <= 2) { + return new Split(new ArrayList<>(), new ArrayList<>(history)); + } + + // Walk from the tail, count user boundaries. If the history does not contain enough user + // messages to satisfy keepRecentTurns, keep everything (splitIdx = 0); the empty-old check + // in beforeLlmCall will then skip this turn gracefully. + int userBoundariesFound = 0; + int splitIdx = 0; + for (int i = history.size() - 1; i >= 0; i--) { + if (history.get(i).isUser()) { + userBoundariesFound++; + if (userBoundariesFound == keepRecentTurns) { + splitIdx = i; + break; + } + } + } + + // Protect tool-call / tool-result pairing: never split between an assistant that issued + // tool_calls and the tool_results that satisfy them. + while (splitIdx > 0) { + ChatCompletionMessageParam prev = history.get(splitIdx - 1); + if (prev.isTool()) { + splitIdx--; + continue; + } + if (prev.isAssistant()) { + boolean hasToolCalls = prev.asAssistant().toolCalls().map(List::size).orElse(0) > 0; + if (hasToolCalls) { + splitIdx--; + continue; + } + } + break; + } + + // Also guard against the edge case: if kept contains a tool_result whose tool_call id is + // defined only in old, pull that assistant (and its siblings) into kept too. + Set keptCallIds = collectToolCallIds(history.subList(splitIdx, history.size())); + if (!keptCallIds.isEmpty()) { + while (splitIdx > 0) { + ChatCompletionMessageParam prev = history.get(splitIdx - 1); + if (!prev.isAssistant()) break; + List calls = prev.asAssistant().toolCalls().orElse(null); + if (calls == null || calls.isEmpty()) break; + boolean satisfiesKept = false; + for (ChatCompletionMessageToolCall tc : calls) { + if (tc.isFunction() && keptCallIds.contains(tc.asFunction().id())) { + satisfiesKept = true; + break; + } + } + if (!satisfiesKept) break; + splitIdx--; + } + } + + List oldPart = new ArrayList<>(history.subList(0, splitIdx)); + List keptPart = + new ArrayList<>(history.subList(splitIdx, history.size())); + return new Split(oldPart, keptPart); + } + + private static Set collectToolCallIds(List slice) { + Set ids = new HashSet<>(); + for (ChatCompletionMessageParam m : slice) { + if (m.isTool()) { + ids.add(m.asTool().toolCallId()); + } + } + return ids; + } + + private static ChatCompletionMessageParam wrapSummaryAsUserMessage(String summary) { + String body = "\n" + summary + "\n"; + return ChatCompletionMessageParam.ofUser( + ChatCompletionUserMessageParam.builder().content(body).build()); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/LoggingMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/LoggingMiddleware.java new file mode 100644 index 00000000000..e0a5c2364eb --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/LoggingMiddleware.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import java.util.List; +import java.util.Map; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentError; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent; +import org.apache.kyuubi.engine.dataagent.runtime.event.StepStart; +import org.apache.kyuubi.engine.dataagent.runtime.event.ToolResult; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.slf4j.MDC; + +/** + * Logging middleware that prints agent lifecycle events for debugging and observability. + * + *

Picks up {@code operationId} and {@code sessionId} from SLF4J MDC (set by ExecuteStatement) to + * tag every log line with the Kyuubi operation/session context. + * + *

Log structure: + * + *

+ *   [op:abcd1234] START user_input="..."
+ *   [op:abcd1234] Step 1
+ *   [op:abcd1234] LLM call: step=1, messages=3
+ *   [op:abcd1234] LLM response: step=1, content="...(truncated)", tool_calls=1
+ *   [op:abcd1234] Tool call: sql_query {sql=SELECT ...}
+ *   [op:abcd1234] Tool result: sql_query -> "| col1 | col2 |...(truncated)"
+ *   [op:abcd1234] FINISH steps=2, tokens=1234
+ * 
+ */ +public class LoggingMiddleware implements AgentMiddleware { + + private static final Logger LOG = LoggerFactory.getLogger("DataAgent"); + + private static final int MAX_PREVIEW_LENGTH = 500; + + private static String prefix() { + String sessionId = MDC.get("sessionId"); + String opId = MDC.get("operationId"); + StringBuilder sb = new StringBuilder(); + if (sessionId != null) { + sb.append("[s:").append(shortId(sessionId)).append("]"); + } + if (opId != null) { + sb.append("[op:").append(shortId(opId)).append("]"); + } + if (sb.length() > 0) { + sb.append(" "); + } + return sb.toString(); + } + + /** + * Take the first segment of a UUID (before the first dash). e.g. "327d8c5b-91ef-..." → "327d8c5b" + */ + private static String shortId(String id) { + int dash = id.indexOf('-'); + return dash > 0 ? id.substring(0, dash) : id; + } + + @Override + public void onAgentStart(AgentRunContext ctx) { + LOG.debug("{}START user_input=\"{}\"", prefix(), truncate(ctx.getMemory().getLastUserInput())); + } + + @Override + public void onAgentFinish(AgentRunContext ctx) { + LOG.info( + "{}FINISH steps={}, prompt_tokens={}, completion_tokens={}, total_tokens={}", + prefix(), + ctx.getIteration(), + ctx.getPromptTokens(), + ctx.getCompletionTokens(), + ctx.getTotalTokens()); + } + + @Override + public LlmCallAction beforeLlmCall( + AgentRunContext ctx, List messages) { + LOG.info("{}LLM call: step={}, messages={}", prefix(), ctx.getIteration(), messages.size()); + return null; + } + + @Override + public void afterLlmCall(AgentRunContext ctx, ChatCompletionAssistantMessageParam response) { + String content = response.content().map(Object::toString).orElse(""); + int toolCallCount = response.toolCalls().map(List::size).orElse(0); + LOG.info( + "{}LLM response: step={}, content=\"{}\", tool_calls={}, " + + "usage(cumulative): prompt={}, completion={}, total={}", + prefix(), + ctx.getIteration(), + truncate(content), + toolCallCount, + ctx.getPromptTokens(), + ctx.getCompletionTokens(), + ctx.getTotalTokens()); + } + + @Override + public ToolCallDenial beforeToolCall( + AgentRunContext ctx, String toolCallId, String toolName, Map toolArgs) { + LOG.info("{}Tool call: id={}, name={}", prefix(), toolCallId, toolName); + LOG.debug("{}Tool args: {}", prefix(), toolArgs); + return null; + } + + @Override + public String afterToolCall( + AgentRunContext ctx, String toolName, Map toolArgs, String result) { + LOG.info("{}Tool result: {} -> \"{}\"", prefix(), toolName, truncate(result)); + return null; + } + + @Override + public AgentEvent onEvent(AgentRunContext ctx, AgentEvent event) { + switch (event.eventType()) { + case STEP_START: + LOG.info("{}Step {}", prefix(), ((StepStart) event).stepNumber()); + break; + case ERROR: + LOG.error("{}ERROR: {}", prefix(), ((AgentError) event).message()); + break; + case TOOL_RESULT: + ToolResult tr = (ToolResult) event; + if (tr.isError()) { + LOG.warn("{}Tool error: {} -> \"{}\"", prefix(), tr.toolName(), truncate(tr.output())); + } + break; + default: + break; + } + return event; + } + + private static String truncate(String s) { + if (s == null) return ""; + return s.length() <= MAX_PREVIEW_LENGTH ? s : s.substring(0, MAX_PREVIEW_LENGTH) + "..."; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddleware.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddleware.java new file mode 100644 index 00000000000..87aad9f3255 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddleware.java @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.ToolOutputStore; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; +import org.apache.kyuubi.engine.dataagent.tool.output.GrepToolOutputTool; +import org.apache.kyuubi.engine.dataagent.tool.output.ReadToolOutputTool; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Input-gate middleware that offloads oversized tool outputs to per-session temp files and replaces + * the in-memory tool result with a small head + tail preview plus a retrieval hint. The cheapest + * and highest-ROI defense against context rot. + * + *

Trigger: {@code result.lines > MAX_LINES} OR {@code result.bytes > MAX_BYTES}, first to + * trip wins. Thresholds are hardcoded — the ReAct loop can compensate for suboptimal defaults by + * calling the retrieval tools more aggressively. + * + *

Exempt tools: {@link ReadToolOutputTool} and {@link GrepToolOutputTool} never go + * through the gate — the agent would otherwise recursively re-offload its own retrieval output. + * + *

Session lifecycle: a monotonic counter per session is used to name temp files. {@link + * #onSessionClose(String)} wipes the counter and the per-session temp dir; call it from the + * provider's {@code close(sessionId)} hook (not from {@link #onAgentFinish}, which fires on every + * turn and would invalidate paths the LLM still needs to reference). + */ +public class ToolResultOffloadMiddleware implements AgentMiddleware { + + private static final Logger LOG = LoggerFactory.getLogger(ToolResultOffloadMiddleware.class); + + static final int MAX_LINES = 500; + static final int MAX_BYTES = 50 * 1024; + static final int PREVIEW_HEAD_LINES = 20; + static final int PREVIEW_TAIL_LINES = 20; + + private static final Set EXEMPT_TOOLS = + new HashSet<>(Arrays.asList(ReadToolOutputTool.NAME, GrepToolOutputTool.NAME)); + + private final ToolOutputStore store = ToolOutputStore.create(); + private final ConcurrentHashMap counters = new ConcurrentHashMap<>(); + + /** + * Register the companion retrieval tools so the LLM can reach back into offloaded files. Paired + * with the preview hint emitted from {@link #afterToolCall}; skipping this registration would + * leave the LLM dangling on file paths it can never read. + */ + @Override + public void onRegister(ToolRegistry registry) { + registry.register(new ReadToolOutputTool(store)); + registry.register(new GrepToolOutputTool(store)); + } + + @Override + public String afterToolCall( + AgentRunContext ctx, String toolName, Map toolArgs, String result) { + if (result.isEmpty()) return null; + if (EXEMPT_TOOLS.contains(toolName)) return null; + + int bytes = result.getBytes(StandardCharsets.UTF_8).length; + int lines = countLines(result); + if (lines <= MAX_LINES && bytes <= MAX_BYTES) { + return null; + } + + // AgentRunContext.sessionId is null in unit-test constructions that don't exercise offload. + // In production the provider always threads it through, so treat null as "skip offload". + String sessionId = ctx.getSessionId(); + if (sessionId == null) return null; + + long n = counters.computeIfAbsent(sessionId, k -> new AtomicLong()).incrementAndGet(); + String toolCallId = toolName + "_" + n; + + Path file; + try { + file = store.write(sessionId, toolCallId, result); + } catch (IOException e) { + LOG.warn( + "Tool output offload failed for tool={} session={}; passing through full output", + toolName, + sessionId, + e); + return null; + } + + LOG.info( + "Offloaded tool={} session={} ({} lines / {} bytes) -> {}", + toolName, + sessionId, + lines, + bytes, + file.getFileName()); + return buildPreview(result, lines, bytes, file); + } + + /** Clean up counter and temp dir for a closed session. Idempotent. */ + @Override + public void onSessionClose(String sessionId) { + if (sessionId == null) return; + counters.remove(sessionId); + store.cleanupSession(sessionId); + } + + /** Engine-wide shutdown: drop the temp root. */ + @Override + public void onStop() { + store.close(); + } + + static int countLines(String s) { + if (s.isEmpty()) return 0; + int count = 1; + for (int i = 0; i < s.length(); i++) { + if (s.charAt(i) == '\n') count++; + } + // Trailing newline means the last "line" is empty, but still counted. Good enough for + // gating decisions — we're not trying to match `wc -l` exactly. + return count; + } + + static String buildPreview(String full, int lines, int bytes, Path file) { + String[] split = full.split("\n", -1); + int headEnd = Math.min(PREVIEW_HEAD_LINES, split.length); + int tailStart = Math.max(headEnd, split.length - PREVIEW_TAIL_LINES); + + StringBuilder sb = new StringBuilder(); + sb.append("[Tool output truncated: ") + .append(lines) + .append(" lines, ") + .append(humanBytes(bytes)) + .append("]\n") + .append("Saved to: ") + .append(file.toString()) + .append("\n\n--- First ") + .append(headEnd) + .append(" lines ---\n"); + for (int i = 0; i < headEnd; i++) { + sb.append(split[i]).append('\n'); + } + if (tailStart < split.length) { + int tailCount = split.length - tailStart; + sb.append("--- Last ").append(tailCount).append(" lines ---\n"); + for (int i = tailStart; i < split.length; i++) { + sb.append(split[i]).append('\n'); + } + } + sb.append("\nUse ") + .append(ReadToolOutputTool.NAME) + .append("(path, offset, limit) to read windows, or ") + .append(GrepToolOutputTool.NAME) + .append("(path, pattern, max_matches) to search."); + return sb.toString(); + } + + private static String humanBytes(long bytes) { + if (bytes < 1024) return bytes + " B"; + if (bytes < 1024 * 1024) return String.format("%.1f KB", bytes / 1024.0); + return String.format("%.1f MB", bytes / (1024.0 * 1024.0)); + } + + /** Visible for testing. */ + int trackedSessions() { + return counters.size(); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/AgentTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/AgentTool.java index 297a7c8d74a..b19f6787c87 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/AgentTool.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/AgentTool.java @@ -50,7 +50,10 @@ default ToolRiskLevel riskLevel() { * Execute the tool with the given deserialized arguments. * * @param args the deserialized arguments from the LLM's tool call + * @param ctx per-invocation context (session id, etc.); never null — use {@link + * ToolContext#EMPTY} for calls without a session. Tools that are session-agnostic may ignore + * it. * @return the result string to feed back to the LLM */ - String execute(T args); + String execute(T args, ToolContext ctx); } diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolContext.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolContext.java new file mode 100644 index 00000000000..2625f3bab59 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolContext.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool; + +/** + * Per-invocation context handed to {@link AgentTool#execute(Object, ToolContext)}. Today it carries + * just the session id so session-scoped tools (e.g. the offloaded tool-output retrievers) can + * restrict their filesystem view; extend here when a tool needs user/approval/etc. + */ +public final class ToolContext { + + /** Sentinel for call sites that have no session to attribute — tests, direct CLI use. */ + public static final ToolContext EMPTY = new ToolContext(null); + + private final String sessionId; + + public ToolContext(String sessionId) { + this.sessionId = sessionId; + } + + /** Upstream session id, or {@code null} when invoked outside a session. */ + public String sessionId() { + return sessionId; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistry.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistry.java index a403c66b58d..3a11bab567c 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistry.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistry.java @@ -26,15 +26,16 @@ import com.openai.models.chat.completions.ChatCompletionTool; import java.util.LinkedHashMap; import java.util.Map; -import java.util.concurrent.Callable; -import java.util.concurrent.ExecutionException; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLong; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -68,15 +69,20 @@ public class ToolRegistry implements AutoCloseable { */ private static final int MAX_POOL_SIZE = 8; + /** Default wall-clock cap for a single tool call, used when no explicit value is configured. */ + public static final long DEFAULT_TIMEOUT_SECONDS = 300; + private final Map> tools = new LinkedHashMap<>(); private volatile Map cachedSpecs; private final long toolCallTimeoutSeconds; private final ExecutorService executor; + private final ScheduledExecutorService timeoutScheduler; /** - * @param toolCallTimeoutSeconds wall-clock cap on every {@link #executeTool} call, sourced from - * {@code kyuubi.engine.data.agent.tool.call.timeout}. When the timeout fires, the thread is - * interrupted and a descriptive error is returned to the LLM. + * @param toolCallTimeoutSeconds wall-clock cap on every {@link #executeTool} / {@link + * #submitTool} call, sourced from {@code kyuubi.engine.data.agent.tool.call.timeout}. When + * the timeout fires, the worker thread is interrupted and a descriptive error is returned to + * the LLM. */ public ToolRegistry(long toolCallTimeoutSeconds) { this.toolCallTimeoutSeconds = toolCallTimeoutSeconds; @@ -93,12 +99,20 @@ public ToolRegistry(long toolCallTimeoutSeconds) { t.setDaemon(true); return t; }); + this.timeoutScheduler = + Executors.newSingleThreadScheduledExecutor( + r -> { + Thread t = new Thread(r, "tool-call-timeout"); + t.setDaemon(true); + return t; + }); } - /** Shut down the worker pool. Idempotent. */ + /** Shut down the worker pool and the timeout scheduler. Idempotent. */ @Override public void close() { executor.shutdownNow(); + timeoutScheduler.shutdownNow(); } /** Register a tool. Keyed by {@link AgentTool#name()}. */ @@ -138,61 +152,100 @@ private synchronized Map ensureSpecs() { } /** - * Execute a tool call: deserialize the JSON args, then delegate to the tool, with a wall-clock - * timeout sourced from {@code kyuubi.engine.data.agent.tool.call.timeout}. If the tool does not - * finish within the timeout, the worker thread is interrupted and a descriptive error is returned - * to the LLM so it can react (e.g. simplify the query, retry with LIMIT). + * Synchronous entry point. Blocks until the tool finishes, times out, or the registry rejects the + * submission. Errors are surfaced as strings (never as exceptions) so the LLM can observe and + * react to them. * * @param toolName the function name from the LLM response * @param argsJson the raw JSON arguments string * @return the result string, or an error message */ - @SuppressWarnings("unchecked") public String executeTool(String toolName, String argsJson) { - AgentTool tool; + return submitTool(toolName, argsJson, ToolContext.EMPTY).join(); + } + + /** Synchronous entry point with an explicit {@link ToolContext}. */ + public String executeTool(String toolName, String argsJson, ToolContext ctx) { + return submitTool(toolName, argsJson, ctx).join(); + } + + public CompletableFuture submitTool(String toolName, String argsJson) { + return submitTool(toolName, argsJson, ToolContext.EMPTY); + } + + /** + * Asynchronous entry point. Deserialize args, run the tool on the worker pool, and apply a + * wall-clock timeout sourced from {@code kyuubi.engine.data.agent.tool.call.timeout}. The + * returned future is guaranteed to complete normally — timeouts, pool saturation, unknown tool, + * and execution failures are all translated into error strings. Callers can therefore use {@code + * .join()} / {@code .get()} without handling {@link java.util.concurrent.TimeoutException} or + * {@link java.util.concurrent.ExecutionException}. + * + * @param toolName the function name from the LLM response + * @param argsJson the raw JSON arguments string + */ + @SuppressWarnings("unchecked") + public CompletableFuture submitTool(String toolName, String argsJson, ToolContext ctx) { + AgentTool tool; synchronized (this) { - tool = tools.get(toolName); + tool = (AgentTool) tools.get(toolName); } if (tool == null) { - return "Error: unknown tool '" + toolName + "'"; + return CompletableFuture.completedFuture("Error: unknown tool '" + toolName + "'"); } - return executeWithTimeout((AgentTool) tool, argsJson); - } + ToolContext toolCtx = ctx != null ? ctx : ToolContext.EMPTY; - private String executeWithTimeout(AgentTool tool, String argsJson) { - Callable task = - () -> { - T args = JSON.readValue(argsJson, tool.argsType()); - return tool.execute(args); - }; - Future future; + CompletableFuture result = new CompletableFuture<>(); + Future submitted; try { - future = executor.submit(task); + submitted = + executor.submit( + () -> { + try { + Object args = JSON.readValue(argsJson, tool.argsType()); + String out = tool.execute(args, toolCtx); + // When the timeout handler interrupts us, the tool may still unwind cleanly and + // produce a stale return value — don't race the scheduler's timeout message with + // it. Let the timeout path be the single authority for the final result. + if (!Thread.currentThread().isInterrupted()) { + result.complete(out); + } + } catch (Exception e) { + result.complete("Error executing " + toolName + ": " + e.getMessage()); + } + }); } catch (RejectedExecutionException e) { - LOG.warn("Tool call '{}' rejected — worker pool saturated at {}", tool.name(), MAX_POOL_SIZE); - return "Error: tool call '" - + tool.name() - + "' rejected — server is handling too many concurrent tool calls. " - + "Retry in a moment."; - } - try { - return future.get(toolCallTimeoutSeconds, TimeUnit.SECONDS); - } catch (TimeoutException e) { - future.cancel(true); - LOG.warn("Tool call '{}' timed out after {} seconds", tool.name(), toolCallTimeoutSeconds); - return "Error: tool call '" - + tool.name() - + "' timed out after " - + toolCallTimeoutSeconds - + " seconds. " - + "Try simplifying the query or adding filters to reduce execution time."; - } catch (ExecutionException e) { - Throwable cause = e.getCause() != null ? e.getCause() : e; - return "Error executing " + tool.name() + ": " + cause.getMessage(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - return "Error: tool call '" + tool.name() + "' was interrupted."; + LOG.warn("Tool call '{}' rejected — worker pool saturated at {}", toolName, MAX_POOL_SIZE); + return CompletableFuture.completedFuture( + "Error: tool call '" + + toolName + + "' rejected — server is handling too many concurrent tool calls. " + + "Retry in a moment."); } + + ScheduledFuture timer = + timeoutScheduler.schedule( + () -> { + if (!result.isDone()) { + // cancel(true) interrupts the worker thread directly — the inner task's + // catch-all will see the interrupt and call result.complete(...), but the + // timeout message below wins because complete() is idempotent on first-winner. + submitted.cancel(true); + LOG.warn( + "Tool call '{}' timed out after {} seconds", toolName, toolCallTimeoutSeconds); + result.complete( + "Error: tool call '" + + toolName + + "' timed out after " + + toolCallTimeoutSeconds + + " seconds. " + + "Try simplifying the query or adding filters to reduce execution time."); + } + }, + toolCallTimeoutSeconds, + TimeUnit.SECONDS); + result.whenComplete((r, e) -> timer.cancel(false)); + return result; } private static ChatCompletionTool buildChatCompletionTool(AgentTool tool) { diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputArgs.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputArgs.java new file mode 100644 index 00000000000..b05a199090a --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputArgs.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool.output; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +/** Args for {@link GrepToolOutputTool}. */ +public class GrepToolOutputArgs { + + @JsonProperty(required = true) + @JsonPropertyDescription( + "Absolute path to the offloaded tool-output file, as reported by the truncation notice.") + public String path; + + @JsonProperty(required = true) + @JsonPropertyDescription( + "Java regex pattern to search for. Matches are returned as ':'.") + public String pattern; + + @JsonProperty("max_matches") + @JsonPropertyDescription("Maximum number of matching lines to return. Defaults to 50.") + public Integer maxMatches; +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputTool.java new file mode 100644 index 00000000000..f94fb21b315 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/GrepToolOutputTool.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool.output; + +import org.apache.kyuubi.engine.dataagent.runtime.ToolOutputStore; +import org.apache.kyuubi.engine.dataagent.tool.AgentTool; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; + +/** + * Regex-search a previously offloaded tool-output file. Companion to {@code + * ToolResultOffloadMiddleware}. + */ +public class GrepToolOutputTool implements AgentTool { + + public static final String NAME = "grep_tool_output"; + private static final int DEFAULT_MAX_MATCHES = 50; + + private final ToolOutputStore store; + + public GrepToolOutputTool(ToolOutputStore store) { + this.store = store; + } + + @Override + public String name() { + return NAME; + } + + @Override + public String description() { + return "Regex-search a previously offloaded tool-output file " + + "(the path is supplied in the truncation notice of a prior tool result). " + + "Cheaper than read_tool_output when you know what you're looking for. " + + "Returns matching lines as ':'."; + } + + @Override + public Class argsType() { + return GrepToolOutputArgs.class; + } + + @Override + public String execute(GrepToolOutputArgs args, ToolContext ctx) { + if (args == null || args.path == null || args.path.isEmpty()) { + return "Error: 'path' parameter is required."; + } + if (ctx == null || ctx.sessionId() == null) { + return "Error: grep_tool_output requires a session context."; + } + int max = args.maxMatches != null ? args.maxMatches : DEFAULT_MAX_MATCHES; + return store.grep(ctx.sessionId(), args.path, args.pattern, max); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputArgs.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputArgs.java new file mode 100644 index 00000000000..458fbfa4f66 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputArgs.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool.output; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +/** Args for {@link ReadToolOutputTool}. */ +public class ReadToolOutputArgs { + + @JsonProperty(required = true) + @JsonPropertyDescription( + "Absolute path to the offloaded tool-output file, as reported by the truncation notice.") + public String path; + + @JsonPropertyDescription("0-based line offset into the file. Defaults to 0.") + public Integer offset; + + @JsonPropertyDescription( + "Number of lines to return starting at 'offset'. Defaults to 200; capped at 500.") + public Integer limit; +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputTool.java new file mode 100644 index 00000000000..93e837998f3 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/output/ReadToolOutputTool.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.tool.output; + +import org.apache.kyuubi.engine.dataagent.runtime.ToolOutputStore; +import org.apache.kyuubi.engine.dataagent.tool.AgentTool; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; + +/** + * Read a line window from a previously offloaded tool-output file. Companion to {@code + * ToolResultOffloadMiddleware}. + */ +public class ReadToolOutputTool implements AgentTool { + + public static final String NAME = "read_tool_output"; + private static final int DEFAULT_LIMIT = 200; + private static final int MAX_LIMIT = 500; + + private final ToolOutputStore store; + + public ReadToolOutputTool(ToolOutputStore store) { + this.store = store; + } + + @Override + public String name() { + return NAME; + } + + @Override + public String description() { + return "Read a line window from a previously offloaded tool-output file " + + "(the path is supplied in the truncation notice of a prior tool result). " + + "Returns '[lines X-Y of Z total]' header followed by the requested window. " + + "Use when a prior tool's output was truncated and you need to inspect more of it."; + } + + @Override + public Class argsType() { + return ReadToolOutputArgs.class; + } + + @Override + public String execute(ReadToolOutputArgs args, ToolContext ctx) { + if (args == null || args.path == null || args.path.isEmpty()) { + return "Error: 'path' parameter is required."; + } + if (ctx == null || ctx.sessionId() == null) { + return "Error: read_tool_output requires a session context."; + } + int offset = args.offset != null ? args.offset : 0; + int limit = args.limit != null ? args.limit : DEFAULT_LIMIT; + if (limit > MAX_LIMIT) limit = MAX_LIMIT; + return store.read(ctx.sessionId(), args.path, offset, limit); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryTool.java index 06b12f2be72..88838ca477a 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryTool.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryTool.java @@ -19,6 +19,7 @@ import javax.sql.DataSource; import org.apache.kyuubi.engine.dataagent.tool.AgentTool; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; import org.apache.kyuubi.engine.dataagent.tool.ToolRiskLevel; /** @@ -69,7 +70,7 @@ public Class argsType() { } @Override - public String execute(SqlQueryArgs args) { + public String execute(SqlQueryArgs args, ToolContext ctx) { return SqlExecutor.execute(dataSource, args.sql, queryTimeoutSeconds); } } diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryTool.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryTool.java index 0c57cc049ed..9136b0aa903 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryTool.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryTool.java @@ -19,6 +19,7 @@ import javax.sql.DataSource; import org.apache.kyuubi.engine.dataagent.tool.AgentTool; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; import org.apache.kyuubi.engine.dataagent.tool.ToolRiskLevel; /** @@ -69,7 +70,7 @@ public Class argsType() { } @Override - public String execute(SqlQueryArgs args) { + public String execute(SqlQueryArgs args, ToolContext ctx) { String sql = args.sql; if (sql == null || sql.trim().isEmpty()) { return "Error: 'sql' parameter is required."; diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlReadOnlyChecker.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlReadOnlyChecker.java index d2cac9ae518..52b83182add 100644 --- a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlReadOnlyChecker.java +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/tool/sql/SqlReadOnlyChecker.java @@ -53,6 +53,7 @@ final class SqlReadOnlyChecker { *
  • {@code USE} — switch session catalog/database; not data-mutating *
  • {@code LIST} — Spark {@code LIST FILE} / {@code LIST JAR} inspection *
  • {@code HELP} — some engines expose interactive help + *
  • {@code PRAGMA} — SQLite schema/metadata inspection * */ private static final Set READ_ONLY_KEYWORDS = @@ -70,7 +71,8 @@ final class SqlReadOnlyChecker { "EXPLAIN", "USE", "LIST", - "HELP"))); + "HELP", + "PRAGMA"))); private SqlReadOnlyChecker() {} diff --git a/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/util/ConfUtils.java b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/util/ConfUtils.java new file mode 100644 index 00000000000..f4366094670 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/main/java/org/apache/kyuubi/engine/dataagent/util/ConfUtils.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.util; + +import org.apache.kyuubi.config.ConfigEntry; +import org.apache.kyuubi.config.KyuubiConf; +import org.apache.kyuubi.config.OptionalConfigEntry; + +/** Small helpers for reading typed values out of {@link KyuubiConf}. */ +public final class ConfUtils { + + private ConfUtils() {} + + /** Return the string value, or throw if the entry is not set. */ + public static String requireString(KyuubiConf conf, OptionalConfigEntry key) { + scala.Option opt = conf.get(key); + if (opt.isEmpty()) { + throw new IllegalArgumentException(key.key() + " is required"); + } + return opt.get(); + } + + /** Return the string value, or {@code null} if the entry is not set. */ + public static String optionalString(KyuubiConf conf, OptionalConfigEntry key) { + scala.Option opt = conf.get(key); + return opt.isDefined() ? opt.get() : null; + } + + /** Return the value for a raw key, or {@code null} if not set. */ + public static String optionalString(KyuubiConf conf, String key) { + scala.Option opt = conf.getOption(key); + return opt.isDefined() ? opt.get() : null; + } + + public static int intConf(KyuubiConf conf, ConfigEntry key) { + return ((Number) conf.get(key)).intValue(); + } + + public static long longConf(KyuubiConf conf, ConfigEntry key) { + return ((Number) conf.get(key)).longValue(); + } + + /** Read a millisecond-valued entry and return it as whole seconds. */ + public static long millisAsSeconds(KyuubiConf conf, ConfigEntry key) { + return longConf(conf, key) / 1000L; + } +} diff --git a/externals/kyuubi-data-agent-engine/src/main/scala/org/apache/kyuubi/engine/dataagent/operation/ExecuteStatement.scala b/externals/kyuubi-data-agent-engine/src/main/scala/org/apache/kyuubi/engine/dataagent/operation/ExecuteStatement.scala index 04d0defa2dc..3d902677a0a 100644 --- a/externals/kyuubi-data-agent-engine/src/main/scala/org/apache/kyuubi/engine/dataagent/operation/ExecuteStatement.scala +++ b/externals/kyuubi-data-agent-engine/src/main/scala/org/apache/kyuubi/engine/dataagent/operation/ExecuteStatement.scala @@ -25,7 +25,7 @@ import org.slf4j.MDC import org.apache.kyuubi.{KyuubiSQLException, Logging} import org.apache.kyuubi.config.KyuubiConf import org.apache.kyuubi.engine.dataagent.provider.{DataAgentProvider, ProviderRunRequest} -import org.apache.kyuubi.engine.dataagent.runtime.event.{AgentError, AgentEvent, AgentFinish, ApprovalRequest, ContentDelta, EventType, StepEnd, StepStart, ToolCall, ToolResult} +import org.apache.kyuubi.engine.dataagent.runtime.event.{AgentError, AgentEvent, AgentFinish, ApprovalRequest, Compaction, ContentDelta, EventType, StepEnd, StepStart, ToolCall, ToolResult} import org.apache.kyuubi.operation.OperationState import org.apache.kyuubi.operation.log.OperationLog import org.apache.kyuubi.session.Session @@ -117,7 +117,7 @@ class ExecuteStatement( n.put("type", sseType) n.put("id", toolCall.toolCallId()) n.put("name", toolCall.toolName()) - n.set("args", JSON.valueToTree(toolCall.toolArgs())) + n.set[ObjectNode]("args", JSON.valueToTree(toolCall.toolArgs())) })) case EventType.TOOL_RESULT => val toolResult = event.asInstanceOf[ToolResult] @@ -148,6 +148,15 @@ class ExecuteStatement( n.set("args", JSON.valueToTree(req.toolArgs())) n.put("riskLevel", req.riskLevel().name()) })) + case EventType.COMPACTION => + val c = event.asInstanceOf[Compaction] + incrementalIter.append(Array(toJson { n => + n.put("type", sseType) + n.put("summarized", c.summarizedCount()) + n.put("kept", c.keptCount()) + n.put("triggerTokens", c.triggerTokens()) + n.put("observedTokens", c.observedTokens()) + })) case EventType.AGENT_FINISH => val finish = event.asInstanceOf[AgentFinish] incrementalIter.append(Array(toJson { n => diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialectTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialectTest.java index e728ff871e7..c43942a8f35 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialectTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/datasource/JdbcDialectTest.java @@ -19,6 +19,7 @@ import static org.junit.Assert.*; +import org.apache.kyuubi.engine.dataagent.datasource.dialect.GenericDialect; import org.junit.Test; public class JdbcDialectTest { diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java index 4e713b9cca3..cc45ebdc7e7 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/mysql/DialectTest.java @@ -20,8 +20,9 @@ import static org.junit.Assert.*; import org.apache.kyuubi.engine.dataagent.datasource.JdbcDialect; -import org.apache.kyuubi.engine.dataagent.datasource.MysqlDialect; +import org.apache.kyuubi.engine.dataagent.datasource.dialect.MysqlDialect; import org.apache.kyuubi.engine.dataagent.prompt.SystemPromptBuilder; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; import org.apache.kyuubi.engine.dataagent.tool.sql.RunSelectQueryTool; import org.apache.kyuubi.engine.dataagent.tool.sql.SqlQueryArgs; import org.junit.BeforeClass; @@ -66,7 +67,7 @@ public void testBacktickQuotingWithReservedWord() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT " + quotedCol + " FROM " + quotedTable + " WHERE id = 1"; - String result = selectTool.execute(args); + String result = selectTool.execute(args, ToolContext.EMPTY); assertFalse(result.startsWith("Error:")); assertTrue(result.contains("value1")); diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/provider/mock/MockLlmProvider.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/provider/mock/MockLlmProvider.java new file mode 100644 index 00000000000..2756b3e2087 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/provider/mock/MockLlmProvider.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.provider.mock; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import javax.sql.DataSource; +import org.apache.kyuubi.config.KyuubiConf; +import org.apache.kyuubi.engine.dataagent.datasource.DataSourceFactory; +import org.apache.kyuubi.engine.dataagent.provider.DataAgentProvider; +import org.apache.kyuubi.engine.dataagent.provider.ProviderRunRequest; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentFinish; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentStart; +import org.apache.kyuubi.engine.dataagent.runtime.event.ContentComplete; +import org.apache.kyuubi.engine.dataagent.runtime.event.ContentDelta; +import org.apache.kyuubi.engine.dataagent.runtime.event.StepEnd; +import org.apache.kyuubi.engine.dataagent.runtime.event.StepStart; +import org.apache.kyuubi.engine.dataagent.runtime.event.ToolCall; +import org.apache.kyuubi.engine.dataagent.runtime.event.ToolResult; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; +import org.apache.kyuubi.engine.dataagent.tool.sql.RunSelectQueryTool; +import org.apache.kyuubi.engine.dataagent.util.ConfUtils; + +/** + * A mock LLM provider for testing the full tool-call pipeline without a real LLM. Simulates the + * ReAct loop: extracts SQL from the user question, executes it via SqlQueryTool, and returns the + * result as a formatted answer. + * + *

    Recognizes two patterns: + * + *

      + *
    • Questions containing SQL keywords (SELECT, SHOW, DESCRIBE) — extracts and executes the SQL + *
    • All other questions — returns a canned response without tool calls + *
    + */ +public class MockLlmProvider implements DataAgentProvider { + + private static final Pattern SQL_PATTERN = + Pattern.compile( + "(SELECT\\b.+|SHOW\\b.+|DESCRIBE\\b.+)", Pattern.CASE_INSENSITIVE | Pattern.DOTALL); + + /** + * Simple natural-language-to-SQL mappings so tests can use human-readable questions instead of + * raw SQL. Checked before the regex pattern — if a question matches a key (case-insensitive + * prefix), the mapped SQL is executed. + */ + private static final Map NL_TO_SQL = new java.util.LinkedHashMap<>(); + + static { + NL_TO_SQL.put( + "list all employee names and departments", + "SELECT name, department FROM employees ORDER BY id"); + NL_TO_SQL.put( + "how many employees in each department", + "SELECT department, COUNT(*) as cnt FROM employees GROUP BY department"); + NL_TO_SQL.put("count the total number of employees", "SELECT COUNT(*) FROM employees"); + } + + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + private final ToolRegistry toolRegistry; + private final DataSource dataSource; + + public MockLlmProvider(KyuubiConf conf) { + String jdbcUrl = ConfUtils.requireString(conf, KyuubiConf.ENGINE_DATA_AGENT_JDBC_URL()); + this.dataSource = DataSourceFactory.create(jdbcUrl); + this.toolRegistry = new ToolRegistry(30); + this.toolRegistry.register(new RunSelectQueryTool(dataSource, 0)); + } + + @Override + public void open(String sessionId, String user) { + sessions.put(sessionId, new Object()); + } + + @Override + public void run(String sessionId, ProviderRunRequest request, Consumer onEvent) { + String question = request.getQuestion(); + onEvent.accept(new AgentStart()); + + // Trigger an error for testing the error path in ExecuteStatement + if (question.trim().equalsIgnoreCase("__error__")) { + throw new RuntimeException("MockLlmProvider simulated failure"); + } + + // First check natural-language mappings, then fall back to SQL pattern extraction + String sql = resolveToSql(question); + if (sql != null) { + runWithToolCall(sql, onEvent); + } else { + runWithoutToolCall(question, onEvent); + } + } + + private void runWithToolCall(String sql, Consumer onEvent) { + // Step 1: LLM "decides" to call sql_query tool + onEvent.accept(new StepStart(1)); + String toolCallId = "mock_call_" + System.nanoTime(); + Map toolArgs = new HashMap<>(); + toolArgs.put("sql", sql); + onEvent.accept(new ToolCall(toolCallId, "run_select_query", toolArgs)); + + // Execute the tool + String toolOutput = + toolRegistry.executeTool("run_select_query", "{\"sql\":\"" + escapeJson(sql) + "\"}"); + onEvent.accept(new ToolResult(toolCallId, "run_select_query", toolOutput, false)); + onEvent.accept(new StepEnd(1)); + + // Step 2: LLM "summarizes" the result + onEvent.accept(new StepStart(2)); + String answer = "Based on the query result:\n\n" + toolOutput; + for (String token : answer.split("(?<=\\n)")) { + onEvent.accept(new ContentDelta(token)); + } + onEvent.accept(new ContentComplete(answer)); + onEvent.accept(new StepEnd(2)); + onEvent.accept(new AgentFinish(2, 100, 50, 150)); + } + + private void runWithoutToolCall(String question, Consumer onEvent) { + onEvent.accept(new StepStart(1)); + String answer = "[MockLLM] No SQL detected in: " + question; + onEvent.accept(new ContentDelta(answer)); + onEvent.accept(new ContentComplete(answer)); + onEvent.accept(new StepEnd(1)); + onEvent.accept(new AgentFinish(1, 50, 20, 70)); + } + + @Override + public void close(String sessionId) { + sessions.remove(sessionId); + } + + @Override + public void stop() { + if (dataSource instanceof com.zaxxer.hikari.HikariDataSource) { + ((com.zaxxer.hikari.HikariDataSource) dataSource).close(); + } + } + + /** + * Resolve a user question to SQL. Checks NL_TO_SQL mappings first, then falls back to regex + * extraction of raw SQL from the input. + */ + private static String resolveToSql(String question) { + String lower = question.toLowerCase().trim(); + for (Map.Entry entry : NL_TO_SQL.entrySet()) { + if (lower.startsWith(entry.getKey())) { + return entry.getValue(); + } + } + Matcher matcher = SQL_PATTERN.matcher(question); + if (matcher.find()) { + return matcher.group(1).trim(); + } + return null; + } + + private static String escapeJson(String s) { + return s.replace("\\", "\\\\") + .replace("\"", "\\\"") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\t", "\\t"); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemoryTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemoryTest.java new file mode 100644 index 00000000000..d8eb96cfbc1 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ConversationMemoryTest.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +import static org.junit.Assert.*; + +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import com.openai.models.chat.completions.ChatCompletionUserMessageParam; +import java.util.Collections; +import org.junit.Test; + +public class ConversationMemoryTest { + + @Test + public void testReplaceHistoryClearsLastTotalTokensButKeepsCumulative() { + ConversationMemory memory = new ConversationMemory(); + memory.addCumulativeTokens(100, 50, 150); + memory.addCumulativeTokens(200, 80, 280); + assertEquals(280L, memory.getLastTotalTokens()); + assertEquals(300L, memory.getCumulativePromptTokens()); + assertEquals(430L, memory.getCumulativeTotalTokens()); + + ChatCompletionMessageParam summary = + ChatCompletionMessageParam.ofUser( + ChatCompletionUserMessageParam.builder().content("summary").build()); + memory.replaceHistory(Collections.singletonList(summary)); + + assertEquals("lastTotalTokens reset after compaction", 0L, memory.getLastTotalTokens()); + assertEquals("cumulative totals preserved", 300L, memory.getCumulativePromptTokens()); + assertEquals("cumulative totals preserved", 430L, memory.getCumulativeTotalTokens()); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgentLiveTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgentLiveTest.java new file mode 100644 index 00000000000..75553f46998 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ReactAgentLiveTest.java @@ -0,0 +1,568 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +import static org.junit.Assert.*; +import static org.junit.Assume.assumeTrue; + +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import java.io.File; +import java.sql.Connection; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import org.apache.kyuubi.engine.dataagent.prompt.SystemPromptBuilder; +import org.apache.kyuubi.engine.dataagent.runtime.event.*; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.ApprovalMiddleware; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.LoggingMiddleware; +import org.apache.kyuubi.engine.dataagent.runtime.middleware.ToolResultOffloadMiddleware; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; +import org.apache.kyuubi.engine.dataagent.tool.sql.RunMutationQueryTool; +import org.apache.kyuubi.engine.dataagent.tool.sql.RunSelectQueryTool; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.sqlite.SQLiteDataSource; + +/** + * Live integration test with a real LLM and real SQLite database. Exercises the full ReAct loop: + * LLM reasoning -> tool calls -> result verification. Requires DATA_AGENT_LLM_API_KEY and + * DATA_AGENT_LLM_API_URL environment variables. Works with any OpenAI-compatible LLM service. + */ +public class ReactAgentLiveTest { + + private static final String API_KEY = System.getenv().getOrDefault("DATA_AGENT_LLM_API_KEY", ""); + private static final String BASE_URL = System.getenv().getOrDefault("DATA_AGENT_LLM_API_URL", ""); + private static final String MODEL_NAME = + System.getenv().getOrDefault("DATA_AGENT_LLM_MODEL", "gpt-4o"); + + private static final String SYSTEM_PROMPT = + SystemPromptBuilder.create().datasource("sqlite").build(); + + private final List tempFiles = new ArrayList<>(); + private OpenAIClient client; + + @Before + public void setUp() { + assumeTrue("DATA_AGENT_LLM_API_KEY not set, skipping live tests", !API_KEY.isEmpty()); + assumeTrue("DATA_AGENT_LLM_API_URL not set, skipping live tests", !BASE_URL.isEmpty()); + client = OpenAIOkHttpClient.builder().apiKey(API_KEY).baseUrl(BASE_URL).build(); + } + + @After + public void tearDown() { + tempFiles.forEach(File::delete); + } + + @Test + public void testPlainTextStreamingWithoutTools() { + ReactAgent agent = + ReactAgent.builder() + .client(client) + .modelName(MODEL_NAME) + .toolRegistry(new ToolRegistry(30)) + .addMiddleware(new LoggingMiddleware()) + .maxIterations(3) + .systemPrompt("You are a helpful assistant. Answer concisely in 1-2 sentences.") + .build(); + + List events = new CopyOnWriteArrayList<>(); + ConversationMemory memory = new ConversationMemory(); + + agent.run(new AgentInvocation("What is Apache Kyuubi?"), memory, events::add); + + List deltas = + events.stream() + .filter(e -> e instanceof ContentDelta) + .map(e -> ((ContentDelta) e).text()) + .collect(Collectors.toList()); + + assertTrue("Expected multiple ContentDelta events", deltas.size() > 1); + assertFalse("Streamed text should not be empty", String.join("", deltas).isEmpty()); + assertTrue(events.stream().anyMatch(e -> e instanceof StepStart)); + assertTrue(events.stream().anyMatch(e -> e instanceof ContentComplete)); + assertTrue(events.get(events.size() - 1) instanceof AgentFinish); + assertEquals(2, memory.getHistory().size()); // user + assistant + } + + @Test + public void testFullReActLoopWithSchemaInspectAndSqlQuery() { + SQLiteDataSource ds = createSalesDatabase(); + ToolRegistry registry = new ToolRegistry(30); + registry.register(new RunSelectQueryTool(ds, 0)); + + ReactAgent agent = + ReactAgent.builder() + .client(client) + .modelName(MODEL_NAME) + .toolRegistry(registry) + .addMiddleware(new LoggingMiddleware()) + .maxIterations(10) + .systemPrompt(SYSTEM_PROMPT) + .build(); + + List events = new CopyOnWriteArrayList<>(); + ConversationMemory memory = new ConversationMemory(); + + agent.run( + new AgentInvocation( + "What is the total revenue by product category? Which category has the highest revenue?"), + memory, + events::add); + + // Verify tool calls happened + List toolCalls = + events.stream() + .filter(e -> e instanceof ToolCall) + .map(e -> (ToolCall) e) + .collect(Collectors.toList()); + List toolResults = + events.stream() + .filter(e -> e instanceof ToolResult) + .map(e -> (ToolResult) e) + .collect(Collectors.toList()); + + assertFalse("Agent should have called at least one tool", toolCalls.isEmpty()); + assertFalse("Agent should have received tool results", toolResults.isEmpty()); + assertTrue("Tool calls should not error", toolResults.stream().noneMatch(ToolResult::isError)); + + // Verify SQL query was called + assertTrue( + "Agent should execute at least one SQL query", + toolCalls.stream().anyMatch(tc -> "run_select_query".equals(tc.toolName()))); + + // Verify final answer mentions "Electronics" (highest revenue) + List completions = + events.stream() + .filter(e -> e instanceof ContentComplete) + .map(e -> ((ContentComplete) e).fullText()) + .collect(Collectors.toList()); + String lastAnswer = completions.get(completions.size() - 1); + assertTrue( + "Final answer should mention Electronics, got: " + lastAnswer, + lastAnswer.toLowerCase().contains("electronics")); + + // Verify agent finished successfully + AgentEvent last = events.get(events.size() - 1); + assertTrue(last instanceof AgentFinish); + assertTrue("Should take multiple steps", ((AgentFinish) last).totalSteps() > 1); + } + + @Test + public void testMultiTurnConversationWithToolUse() { + SQLiteDataSource ds = createSalesDatabase(); + ToolRegistry registry = new ToolRegistry(30); + registry.register(new RunSelectQueryTool(ds, 0)); + + ReactAgent agent = + ReactAgent.builder() + .client(client) + .modelName(MODEL_NAME) + .toolRegistry(registry) + .addMiddleware(new LoggingMiddleware()) + .maxIterations(10) + .systemPrompt(SYSTEM_PROMPT) + .build(); + + // Shared memory across turns + ConversationMemory memory = new ConversationMemory(); + + // Turn 1 + List events1 = new CopyOnWriteArrayList<>(); + agent.run(new AgentInvocation("How many orders are there in total?"), memory, events1::add); + + assertTrue( + "Turn 1 should query the database", + events1.stream() + .anyMatch( + e -> + e instanceof ToolCall && "run_select_query".equals(((ToolCall) e).toolName()))); + assertTrue(events1.get(events1.size() - 1) instanceof AgentFinish); + + // Turn 2: follow-up relying on conversation context + List events2 = new CopyOnWriteArrayList<>(); + agent.run( + new AgentInvocation("Now show me only orders above 500 dollars."), memory, events2::add); + + assertTrue( + "Turn 2 should also query the database", + events2.stream() + .anyMatch( + e -> + e instanceof ToolCall && "run_select_query".equals(((ToolCall) e).toolName()))); + assertTrue(events2.get(events2.size() - 1) instanceof AgentFinish); + + // Verify memory accumulated across both turns + assertTrue( + "Memory should contain messages from both turns, got " + memory.getHistory().size(), + memory.getHistory().size() > 4); + } + + @Test + public void testToolOutputOffloadThenGrep() throws Exception { + // Large result forces ToolResultOffloadMiddleware to truncate the tool output and + // emit a preview hint telling the LLM to use grep_tool_output / read_tool_output. + // A correct answer proves the LLM read the hint and drove the retrieval itself. + SQLiteDataSource ds = createNeedleInHaystackDatabase(); + ToolRegistry registry = new ToolRegistry(30); + registry.register(new RunSelectQueryTool(ds, 0)); + + ReactAgent agent = + ReactAgent.builder() + .client(client) + .modelName(MODEL_NAME) + .toolRegistry(registry) + .addMiddleware(new LoggingMiddleware()) + .addMiddleware(new ToolResultOffloadMiddleware()) + .maxIterations(10) + .systemPrompt(SYSTEM_PROMPT) + .build(); + + List events = new CopyOnWriteArrayList<>(); + ConversationMemory memory = new ConversationMemory(); + + // Explicit workflow: full-table scan first, then retrieval tool. Without this hint a + // capable LLM just issues SELECT note FROM events WHERE tag='NEEDLE' and skips the + // offload path entirely -- smart behavior, but defeats the purpose of this test. + agent.run( + new AgentInvocation( + "Step 1: issue exactly this query: SELECT id, tag, note FROM events (no" + + " WHERE clause, return every row). Step 2: the result will be" + + " truncated; call the grep_tool_output tool with pattern 'NEEDLE' on" + + " the saved output file to find the matching row. Step 3: respond with" + + " ONLY the note text from that row, nothing else. Do NOT add a WHERE" + + " clause. Do NOT issue any other SQL.") + // Offload middleware requires a non-null session id -- without it the offload + // path skips entirely (see ToolResultOffloadMiddleware.afterToolCall). + .sessionId("offload-live-" + java.util.UUID.randomUUID()), + memory, + events::add); + + List toolCalls = + events.stream() + .filter(e -> e instanceof ToolCall) + .map(e -> (ToolCall) e) + .collect(Collectors.toList()); + List toolResults = + events.stream() + .filter(e -> e instanceof ToolResult) + .map(e -> (ToolResult) e) + .collect(Collectors.toList()); + + // Dump the tool trace -- diagnoses whether the LLM followed the workflow, added a + // WHERE clause despite instructions, or deviated in some other way. + for (ToolCall tc : toolCalls) { + System.out.println("[ToolCall] " + tc.toolName() + " args=" + tc.toolArgs()); + } + for (ToolResult tr : toolResults) { + String out = tr.output(); + String preview = + out.length() > 300 + ? out.substring(0, 300) + "...(+" + (out.length() - 300) + " chars)" + : out; + System.out.println("[ToolResult] " + tr.toolName() + " -> " + preview); + } + + assertTrue( + "Agent should have run a select query first", + toolCalls.stream().anyMatch(tc -> "run_select_query".equals(tc.toolName()))); + + // The SELECT returned 800 rows, which must trip the offload threshold. + assertTrue( + "Expected at least one offload preview marker in tool results", + toolResults.stream().anyMatch(tr -> tr.output().contains("Tool output truncated"))); + + assertTrue( + "Agent should have used grep_tool_output or read_tool_output after seeing" + + " the offload preview; actual tool calls: " + + toolCalls.stream().map(ToolCall::toolName).collect(Collectors.toList()), + toolCalls.stream() + .anyMatch( + tc -> + "grep_tool_output".equals(tc.toolName()) + || "read_tool_output".equals(tc.toolName()))); + + String finalAnswer = + events.stream() + .filter(e -> e instanceof ContentComplete) + .map(e -> ((ContentComplete) e).fullText()) + .reduce((a, b) -> b) + .orElse(""); + assertTrue( + "Final answer should contain the needle note 'the-answer-is-42', got: " + finalAnswer, + finalAnswer.contains("the-answer-is-42")); + } + + @Test + public void testApprovalApproveFlow() throws Exception { + // Real LLM picks run_mutation_query (DESTRUCTIVE) -> ApprovalMiddleware pauses -> + // background thread resolves(approved=true) -> mutation actually runs in SQLite. + SQLiteDataSource ds = createCountersDatabase(); + ToolRegistry registry = new ToolRegistry(30); + registry.register(new RunSelectQueryTool(ds, 0)); + registry.register(new RunMutationQueryTool(ds, 0)); + + ApprovalMiddleware approval = new ApprovalMiddleware(30); + + ReactAgent agent = + ReactAgent.builder() + .client(client) + .modelName(MODEL_NAME) + .toolRegistry(registry) + .addMiddleware(new LoggingMiddleware()) + .addMiddleware(approval) + .maxIterations(10) + .systemPrompt(SYSTEM_PROMPT) + .build(); + + List events = new CopyOnWriteArrayList<>(); + ConversationMemory memory = new ConversationMemory(); + + // Auto-approve any approval request that shows up. + ExecutorService approver = Executors.newSingleThreadExecutor(); + Consumer listener = + event -> { + events.add(event); + if (event instanceof ApprovalRequest) { + String rid = ((ApprovalRequest) event).requestId(); + approver.submit(() -> approval.resolve(rid, true)); + } + }; + + try { + agent.run( + new AgentInvocation( + "Increment the 'hits' counter in the counters table by 1, then tell me its" + + " new value. Respond with ONLY the new value, no explanation."), + memory, + listener); + } finally { + approver.shutdown(); + approver.awaitTermination(5, TimeUnit.SECONDS); + } + + List approvals = + events.stream() + .filter(e -> e instanceof ApprovalRequest) + .map(e -> (ApprovalRequest) e) + .collect(Collectors.toList()); + assertFalse("Expected at least one ApprovalRequest", approvals.isEmpty()); + assertTrue( + "ApprovalRequest should target run_mutation_query", + approvals.stream().anyMatch(a -> "run_mutation_query".equals(a.toolName()))); + + // Mutation must have actually executed — check the DB directly. + try (Connection conn = ds.getConnection(); + Statement stmt = conn.createStatement(); + java.sql.ResultSet rs = stmt.executeQuery("SELECT value FROM counters WHERE name='hits'")) { + assertTrue(rs.next()); + assertEquals("Counter should be 1 after approved mutation", 1, rs.getInt(1)); + } + + String finalAnswer = + events.stream() + .filter(e -> e instanceof ContentComplete) + .map(e -> ((ContentComplete) e).fullText()) + .reduce((a, b) -> b) + .orElse(""); + assertTrue("Final answer should mention 1, got: " + finalAnswer, finalAnswer.contains("1")); + } + + @Test + public void testApprovalDenyFlow() throws Exception { + // Same setup as the approve test, but the approval listener denies. The mutation + // must NOT run, and the LLM must surface the denial to the user naturally. + SQLiteDataSource ds = createCountersDatabase(); + ToolRegistry registry = new ToolRegistry(30); + registry.register(new RunSelectQueryTool(ds, 0)); + registry.register(new RunMutationQueryTool(ds, 0)); + + ApprovalMiddleware approval = new ApprovalMiddleware(30); + + ReactAgent agent = + ReactAgent.builder() + .client(client) + .modelName(MODEL_NAME) + .toolRegistry(registry) + .addMiddleware(new LoggingMiddleware()) + .addMiddleware(approval) + .maxIterations(10) + .systemPrompt(SYSTEM_PROMPT) + .build(); + + List events = new CopyOnWriteArrayList<>(); + ConversationMemory memory = new ConversationMemory(); + + ExecutorService approver = Executors.newSingleThreadExecutor(); + Consumer listener = + event -> { + events.add(event); + if (event instanceof ApprovalRequest) { + String rid = ((ApprovalRequest) event).requestId(); + approver.submit(() -> approval.resolve(rid, false)); + } + }; + + try { + agent.run( + new AgentInvocation( + "Delete all rows from the counters table. If you cannot, explain why."), + memory, + listener); + } finally { + approver.shutdown(); + approver.awaitTermination(5, TimeUnit.SECONDS); + } + + assertTrue( + "Expected at least one ApprovalRequest", + events.stream().anyMatch(e -> e instanceof ApprovalRequest)); + + // DB must be untouched. + try (Connection conn = ds.getConnection(); + Statement stmt = conn.createStatement(); + java.sql.ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM counters")) { + assertTrue(rs.next()); + assertTrue("Counters rows must survive denied mutation", rs.getInt(1) > 0); + } + + // LLM should tell the user the operation didn't go through. Loose lexical check to + // absorb model wording drift. + String finalAnswer = + events.stream() + .filter(e -> e instanceof ContentComplete) + .map(e -> ((ContentComplete) e).fullText()) + .reduce((a, b) -> b) + .orElse("") + .toLowerCase(); + assertTrue( + "Final answer should indicate the deletion was refused/denied/not executed, got: " + + finalAnswer, + finalAnswer.contains("den") + || finalAnswer.contains("reject") + || finalAnswer.contains("not ") + || finalAnswer.contains("refus") + || finalAnswer.contains("unable") + || finalAnswer.contains("could not")); + } + + // --- Helpers --- + + private SQLiteDataSource createSalesDatabase() { + SQLiteDataSource ds = createDataSource(); + try (Connection conn = ds.getConnection(); + Statement stmt = conn.createStatement()) { + stmt.execute( + "CREATE TABLE products (" + + "id INTEGER PRIMARY KEY, name TEXT NOT NULL, " + + "category TEXT NOT NULL, price REAL NOT NULL)"); + stmt.execute( + "INSERT INTO products VALUES " + + "(1, 'Laptop', 'Electronics', 999.99), " + + "(2, 'Headphones', 'Electronics', 199.99), " + + "(3, 'T-Shirt', 'Clothing', 29.99), " + + "(4, 'Jeans', 'Clothing', 59.99), " + + "(5, 'Novel', 'Books', 14.99), " + + "(6, 'Textbook', 'Books', 89.99)"); + stmt.execute( + "CREATE TABLE orders (" + + "id INTEGER PRIMARY KEY, product_id INTEGER NOT NULL, " + + "customer_name TEXT NOT NULL, quantity INTEGER NOT NULL, " + + "order_date TEXT NOT NULL, " + + "FOREIGN KEY (product_id) REFERENCES products(id))"); + stmt.execute( + "INSERT INTO orders VALUES " + + "(1, 1, 'Alice', 1, '2024-01-15'), " + + "(2, 2, 'Bob', 2, '2024-01-20'), " + + "(3, 3, 'Charlie', 3, '2024-02-01'), " + + "(4, 4, 'Alice', 1, '2024-02-10'), " + + "(5, 5, 'Bob', 5, '2024-02-15'), " + + "(6, 1, 'Diana', 1, '2024-03-01'), " + + "(7, 6, 'Charlie', 2, '2024-03-05'), " + + "(8, 2, 'Diana', 1, '2024-03-10')"); + } catch (Exception e) { + throw new RuntimeException(e); + } + return ds; + } + + private SQLiteDataSource createNeedleInHaystackDatabase() { + SQLiteDataSource ds = createDataSource(); + try (Connection conn = ds.getConnection(); + Statement stmt = conn.createStatement()) { + stmt.execute( + "CREATE TABLE events (" + + "id INTEGER PRIMARY KEY, tag TEXT NOT NULL, note TEXT NOT NULL)"); + conn.setAutoCommit(false); + try (java.sql.PreparedStatement ps = + conn.prepareStatement("INSERT INTO events VALUES (?, ?, ?)")) { + // 800 filler rows, guaranteed to blow past ToolResultOffloadMiddleware's 500-line + // threshold when the LLM issues a SELECT *. Exactly one row carries the NEEDLE tag. + int needleId = 573; + for (int i = 1; i <= 800; i++) { + ps.setInt(1, i); + if (i == needleId) { + ps.setString(2, "NEEDLE"); + ps.setString(3, "the-answer-is-42"); + } else { + ps.setString(2, "FILLER"); + ps.setString(3, "filler-note-" + i); + } + ps.addBatch(); + } + ps.executeBatch(); + } + conn.commit(); + } catch (Exception e) { + throw new RuntimeException(e); + } + return ds; + } + + private SQLiteDataSource createCountersDatabase() { + SQLiteDataSource ds = createDataSource(); + try (Connection conn = ds.getConnection(); + Statement stmt = conn.createStatement()) { + stmt.execute("CREATE TABLE counters (name TEXT PRIMARY KEY, value INTEGER NOT NULL)"); + stmt.execute("INSERT INTO counters VALUES ('hits', 0)"); + } catch (Exception e) { + throw new RuntimeException(e); + } + return ds; + } + + private SQLiteDataSource createDataSource() { + try { + File tmpFile = File.createTempFile("kyuubi-agent-live-", ".db"); + tmpFile.deleteOnExit(); + tempFiles.add(tmpFile); + SQLiteDataSource ds = new SQLiteDataSource(); + ds.setUrl("jdbc:sqlite:" + tmpFile.getAbsolutePath()); + return ds; + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStoreTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStoreTest.java new file mode 100644 index 00000000000..e3eda5ff5e4 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/ToolOutputStoreTest.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime; + +import static org.junit.Assert.*; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import org.junit.Before; +import org.junit.Test; + +public class ToolOutputStoreTest { + + private ToolOutputStore store; + + @Before + public void setUp() throws IOException { + store = ToolOutputStore.create(); + } + + @Test + public void writeAndReadWindow() throws IOException { + StringBuilder sb = new StringBuilder(); + for (int i = 1; i <= 100; i++) sb.append("row").append(i).append('\n'); + Path p = store.write("sess1", "call1", sb.toString()); + assertTrue(Files.exists(p)); + + String out = store.read("sess1", p.toString(), 10, 5); + assertTrue(out, out.contains("lines 11-15 of")); + assertTrue(out, out.contains("row11")); + assertTrue(out, out.contains("row15")); + assertFalse(out, out.contains("row16")); + assertFalse(out, out.contains("row10")); + } + + @Test + public void grepReturnsMatchingLinesWithLineNumbers() throws IOException { + String content = "apple\nbanana\ncherry\napple pie\ndate\n"; + Path p = store.write("sess1", "call1", content); + + String out = store.grep("sess1", p.toString(), "apple", 10); + assertTrue(out, out.contains("1:apple")); + assertTrue(out, out.contains("4:apple pie")); + assertFalse(out, out.contains("banana")); + } + + @Test + public void grepRespectsMaxMatches() throws IOException { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 20; i++) sb.append("hit\n"); + Path p = store.write("sess1", "call1", sb.toString()); + + String out = store.grep("sess1", p.toString(), "hit", 3); + assertTrue(out, out.contains("[3 matches]")); + assertTrue(out, out.contains("1:hit")); + assertTrue(out, out.contains("3:hit")); + assertFalse("should stop after 3 matches", out.contains("4:hit")); + } + + @Test + public void grepInvalidRegexReturnsError() throws IOException { + Path p = store.write("sess1", "call1", "x\n"); + String out = store.grep("sess1", p.toString(), "[", 10); + assertTrue(out, out.startsWith("Error:")); + } + + @Test + public void readRejectsCrossSessionPath() throws IOException { + Path victim = store.write("victim", "secret_call", "top secret\n"); + assertTrue(Files.exists(victim)); + + String out = store.read("attacker", victim.toString(), 0, 10); + assertTrue(out, out.startsWith("Error:")); + assertFalse(out, out.contains("top secret")); + } + + @Test + public void grepRejectsCrossSessionPath() throws IOException { + Path victim = store.write("victim", "secret_call", "api_key=xyz\n"); + String out = store.grep("attacker", victim.toString(), "api_key", 10); + assertTrue(out, out.startsWith("Error:")); + assertFalse(out, out.contains("xyz")); + } + + @Test + public void cleanupSessionRemovesSubtree() throws IOException { + Path p1 = store.write("sessA", "call1", "a\n"); + Path p2 = store.write("sessA", "call2", "b\n"); + Path p3 = store.write("sessB", "call1", "c\n"); + assertTrue(Files.exists(p1)); + assertTrue(Files.exists(p2)); + assertTrue(Files.exists(p3)); + + store.cleanupSession("sessA"); + + assertFalse(Files.exists(p1)); + assertFalse(Files.exists(p2)); + assertTrue("other sessions untouched", Files.exists(p3)); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventTest.java index 50d22da416d..b6a5f093b61 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/event/EventTest.java @@ -113,6 +113,7 @@ public void testEventTypeSseNames() { assertEquals("step_end", EventType.STEP_END.sseEventName()); assertEquals("error", EventType.ERROR.sseEventName()); assertEquals("approval_request", EventType.APPROVAL_REQUEST.sseEventName()); + assertEquals("compaction", EventType.COMPACTION.sseEventName()); assertEquals("agent_finish", EventType.AGENT_FINISH.sseEventName()); } @@ -123,6 +124,6 @@ public void testAllEventTypesHaveUniqueSseNames() { for (EventType type : values) { assertTrue("Duplicate SSE name: " + type.sseEventName(), names.add(type.sseEventName())); } - assertEquals(10, values.length); + assertEquals(11, values.length); } } diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddlewareTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddlewareTest.java new file mode 100644 index 00000000000..a84bbc25948 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ApprovalMiddlewareTest.java @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import static org.junit.Assert.*; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.ApprovalMode; +import org.apache.kyuubi.engine.dataagent.runtime.ConversationMemory; +import org.apache.kyuubi.engine.dataagent.runtime.event.AgentEvent; +import org.apache.kyuubi.engine.dataagent.runtime.event.ApprovalRequest; +import org.apache.kyuubi.engine.dataagent.runtime.event.EventType; +import org.apache.kyuubi.engine.dataagent.tool.AgentTool; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; +import org.apache.kyuubi.engine.dataagent.tool.ToolRegistry; +import org.apache.kyuubi.engine.dataagent.tool.ToolRiskLevel; +import org.junit.Before; +import org.junit.Test; + +public class ApprovalMiddlewareTest { + + private ToolRegistry registry; + private List emittedEvents; + + @Before + public void setUp() { + registry = new ToolRegistry(30); + registry.register(safeTool("safe_tool")); + registry.register(destructiveTool("dangerous_tool")); + emittedEvents = Collections.synchronizedList(new ArrayList<>()); + } + + // --- Auto-approve mode: all tools pass --- + + @Test + public void testAutoApproveModeSkipsAllApproval() { + ApprovalMiddleware mw = newApprovalMiddleware(); + AgentRunContext ctx = makeContext(ApprovalMode.AUTO_APPROVE); + + assertNull(mw.beforeToolCall(ctx, "tc1", "dangerous_tool", Collections.emptyMap())); + assertNull(mw.beforeToolCall(ctx, "tc2", "safe_tool", Collections.emptyMap())); + assertTrue("No approval events should be emitted", emittedEvents.isEmpty()); + } + + // --- Normal mode: safe auto-approved, destructive needs approval --- + + @Test + public void testNormalModeAutoApprovesSafeTool() { + ApprovalMiddleware mw = newApprovalMiddleware(); + AgentRunContext ctx = makeContext(ApprovalMode.NORMAL); + + assertNull(mw.beforeToolCall(ctx, "tc1", "safe_tool", Collections.emptyMap())); + assertTrue(emittedEvents.isEmpty()); + } + + @Test + public void testNormalModeRequiresApprovalForDestructiveTool() throws Exception { + ApprovalMiddleware mw = newApprovalMiddleware(5); + AgentRunContext ctx = makeContext(ApprovalMode.NORMAL); + + ExecutorService exec = Executors.newSingleThreadExecutor(); + try { + CountDownLatch eventEmitted = new CountDownLatch(1); + // Capture the emitted event to get the requestId + ctx.setEventEmitter( + event -> { + emittedEvents.add(event); + eventEmitted.countDown(); + }); + + Future future = + exec.submit( + () -> mw.beforeToolCall(ctx, "tc1", "dangerous_tool", Collections.emptyMap())); + + // Wait for the approval request event + assertTrue("Approval event should be emitted", eventEmitted.await(2, TimeUnit.SECONDS)); + assertEquals(1, emittedEvents.size()); + assertEquals(EventType.APPROVAL_REQUEST, emittedEvents.get(0).eventType()); + + ApprovalRequest req = (ApprovalRequest) emittedEvents.get(0); + assertEquals("dangerous_tool", req.toolName()); + assertEquals(ToolRiskLevel.DESTRUCTIVE, req.riskLevel()); + + // Approve + assertTrue(mw.resolve(req.requestId(), true)); + assertNull("Approved tool should return null (no denial)", future.get(2, TimeUnit.SECONDS)); + } finally { + exec.shutdownNow(); + } + } + + @Test + public void testDeniedToolReturnsToolCallDenial() throws Exception { + ApprovalMiddleware mw = newApprovalMiddleware(5); + AgentRunContext ctx = makeContext(ApprovalMode.NORMAL); + + ExecutorService exec = Executors.newSingleThreadExecutor(); + try { + CountDownLatch eventEmitted = new CountDownLatch(1); + ctx.setEventEmitter( + event -> { + emittedEvents.add(event); + eventEmitted.countDown(); + }); + + Future future = + exec.submit( + () -> mw.beforeToolCall(ctx, "tc1", "dangerous_tool", Collections.emptyMap())); + + assertTrue(eventEmitted.await(2, TimeUnit.SECONDS)); + ApprovalRequest req = (ApprovalRequest) emittedEvents.get(0); + + // Deny + assertTrue(mw.resolve(req.requestId(), false)); + AgentMiddleware.ToolCallDenial denial = future.get(2, TimeUnit.SECONDS); + assertNotNull(denial); + assertTrue(denial.reason().contains("denied")); + } finally { + exec.shutdownNow(); + } + } + + // --- Strict mode: all tools need approval --- + + @Test + public void testStrictModeRequiresApprovalForSafeTool() throws Exception { + ApprovalMiddleware mw = newApprovalMiddleware(5); + AgentRunContext ctx = makeContext(ApprovalMode.STRICT); + + ExecutorService exec = Executors.newSingleThreadExecutor(); + try { + CountDownLatch eventEmitted = new CountDownLatch(1); + ctx.setEventEmitter( + event -> { + emittedEvents.add(event); + eventEmitted.countDown(); + }); + + Future future = + exec.submit(() -> mw.beforeToolCall(ctx, "tc1", "safe_tool", Collections.emptyMap())); + + assertTrue(eventEmitted.await(2, TimeUnit.SECONDS)); + ApprovalRequest req = (ApprovalRequest) emittedEvents.get(0); + assertEquals("safe_tool", req.toolName()); + + assertTrue(mw.resolve(req.requestId(), true)); + assertNull(future.get(2, TimeUnit.SECONDS)); + } finally { + exec.shutdownNow(); + } + } + + // --- Timeout --- + + @Test + public void testApprovalTimeoutReturnsDenial() throws Exception { + ApprovalMiddleware mw = newApprovalMiddleware(1); // 1 second timeout + AgentRunContext ctx = makeContext(ApprovalMode.STRICT); + ctx.setEventEmitter(emittedEvents::add); + + ExecutorService exec = Executors.newSingleThreadExecutor(); + try { + Future future = + exec.submit(() -> mw.beforeToolCall(ctx, "tc1", "safe_tool", Collections.emptyMap())); + + // Don't resolve — let it time out + AgentMiddleware.ToolCallDenial denial = future.get(5, TimeUnit.SECONDS); + assertNotNull("Timeout should produce a denial", denial); + assertTrue(denial.reason().contains("timed out")); + } finally { + exec.shutdownNow(); + } + } + + // --- Cancel all --- + + @Test + public void testOnStopUnblocksPendingRequests() throws Exception { + ApprovalMiddleware mw = newApprovalMiddleware(30); + AgentRunContext ctx = makeContext(ApprovalMode.STRICT); + ctx.setEventEmitter(emittedEvents::add); + + ExecutorService exec = Executors.newSingleThreadExecutor(); + try { + CountDownLatch started = new CountDownLatch(1); + Future future = + exec.submit( + () -> { + started.countDown(); + return mw.beforeToolCall(ctx, "tc1", "safe_tool", Collections.emptyMap()); + }); + + assertTrue(started.await(2, TimeUnit.SECONDS)); + Thread.sleep(100); // let the thread enter the blocking wait + + mw.onStop(); + + AgentMiddleware.ToolCallDenial denial = future.get(2, TimeUnit.SECONDS); + assertNotNull("onStop should unblock with a denial", denial); + } finally { + exec.shutdownNow(); + } + } + + // --- Helpers --- + + private ApprovalMiddleware newApprovalMiddleware() { + ApprovalMiddleware mw = new ApprovalMiddleware(); + mw.onRegister(registry); + return mw; + } + + private ApprovalMiddleware newApprovalMiddleware(long timeoutSeconds) { + ApprovalMiddleware mw = new ApprovalMiddleware(timeoutSeconds); + mw.onRegister(registry); + return mw; + } + + private AgentRunContext makeContext(ApprovalMode mode) { + AgentRunContext ctx = new AgentRunContext(new ConversationMemory(), mode); + ctx.setEventEmitter(emittedEvents::add); + return ctx; + } + + private static AgentTool safeTool(String name) { + return new DummyTool(name, ToolRiskLevel.SAFE); + } + + private static AgentTool destructiveTool(String name) { + return new DummyTool(name, ToolRiskLevel.DESTRUCTIVE); + } + + public static class DummyArgs { + public String value; + } + + private static class DummyTool implements AgentTool { + private final String name; + private final ToolRiskLevel riskLevel; + + DummyTool(String name, ToolRiskLevel riskLevel) { + this.name = name; + this.riskLevel = riskLevel; + } + + @Override + public String name() { + return name; + } + + @Override + public String description() { + return "dummy tool"; + } + + @Override + public ToolRiskLevel riskLevel() { + return riskLevel; + } + + @Override + public Class argsType() { + return DummyArgs.class; + } + + @Override + public String execute(DummyArgs args, ToolContext ctx) { + return "ok"; + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareLiveTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareLiveTest.java new file mode 100644 index 00000000000..1c4eeb7c5ad --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareLiveTest.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import static org.junit.Assert.*; +import static org.junit.Assume.assumeTrue; + +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import java.util.List; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.ApprovalMode; +import org.apache.kyuubi.engine.dataagent.runtime.ConversationMemory; +import org.junit.Before; +import org.junit.Test; + +/** + * Live integration test for {@link CompactionMiddleware}: exercises the full compaction path + * against a real OpenAI-compatible LLM. Requires {@code DATA_AGENT_LLM_API_KEY} and {@code + * DATA_AGENT_LLM_API_URL} environment variables; skipped otherwise. + */ +public class CompactionMiddlewareLiveTest { + + private static final String API_KEY = System.getenv().getOrDefault("DATA_AGENT_LLM_API_KEY", ""); + private static final String BASE_URL = System.getenv().getOrDefault("DATA_AGENT_LLM_API_URL", ""); + private static final String MODEL_NAME = System.getenv().getOrDefault("DATA_AGENT_LLM_MODEL", ""); + + private OpenAIClient client; + + @Before + public void setUp() { + assumeTrue("DATA_AGENT_LLM_API_KEY not set, skipping live tests", !API_KEY.isEmpty()); + assumeTrue("DATA_AGENT_LLM_API_URL not set, skipping live tests", !BASE_URL.isEmpty()); + client = OpenAIOkHttpClient.builder().apiKey(API_KEY).baseUrl(BASE_URL).build(); + } + + @Test + public void compactsHistoryWhenThresholdCrossed() { + // Seed a realistic ReAct-style history so the summarizer has something non-trivial to + // summarize. ~20 alternating user/assistant turns. + ConversationMemory memory = new ConversationMemory(); + memory.setSystemPrompt( + "You are a data agent. You previously helped the user investigate the orders table."); + for (int i = 0; i < 10; i++) { + memory.addUserMessage( + "Follow-up question " + i + ": what about the column customer_id in orders?"); + memory.addAssistantMessage( + ChatCompletionAssistantMessageParam.builder() + .content( + "Assistant turn " + + i + + ": the orders table has a customer_id BIGINT column referencing customers.id.") + .build()); + } + int originalSize = memory.size(); + + AgentRunContext ctx = new AgentRunContext(memory, ApprovalMode.AUTO_APPROVE); + // Simulate the previous LLM call having reported a large prompt_tokens so the next + // beforeLlmCall trips the threshold. + ctx.addTokenUsage(60_000, 0, 60_000); + + CompactionMiddleware mw = new CompactionMiddleware(client, MODEL_NAME, /* trigger */ 50_000L); + + AgentMiddleware.LlmCallAction action = mw.beforeLlmCall(ctx, memory.buildLlmMessages()); + + assertNotNull("expected compaction to fire", action); + assertTrue(action instanceof AgentMiddleware.LlmModifyMessages); + + // History got rewritten: [summary user msg] + kept tail. + List hist = memory.getHistory(); + assertTrue(hist.size() < originalSize); + assertTrue(hist.get(0).isUser()); + String first = hist.get(0).asUser().content().text().orElse(""); + assertTrue( + "summary message should be wrapped in ", + first.contains("")); + + // The LLM was told to emit 8 markdown sections; sanity-check a couple show up. + assertTrue( + "summary should contain '## User Intent' section", + first.contains("## User Intent") || first.contains("User Intent")); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareTest.java new file mode 100644 index 00000000000..d4c0cc581db --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/CompactionMiddlewareTest.java @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import static org.junit.Assert.*; + +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.models.chat.completions.ChatCompletionAssistantMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageFunctionToolCall; +import com.openai.models.chat.completions.ChatCompletionMessageParam; +import com.openai.models.chat.completions.ChatCompletionMessageToolCall; +import com.openai.models.chat.completions.ChatCompletionToolMessageParam; +import com.openai.models.chat.completions.ChatCompletionUserMessageParam; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.ApprovalMode; +import org.apache.kyuubi.engine.dataagent.runtime.ConversationMemory; +import org.junit.Before; +import org.junit.Test; + +/** + * Unit tests that exercise only the deterministic, LLM-free parts of {@link CompactionMiddleware}: + * + *
      + *
    • Static helpers {@code computeSplit} and {@code estimateTailAfterLastAssistant}. + *
    • Pre-summarizer gating paths in {@code beforeLlmCall} (below threshold, empty old after + * split) that never reach the LLM call. + *
    • Constructor validation. + *
    + * + * The end-to-end behaviour (actual summarization output, compacted-history structure, no re-trigger + * next turn) is covered by {@code CompactionMiddlewareLiveTest}, which is gated on real LLM + * credentials. + */ +public class CompactionMiddlewareTest { + + /** + * A minimally-configured real {@link OpenAIClient}. Never invoked in these tests because every + * code path exercised here returns before reaching the summarizer. Uses dummy credentials and a + * bogus base URL so an accidental invocation would fail loudly. + */ + private static final OpenAIClient DUMMY_CLIENT = + OpenAIOkHttpClient.builder() + .apiKey("dummy-for-unit-tests") + .baseUrl("http://127.0.0.1:0") + .build(); + + private ConversationMemory memory; + private AgentRunContext ctx; + + @Before + public void setUp() { + memory = new ConversationMemory(); + memory.setSystemPrompt("SYS"); + ctx = new AgentRunContext(memory, ApprovalMode.AUTO_APPROVE); + } + + // ----- computeSplit ----- + + @Test + public void computeSplit_shortHistoryReturnsAllInKept() { + List history = + Arrays.asList(userMsg("u0"), asstMsg(asstPlain("a0"))); + CompactionMiddleware.Split s = CompactionMiddleware.computeSplit(history, 4); + assertEquals(0, s.old.size()); + assertEquals(2, s.kept.size()); + } + + @Test + public void computeSplit_simpleAlternationSplitsCorrectly() { + // 10 msgs: u0,a0,u1,a1,u2,a2,u3,a3,u4,a4. keep=2 → boundary at u3 (2 users from tail). + List history = alternatingHistory(10); + CompactionMiddleware.Split s = CompactionMiddleware.computeSplit(history, 2); + // 4 users from tail: u4,u3 → splitIdx = index of u3 = 6 → old=[0..5], kept=[6..9] + assertEquals(6, s.old.size()); + assertEquals(4, s.kept.size()); + assertTrue(s.kept.get(0).isUser()); + } + + @Test + public void computeSplit_neverOrphansToolResult() { + // Layout: u0 a0 u1 a1 u2 a2 u3 a3(tc1) tool(tc1) u4 a4 u5 a5 u6 a6 u7 a7 + // keep=4 → naive splitIdx lands between tool(tc1) and u4; pair-protection must shift it back + // to before a3(tc1), so a3(tc1) + tool(tc1) end up in kept. + List history = new ArrayList<>(); + history.add(userMsg("u0")); + history.add(asstMsg(asstPlain("a0"))); + history.add(userMsg("u1")); + history.add(asstMsg(asstPlain("a1"))); + history.add(userMsg("u2")); + history.add(asstMsg(asstPlain("a2"))); + history.add(userMsg("u3")); + history.add(asstMsg(asstWithToolCall("a3", "tc1", "sql_query", "{}"))); + history.add(toolMsg("tc1", "r1")); + history.add(userMsg("u4")); + history.add(asstMsg(asstPlain("a4"))); + history.add(userMsg("u5")); + history.add(asstMsg(asstPlain("a5"))); + history.add(userMsg("u6")); + history.add(asstMsg(asstPlain("a6"))); + history.add(userMsg("u7")); + history.add(asstMsg(asstPlain("a7"))); + + CompactionMiddleware.Split s = CompactionMiddleware.computeSplit(history, 4); + assertNoOrphanToolResult(s.kept); + // Verify the tc1 pair really did end up in kept, not old. + assertTrue("a3(tc1) must be in kept", containsToolCallId(s.kept, "tc1")); + assertTrue("tool_result(tc1) must be in kept", containsToolCallIdAsResult(s.kept, "tc1")); + } + + @Test + public void computeSplit_keepCountExceedsAvailableUsers() { + // Only 2 user msgs but we ask to keep 4 — boundary walks to the top, old=[], kept=all. + List history = alternatingHistory(4); // u0,a0,u1,a1 + CompactionMiddleware.Split s = CompactionMiddleware.computeSplit(history, 4); + assertEquals(0, s.old.size()); + assertEquals(4, s.kept.size()); + } + + // ----- estimateTailAfterLastAssistant ----- + + @Test + public void estimateTail_afterLastAssistant() { + // u(200 chars) a(50) u(100) → last assistant at index 1, tail is the final user = 100 chars + // → 100/4 = 25 tokens + List msgs = + Arrays.asList( + userMsg(repeat('x', 200)), + asstMsg(asstPlain(repeat('y', 50))), + userMsg(repeat('z', 100))); + assertEquals(25L, CompactionMiddleware.estimateTailAfterLastAssistant(msgs)); + } + + @Test + public void estimateTail_noAssistantMeansEverythingIsTail() { + List msgs = + Arrays.asList(userMsg(repeat('x', 400)), userMsg(repeat('y', 400))); + assertEquals(200L, CompactionMiddleware.estimateTailAfterLastAssistant(msgs)); + } + + @Test + public void estimateTail_emptyReturnsZero() { + assertEquals(0L, CompactionMiddleware.estimateTailAfterLastAssistant(Collections.emptyList())); + } + + // ----- beforeLlmCall pre-summarizer gating ----- + + @Test + public void belowThresholdReturnsNull() { + seedSimpleHistory(6); + ctx.addTokenUsage(1000, 0, 1000); + CompactionMiddleware mw = new CompactionMiddleware(DUMMY_CLIENT, "m", 50_000L); + + assertNull(mw.beforeLlmCall(ctx, memory.buildLlmMessages())); + // Nothing was mutated. + assertEquals(6, memory.size()); + } + + @Test + public void aboveThresholdButHistoryTooShortReturnsNull() { + // Threshold crossed (60k cumulative) but history has only 2 user turns → computeSplit + // can't satisfy KEEP_RECENT_TURNS=4 and keeps everything, leaving split.old empty; so + // beforeLlmCall bails out before ever invoking the summarizer. + memory.addUserMessage("u0"); + memory.addAssistantMessage(asstPlain("a0")); + memory.addUserMessage("u1"); + ctx.addTokenUsage(60_000, 0, 60_000); + CompactionMiddleware mw = new CompactionMiddleware(DUMMY_CLIENT, "m", 50_000L); + + assertNull(mw.beforeLlmCall(ctx, memory.buildLlmMessages())); + assertEquals(3, memory.size()); + } + + @Test + public void triggerUsesLastCallTotalNotCumulative() { + // Two consecutive calls with total_tokens below threshold. The middleware must key on the + // *last* call's total (prompt + completion), not the session cumulative — otherwise a session + // that has accumulated large cumulative cost but then compacted would misfire. Using total + // (not just prompt) also covers the last assistant message — e.g. a tool_call's completion + // tokens — which is part of the next prompt but sits beyond the tail estimator's window. + seedSimpleHistory(6); + CompactionMiddleware mw = new CompactionMiddleware(DUMMY_CLIENT, "m", 50_000L); + + ctx.addTokenUsage(4_000, 1_000, 5_000); + assertNull(mw.beforeLlmCall(ctx, memory.buildLlmMessages())); + + ctx.addTokenUsage(8_000, 2_000, 10_000); + assertNull(mw.beforeLlmCall(ctx, memory.buildLlmMessages())); + + assertEquals(10_000L, memory.getLastTotalTokens()); + assertEquals(15_000L, memory.getCumulativeTotalTokens()); + } + + // ----- helpers ----- + + private void seedSimpleHistory(int n) { + for (int i = 0; i < n; i++) { + if (i % 2 == 0) { + memory.addUserMessage("u" + i); + } else { + memory.addAssistantMessage(asstPlain("a" + i)); + } + } + } + + private static List alternatingHistory(int n) { + List out = new ArrayList<>(n); + for (int i = 0; i < n; i++) { + if (i % 2 == 0) { + out.add(userMsg("u" + i)); + } else { + out.add(asstMsg(asstPlain("a" + i))); + } + } + return out; + } + + private static ChatCompletionMessageParam userMsg(String text) { + return ChatCompletionMessageParam.ofUser( + ChatCompletionUserMessageParam.builder().content(text).build()); + } + + private static ChatCompletionMessageParam asstMsg(ChatCompletionAssistantMessageParam p) { + return ChatCompletionMessageParam.ofAssistant(p); + } + + private static ChatCompletionMessageParam toolMsg(String toolCallId, String content) { + return ChatCompletionMessageParam.ofTool( + ChatCompletionToolMessageParam.builder().toolCallId(toolCallId).content(content).build()); + } + + private static ChatCompletionAssistantMessageParam asstPlain(String text) { + return ChatCompletionAssistantMessageParam.builder().content(text).build(); + } + + private static ChatCompletionAssistantMessageParam asstWithToolCall( + String text, String toolCallId, String toolName, String args) { + List calls = new ArrayList<>(); + calls.add( + ChatCompletionMessageToolCall.ofFunction( + ChatCompletionMessageFunctionToolCall.builder() + .id(toolCallId) + .function( + ChatCompletionMessageFunctionToolCall.Function.builder() + .name(toolName) + .arguments(args) + .build()) + .build())); + return ChatCompletionAssistantMessageParam.builder().content(text).toolCalls(calls).build(); + } + + private static String repeat(char c, int n) { + char[] arr = new char[n]; + Arrays.fill(arr, c); + return new String(arr); + } + + private static boolean containsToolCallId(List msgs, String id) { + for (ChatCompletionMessageParam m : msgs) { + if (m.isAssistant()) { + List calls = m.asAssistant().toolCalls().orElse(null); + if (calls == null) continue; + for (ChatCompletionMessageToolCall tc : calls) { + if (tc.isFunction() && id.equals(tc.asFunction().id())) return true; + } + } + } + return false; + } + + private static boolean containsToolCallIdAsResult( + List msgs, String id) { + for (ChatCompletionMessageParam m : msgs) { + if (m.isTool() && id.equals(m.asTool().toolCallId())) return true; + } + return false; + } + + private static void assertNoOrphanToolResult(List msgs) { + Set issued = new HashSet<>(); + for (ChatCompletionMessageParam m : msgs) { + if (m.isAssistant()) { + m.asAssistant() + .toolCalls() + .ifPresent( + calls -> { + for (ChatCompletionMessageToolCall tc : calls) { + if (tc.isFunction()) issued.add(tc.asFunction().id()); + } + }); + } + } + for (ChatCompletionMessageParam m : msgs) { + if (m.isTool()) { + assertTrue( + "tool_result id=" + m.asTool().toolCallId() + " has no matching tool_call", + issued.contains(m.asTool().toolCallId())); + } + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddlewareTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddlewareTest.java new file mode 100644 index 00000000000..cb107b775f5 --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/runtime/middleware/ToolResultOffloadMiddlewareTest.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.runtime.middleware; + +import static org.junit.Assert.*; + +import java.util.Collections; +import org.apache.kyuubi.engine.dataagent.runtime.AgentRunContext; +import org.apache.kyuubi.engine.dataagent.runtime.ApprovalMode; +import org.apache.kyuubi.engine.dataagent.runtime.ConversationMemory; +import org.apache.kyuubi.engine.dataagent.tool.output.GrepToolOutputTool; +import org.apache.kyuubi.engine.dataagent.tool.output.ReadToolOutputTool; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class ToolResultOffloadMiddlewareTest { + + private ToolResultOffloadMiddleware mw; + private AgentRunContext ctxWithSession; + private AgentRunContext ctxNoSession; + + @Before + public void setUp() { + mw = new ToolResultOffloadMiddleware(); + ctxWithSession = + new AgentRunContext(new ConversationMemory(), ApprovalMode.AUTO_APPROVE, "sess-1"); + ctxNoSession = new AgentRunContext(new ConversationMemory(), ApprovalMode.AUTO_APPROVE, null); + } + + @After + public void tearDown() { + mw.onStop(); + } + + @Test + public void underThresholdPassesThrough() { + String small = "row1\nrow2\nrow3\n"; + String out = + mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), small); + assertNull(out); + } + + @Test + public void overLineThresholdTriggersOffload() { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 600; i++) sb.append("row").append(i).append('\n'); + String out = + mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString()); + + assertNotNull(out); + assertTrue(out, out.contains("Tool output truncated")); + assertTrue(out, out.contains("Saved to:")); + assertTrue(out, out.contains(ReadToolOutputTool.NAME)); + assertTrue(out, out.contains(GrepToolOutputTool.NAME)); + assertTrue(out, out.contains("row0")); + assertTrue(out, out.contains("row599")); + } + + @Test + public void overByteThresholdTriggersOffload() { + // 60 lines of ~1 KB each = ~60 KB — over the byte threshold but well under the line threshold. + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 60; i++) { + for (int j = 0; j < 1024; j++) sb.append('a'); + sb.append('\n'); + } + String out = + mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString()); + + assertNotNull("byte threshold should trigger", out); + assertTrue(out, out.contains("Tool output truncated")); + } + + @Test + public void retrievalToolsAreExemptFromGate() { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 2000; i++) sb.append("row").append(i).append('\n'); + assertNull( + mw.afterToolCall( + ctxWithSession, ReadToolOutputTool.NAME, Collections.emptyMap(), sb.toString())); + assertNull( + mw.afterToolCall( + ctxWithSession, GrepToolOutputTool.NAME, Collections.emptyMap(), sb.toString())); + } + + @Test + public void missingSessionIdPassesThrough() { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 1000; i++) sb.append("row").append(i).append('\n'); + String out = + mw.afterToolCall(ctxNoSession, "run_select_query", Collections.emptyMap(), sb.toString()); + assertNull("without sessionId, cannot offload safely — pass through", out); + } + + @Test + public void onSessionCloseClearsCounterAndFiles() { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 600; i++) sb.append("row").append(i).append('\n'); + mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString()); + assertEquals(1, mw.trackedSessions()); + + mw.onSessionClose("sess-1"); + assertEquals(0, mw.trackedSessions()); + } + + @Test + public void multipleOffloadsReuseSameSessionDir() { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 600; i++) sb.append("row").append(i).append('\n'); + String out1 = + mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString()); + String out2 = + mw.afterToolCall(ctxWithSession, "run_select_query", Collections.emptyMap(), sb.toString()); + // Both previews reference the same session dir, different file names. + assertNotEquals(extractPath(out1), extractPath(out2)); + assertTrue(extractPath(out1).contains("sess-1")); + assertTrue(extractPath(out2).contains("sess-1")); + } + + private static String extractPath(String preview) { + int i = preview.indexOf("Saved to:"); + int eol = preview.indexOf('\n', i); + return preview.substring(i + "Saved to:".length(), eol).trim(); + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistryThreadSafetyTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistryThreadSafetyTest.java index c3ceb1a4dc0..7f790e1238e 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistryThreadSafetyTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolRegistryThreadSafetyTest.java @@ -66,7 +66,7 @@ public Class argsType() { } @Override - public String execute(DummyArgs args) { + public String execute(DummyArgs args, ToolContext ctx) { return "result_" + idx; } }); @@ -118,7 +118,7 @@ public Class argsType() { } @Override - public String execute(DummyArgs args) { + public String execute(DummyArgs args, ToolContext ctx) { return "existing_result"; } }); @@ -159,7 +159,7 @@ public Class argsType() { } @Override - public String execute(DummyArgs args) { + public String execute(DummyArgs args, ToolContext ctx) { return "dynamic"; } }); diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolTest.java index 8e0a5cd01b0..777017c4438 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/ToolTest.java @@ -109,7 +109,7 @@ public Class argsType() { } @Override - public String execute(ToolRegistryThreadSafetyTest.DummyArgs args) { + public String execute(ToolRegistryThreadSafetyTest.DummyArgs args, ToolContext ctx) { try { Thread.sleep(60_000); } catch (InterruptedException e) { @@ -147,7 +147,7 @@ public Class argsType() { } @Override - public String execute(ToolRegistryThreadSafetyTest.DummyArgs args) { + public String execute(ToolRegistryThreadSafetyTest.DummyArgs args, ToolContext ctx) { throw new RuntimeException("intentional failure"); } }); diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryToolTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryToolTest.java index c46fc2501bc..5384bd937d9 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryToolTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunMutationQueryToolTest.java @@ -24,6 +24,7 @@ import java.sql.Statement; import java.util.ArrayList; import java.util.List; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; import org.apache.kyuubi.engine.dataagent.tool.ToolRiskLevel; import org.junit.After; import org.junit.Before; @@ -60,7 +61,7 @@ public void testRiskLevelDestructive() { public void testInsert() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "INSERT INTO t VALUES (9999, 'hello')"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertTrue(result.contains("1 row(s) affected")); } @@ -68,21 +69,21 @@ public void testInsert() { public void testUpdate() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "UPDATE t SET v = 'updated' WHERE id = 1"; - assertTrue(tool.execute(args).contains("1 row(s) affected")); + assertTrue(tool.execute(args, ToolContext.EMPTY).contains("1 row(s) affected")); } @Test public void testDelete() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "DELETE FROM t WHERE id = 1"; - assertTrue(tool.execute(args).contains("1 row(s) affected")); + assertTrue(tool.execute(args, ToolContext.EMPTY).contains("1 row(s) affected")); } @Test public void testCreateTable() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "CREATE TABLE new_t (id INTEGER PRIMARY KEY, v TEXT)"; - assertTrue(tool.execute(args).contains("executed successfully")); + assertTrue(tool.execute(args, ToolContext.EMPTY).contains("executed successfully")); } @Test @@ -90,7 +91,7 @@ public void testAlsoAcceptsSelect() { // Mutation tool does not enforce read-only check; SELECT works fine here. SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT v FROM t WHERE id = 1"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertFalse(result.startsWith("Error:")); } @@ -98,18 +99,18 @@ public void testAlsoAcceptsSelect() { public void testRejectsEmptyAndNullSql() { SqlQueryArgs emptyArgs = new SqlQueryArgs(); emptyArgs.sql = ""; - assertTrue(tool.execute(emptyArgs).startsWith("Error:")); + assertTrue(tool.execute(emptyArgs, ToolContext.EMPTY).startsWith("Error:")); SqlQueryArgs nullArgs = new SqlQueryArgs(); nullArgs.sql = null; - assertTrue(tool.execute(nullArgs).startsWith("Error:")); + assertTrue(tool.execute(nullArgs, ToolContext.EMPTY).startsWith("Error:")); } @Test public void testInvalidSqlReturnsError() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "INSERT INTO nonexistent_table VALUES (1)"; - assertTrue(tool.execute(args).startsWith("Error:")); + assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); } // --- Helpers --- diff --git a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryToolTest.java b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryToolTest.java index 3c6579cf9a7..d1015ede63d 100644 --- a/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryToolTest.java +++ b/externals/kyuubi-data-agent-engine/src/test/java/org/apache/kyuubi/engine/dataagent/tool/sql/RunSelectQueryToolTest.java @@ -24,6 +24,7 @@ import java.sql.Statement; import java.util.ArrayList; import java.util.List; +import org.apache.kyuubi.engine.dataagent.tool.ToolContext; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -56,7 +57,7 @@ public void tearDown() { public void testRejectsInsert() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "INSERT INTO large_table VALUES (9999, 'x')"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertTrue(result.startsWith("Error:")); assertTrue(result.contains("read-only")); assertTrue(result.contains("run_mutation_query")); @@ -66,28 +67,28 @@ public void testRejectsInsert() { public void testRejectsUpdate() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "UPDATE large_table SET value = 'x' WHERE id = 1"; - assertTrue(tool.execute(args).startsWith("Error:")); + assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); } @Test public void testRejectsDelete() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "DELETE FROM large_table WHERE id = 1"; - assertTrue(tool.execute(args).startsWith("Error:")); + assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); } @Test public void testRejectsCreateTable() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "CREATE TABLE x (id INT)"; - assertTrue(tool.execute(args).startsWith("Error:")); + assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); } @Test public void testAllowsSelect() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT id FROM large_table LIMIT 100"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertFalse(result.startsWith("Error:")); assertTrue(result.contains("[100 row(s) returned]")); } @@ -96,7 +97,7 @@ public void testAllowsSelect() { public void testAllowsCte() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "WITH cte AS (SELECT id, value FROM large_table LIMIT 5) SELECT * FROM cte"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertFalse(result.startsWith("Error:")); assertTrue(result.contains("row(s)")); } @@ -107,7 +108,7 @@ public void testAllowsCte() { public void testRespectsLimitInSql() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT id FROM large_table LIMIT 5"; - assertTrue(tool.execute(args).contains("[5 row(s) returned]")); + assertTrue(tool.execute(args, ToolContext.EMPTY).contains("[5 row(s) returned]")); } @Test @@ -116,7 +117,7 @@ public void testNoClientSideCapWhenLimitOmitted() { // Cap discipline is delegated to the LLM via the system prompt. SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT id FROM large_table"; - assertTrue(tool.execute(args).contains("[1500 row(s) returned]")); + assertTrue(tool.execute(args, ToolContext.EMPTY).contains("[1500 row(s) returned]")); } // --- Zero-row result --- @@ -125,7 +126,7 @@ public void testNoClientSideCapWhenLimitOmitted() { public void testZeroRowsResult() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT id FROM large_table WHERE id < 0"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertFalse(result.startsWith("Error:")); assertTrue(result.contains("[0 row(s) returned]")); } @@ -136,15 +137,15 @@ public void testZeroRowsResult() { public void testSelectWithLeadingBlockComment() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "/* get count */ SELECT COUNT(*) FROM large_table"; - assertFalse(tool.execute(args).startsWith("Error:")); + assertFalse(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); } @Test public void testRejectsMutationHiddenBehindComment() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "-- looks innocent\nDROP TABLE large_table"; - assertTrue(tool.execute(args).startsWith("Error:")); - assertTrue(tool.execute(args).contains("read-only")); + assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); + assertTrue(tool.execute(args, ToolContext.EMPTY).contains("read-only")); } // --- Edge cases --- @@ -153,25 +154,25 @@ public void testRejectsMutationHiddenBehindComment() { public void testRejectsEmptyAndNullSql() { SqlQueryArgs emptyArgs = new SqlQueryArgs(); emptyArgs.sql = ""; - assertTrue(tool.execute(emptyArgs).startsWith("Error:")); + assertTrue(tool.execute(emptyArgs, ToolContext.EMPTY).startsWith("Error:")); SqlQueryArgs nullArgs = new SqlQueryArgs(); nullArgs.sql = null; - assertTrue(tool.execute(nullArgs).startsWith("Error:")); + assertTrue(tool.execute(nullArgs, ToolContext.EMPTY).startsWith("Error:")); } @Test public void testRejectsWhitespaceOnlySql() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = " \t\n "; - assertTrue(tool.execute(args).startsWith("Error:")); + assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); } @Test public void testInvalidSqlReturnsError() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT * FROM nonexistent_table"; - assertTrue(tool.execute(args).startsWith("Error:")); + assertTrue(tool.execute(args, ToolContext.EMPTY).startsWith("Error:")); } // --- Output formatting --- @@ -188,7 +189,7 @@ public void testNullValuesRenderedAsNULL() { } SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT id, name FROM nullable_test ORDER BY ROWID"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertTrue(result.contains("NULL")); assertTrue(result.contains("Alice")); } @@ -204,7 +205,7 @@ public void testPipeCharacterEscapedInOutput() { } SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT val FROM pipe_test"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertTrue("Pipe should be escaped for markdown table", result.contains("a\\|b\\|c")); } @@ -214,7 +215,7 @@ public void testPipeCharacterEscapedInOutput() { public void testExtractRootCauseFromNestedExceptions() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT * FROM this_table_does_not_exist_at_all"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertTrue(result.startsWith("Error:")); assertTrue(result.contains("this_table_does_not_exist_at_all")); } @@ -223,7 +224,7 @@ public void testExtractRootCauseFromNestedExceptions() { public void testErrorMessageIsConcise() { SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELEC INVALID SYNTAX HERE !!!"; - String result = tool.execute(args); + String result = tool.execute(args, ToolContext.EMPTY); assertTrue(result.startsWith("Error:")); long newlines = result.chars().filter(c -> c == '\n').count(); assertTrue("Error should be concise (<=2 newlines), got " + newlines, newlines <= 2); @@ -236,7 +237,7 @@ public void testCustomQueryTimeout() { RunSelectQueryTool customTool = new RunSelectQueryTool(ds, 5); SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT COUNT(*) FROM large_table"; - assertFalse(customTool.execute(args).startsWith("Error:")); + assertFalse(customTool.execute(args, ToolContext.EMPTY).startsWith("Error:")); } @Test @@ -314,7 +315,7 @@ public boolean isWrapperFor(Class iface) { RunSelectQueryTool timeoutTool = new RunSelectQueryTool(slowDs, 1); SqlQueryArgs args = new SqlQueryArgs(); args.sql = "SELECT * FROM large_table"; - String result = timeoutTool.execute(args); + String result = timeoutTool.execute(args, ToolContext.EMPTY); assertTrue("Expected error on timeout", result.startsWith("Error:")); assertTrue("Expected timeout message", result.contains("timed out")); } diff --git a/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentCompactionE2ESuite.scala b/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentCompactionE2ESuite.scala new file mode 100644 index 00000000000..5fd95bb0fdc --- /dev/null +++ b/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentCompactionE2ESuite.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.dataagent.operation + +import java.sql.DriverManager + +import scala.collection.JavaConverters._ + +import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper} + +import org.apache.kyuubi.config.KyuubiConf._ +import org.apache.kyuubi.engine.dataagent.WithDataAgentEngine +import org.apache.kyuubi.operation.HiveJDBCTestHelper + +/** + * End-to-end test that forces CompactionMiddleware to fire inside the engine and verifies + * two things simultaneously: + * 1. The `compaction` SSE event reaches the JDBC client (wiring works). + * 2. The agent still answers correctly *after* compaction, proving the summary preserved + * the facts the follow-up question depends on. + * + * The trigger threshold is set extremely low (500 tokens) so that the schema dump plus the + * first turn's completion already blow past it, forcing compaction before turn 2 or 3. + * + * Requires DATA_AGENT_LLM_API_KEY and DATA_AGENT_LLM_API_URL. + */ +class DataAgentCompactionE2ESuite extends HiveJDBCTestHelper with WithDataAgentEngine { + + private val apiKey = sys.env.getOrElse("DATA_AGENT_LLM_API_KEY", "") + private val apiUrl = sys.env.getOrElse("DATA_AGENT_LLM_API_URL", "") + private val modelName = sys.env.getOrElse("DATA_AGENT_LLM_MODEL", "") + private val dbPath = { + val tmp = System.getProperty("java.io.tmpdir") + val uid = java.util.UUID.randomUUID() + s"$tmp/dataagent_compaction_e2e_$uid.db" + } + + override def withKyuubiConf: Map[String, String] = Map( + ENGINE_DATA_AGENT_PROVIDER.key -> "OPENAI_COMPATIBLE", + ENGINE_DATA_AGENT_LLM_API_KEY.key -> apiKey, + ENGINE_DATA_AGENT_LLM_API_URL.key -> apiUrl, + ENGINE_DATA_AGENT_LLM_MODEL.key -> modelName, + ENGINE_DATA_AGENT_MAX_ITERATIONS.key -> "10", + ENGINE_DATA_AGENT_APPROVAL_MODE.key -> "AUTO_APPROVE", + // Force compaction to fire aggressively -- any realistic prompt + one LLM round-trip + // will exceed 500 tokens, so compaction must trigger by turn 2 or 3. + ENGINE_DATA_AGENT_COMPACTION_TRIGGER_TOKENS.key -> "500", + ENGINE_DATA_AGENT_JDBC_URL.key -> s"jdbc:sqlite:$dbPath") + + override protected def jdbcUrl: String = jdbcConnectionUrl + + private val enabled: Boolean = apiKey.nonEmpty && apiUrl.nonEmpty + + override def beforeAll(): Unit = { + if (enabled) { + setupTestDatabase() + super.beforeAll() + } + } + + override def afterAll(): Unit = { + if (enabled) { + super.afterAll() + new java.io.File(dbPath).delete() + } + } + + private def setupTestDatabase(): Unit = { + new java.io.File(dbPath).delete() + val conn = DriverManager.getConnection(s"jdbc:sqlite:$dbPath") + try { + val stmt = conn.createStatement() + stmt.execute( + """ + |CREATE TABLE employees ( + | id INTEGER PRIMARY KEY, + | name TEXT NOT NULL, + | department TEXT NOT NULL, + | salary REAL NOT NULL + |)""".stripMargin) + // 6 employees across 3 departments; Frank is unambiguously the top earner. + stmt.execute("INSERT INTO employees VALUES (1, 'Alice', 'Engineering', 25000)") + stmt.execute("INSERT INTO employees VALUES (2, 'Bob', 'Engineering', 30000)") + stmt.execute("INSERT INTO employees VALUES (3, 'Charlie', 'Sales', 20000)") + stmt.execute("INSERT INTO employees VALUES (4, 'Diana', 'Sales', 22000)") + stmt.execute("INSERT INTO employees VALUES (5, 'Eve', 'Marketing', 18000)") + stmt.execute("INSERT INTO employees VALUES (6, 'Frank', 'Engineering', 35000)") + } finally { + conn.close() + } + } + + private val mapper = new ObjectMapper() + + private def drainReply(rs: java.sql.ResultSet): String = { + val sb = new StringBuilder + while (rs.next()) { + sb.append(rs.getString("reply")) + } + sb.toString() + } + + private def parseEvents(stream: String): Seq[JsonNode] = { + val parser = mapper.getFactory.createParser(stream) + try mapper.readValues(parser, classOf[JsonNode]).asScala.toList + finally parser.close() + } + + private def extractAnswer(events: Seq[JsonNode]): String = { + val sb = new StringBuilder + events.foreach { node => + if ("content_delta" == node.path("type").asText()) { + sb.append(node.path("text").asText("")) + } + } + sb.toString() + } + + private val strictFormatHint = + "Respond with ONLY the answer, no explanation, no markdown, no punctuation." + + test("E2E: compaction fires mid-conversation and preserves facts across turns") { + assume(enabled, "DATA_AGENT_LLM_API_KEY/API_URL not set, skipping") + + // CompactionMiddleware.KEEP_RECENT_TURNS is hardcoded to 4. computeSplit needs a + // non-empty 'old' slice, so at least 5 distinct user turns must accumulate before + // compaction can fire -- turns 1..(N-4) become the old slice, 4 most recent are + // kept verbatim. Turn 5 is the observable trigger point. + // + // Turn 5 is phrased to force a fresh SQL query rather than relying on recall, + // because summary quality varies across LLMs (some drop the top-earner fact). + // Correctness of the final answer then validates that the post-compaction history + // still gives the agent enough context to pick the right tool and query -- which is + // the compaction contract we actually care about: mechanism fires, agent recovers. + withJdbcStatement() { stmt => + Seq( + "List every department that appears in the employees table.", + "How many employees work in Engineering?", + "What salaries do Sales employees earn?", + "Who works in Marketing?").zipWithIndex.foreach { case (q, i) => + val events = parseEvents(drainReply(stmt.executeQuery(q))) + info(s"Turn ${i + 1} answer: ${extractAnswer(events)}") + } + + // Turn 5 -- explicitly instruct the agent to re-query so the answer does not + // depend on summary fidelity, only on the agent still functioning after + // compaction rewrote history. + val events5 = parseEvents(drainReply(stmt.executeQuery( + "Run a SELECT against the employees table to find the single employee with" + + " the highest salary. Report ONLY that employee's name." + + s" $strictFormatHint"))) + val answer5 = extractAnswer(events5) + info(s"Turn 5 answer: $answer5") + + val compactionEvents = + events5.filter(_.path("type").asText() == "compaction") + assert( + compactionEvents.nonEmpty, + "Expected at least one compaction event in turn 5 (trigger=500 tokens, 5 turns)") + + // Sanity-check event shape -- field names must match ExecuteStatement's SSE encoder. + val c = compactionEvents.head + assert( + c.has("summarized") && c.get("summarized").asInt() > 0, + s"compaction event should carry a positive summarized count: $c") + assert(c.has("kept") && c.get("kept").asInt() >= 0, s"compaction event missing kept: $c") + assert( + c.has("triggerTokens") && c.get("triggerTokens").asLong() == 500L, + s"compaction event should echo configured trigger: $c") + + // Turn 5 was told to SELECT fresh; Frank (35000) is unambiguously the top earner. + // If we don't get "Frank", either the agent failed to re-query after compaction + // (real bug in post-compaction history) or the tool layer is broken. + assert( + answer5.contains("Frank"), + s"Turn 5 should identify Frank as the top earner after re-querying; the agent" + + s" must remain functional post-compaction. Got: $answer5") + } + } +} diff --git a/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentE2ESuite.scala b/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentE2ESuite.scala index 977c655dabd..cfb83ecc9b2 100644 --- a/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentE2ESuite.scala +++ b/externals/kyuubi-data-agent-engine/src/test/scala/org/apache/kyuubi/engine/dataagent/operation/DataAgentE2ESuite.scala @@ -19,6 +19,10 @@ package org.apache.kyuubi.engine.dataagent.operation import java.sql.DriverManager +import scala.collection.JavaConverters._ + +import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper} + import org.apache.kyuubi.config.KyuubiConf._ import org.apache.kyuubi.engine.dataagent.WithDataAgentEngine import org.apache.kyuubi.operation.HiveJDBCTestHelper @@ -34,7 +38,7 @@ class DataAgentE2ESuite extends HiveJDBCTestHelper with WithDataAgentEngine { private val apiKey = sys.env.getOrElse("DATA_AGENT_LLM_API_KEY", "") private val apiUrl = sys.env.getOrElse("DATA_AGENT_LLM_API_URL", "") - private val modelName = sys.env.getOrElse("DATA_AGENT_LLM_MODEL", "gpt-4o") + private val modelName = sys.env.getOrElse("DATA_AGENT_LLM_MODEL", "") private val dbPath = s"${System.getProperty("java.io.tmpdir")}/dataagent_e2e_test_${java.util.UUID.randomUUID()}.db" @@ -107,33 +111,71 @@ class DataAgentE2ESuite extends HiveJDBCTestHelper with WithDataAgentEngine { new java.io.File(dbPath).delete() } + private val mapper = new ObjectMapper() + + private def drainReply(rs: java.sql.ResultSet): String = { + val sb = new StringBuilder + while (rs.next()) { + sb.append(rs.getString("reply")) + } + val stream = sb.toString() + info(s"Agent event stream: $stream") + stream + } + + /** + * The JDBC `reply` column is a concatenated stream of SSE events + * (`agent_start`, `tool_call`, `tool_result`, `content_delta`, ...). Only + * `content_delta.text` is actual model output - this pulls those out and + * joins them to recover the final natural-language answer. + */ + private def extractAnswer(eventStream: String): String = { + val parser = mapper.getFactory.createParser(eventStream) + val sb = new StringBuilder + try { + mapper.readValues(parser, classOf[JsonNode]).asScala.foreach { node => + if ("content_delta" == node.path("type").asText()) { + sb.append(node.path("text").asText("")) + } + } + } finally { + parser.close() + } + sb.toString() + } + + private val strictFormatHint = + "Respond with ONLY the answer, no explanation, no markdown, no punctuation." + test("E2E: agent answers data question through full Kyuubi pipeline") { assume(enabled, "DATA_AGENT_LLM_API_KEY/API_URL not set, skipping E2E tests") - // scalastyle:off println withJdbcStatement() { stmt => - // Ask a question that requires schema exploration + SQL execution - val result = stmt.executeQuery( - "Which department has the highest average salary?") - - val sb = new StringBuilder - while (result.next()) { - val chunk = result.getString("reply") - sb.append(chunk) - print(chunk) // real-time output for debugging - } - println() - - val reply = sb.toString() - - // The agent should have: - // 1. Explored the schema (mentioned table names or columns) - // 2. Executed SQL (the reply should contain actual data) - // 3. Answered with "Engineering" (avg salary 30000) - assert(reply.nonEmpty, "Agent should return a non-empty response") - assert( - reply.toLowerCase.contains("engineering") || reply.contains("30000"), - s"Expected the answer to mention 'Engineering' or '30000', got: ${reply.take(500)}") + val stream = drainReply( + stmt.executeQuery( + s"Which department has the highest average salary? $strictFormatHint")) + assert(extractAnswer(stream) == "Engineering") + } + } + + test("E2E: agent resolves follow-up question using prior conversation context") { + assume(enabled, "DATA_AGENT_LLM_API_KEY/API_URL not set, skipping E2E tests") + // Two executeQuery calls on the same Statement share the JDBC session, which means + // the provider reuses the same ConversationMemory across turns. Turn 2 uses the + // demonstrative "that department" - it can only be answered correctly if Turn 1's + // answer (Engineering) is carried over in the agent's conversation history. + withJdbcStatement() { stmt => + val stream1 = drainReply( + stmt.executeQuery( + s"Which department has the highest average salary? $strictFormatHint")) + assert(extractAnswer(stream1) == "Engineering") + + // Engineering has 3 employees (Alice, Bob, Frank). If memory is not shared + // the agent cannot resolve "that department" and cannot produce the exact + // integer 3 - nothing in Turn 2's prompt points to Engineering. + val stream2 = drainReply( + stmt.executeQuery( + s"How many employees work in that department? $strictFormatHint")) + assert(extractAnswer(stream2) == "3") } - // scalastyle:on println } } diff --git a/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala b/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala index 215a76e26d4..31d308e9c67 100644 --- a/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala +++ b/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala @@ -3848,6 +3848,21 @@ object KyuubiConf { .checkValue(_ > 0, "must be positive number") .createWithDefault(100) + val ENGINE_DATA_AGENT_COMPACTION_TRIGGER_TOKENS: ConfigEntry[Long] = + buildConf("kyuubi.engine.data.agent.compaction.trigger.tokens") + .doc("The prompt-token threshold above which the Data Agent's compaction middleware " + + "summarizes older conversation history into a compact message. The check is made each " + + "turn as " + + "real_prompt_tokens_of_previous_LLM_call + estimate_of_newly_appended_tail; " + + "when this predicted prompt size reaches the configured value, older messages are " + + "replaced by a single summary message while the most recent exchanges are kept verbatim. " + + "Set to a very large value (e.g., 9223372036854775807) to effectively " + + "disable compaction.") + .version("1.12.0") + .longConf + .checkValue(_ > 0, "must be positive number") + .createWithDefault(128000L) + val ENGINE_DATA_AGENT_QUERY_TIMEOUT: ConfigEntry[Long] = buildConf("kyuubi.engine.data.agent.query.timeout") .doc("The JDBC query execution timeout for the Data Agent SQL tools. " +