diff --git a/Docs/README_MCP.md b/Docs/README_MCP.md index 90436bbd..1edb4956 100644 --- a/Docs/README_MCP.md +++ b/Docs/README_MCP.md @@ -131,10 +131,28 @@ TelegramSearchBot 内置了以下工具,通过 `BuiltInToolAttribute` 标记 ```json { + "EnableLLMAgentProcess": false, + "AgentHeartbeatIntervalSeconds": 10, + "AgentHeartbeatTimeoutSeconds": 60, + "AgentChunkPollingIntervalMilliseconds": 200, + "AgentIdleTimeoutMinutes": 15, + "MaxConcurrentAgents": 8, + "AgentTaskTimeoutSeconds": 300, + "AgentShutdownGracePeriodSeconds": 15, + "AgentMaxRecoveryAttempts": 2, + "AgentQueueBacklogWarningThreshold": 20, + "AgentProcessMemoryLimitMb": 256, "MaxToolCycles": 25 } ``` +- `EnableLLMAgentProcess=true` 时,LLM 对话循环会迁移到独立 Agent 进程,主进程仅负责 Telegram 收发、任务队列和流式转发。 +- `AgentHeartbeatIntervalSeconds` / `AgentHeartbeatTimeoutSeconds` 控制主进程对 Agent 存活状态的检测。 +- `AgentChunkPollingIntervalMilliseconds` 控制主进程从 Garnet 轮询流式输出块的频率。 +- `AgentIdleTimeoutMinutes` / `AgentShutdownGracePeriodSeconds` 控制 Agent 的空闲回收和优雅停机窗口。 +- `MaxConcurrentAgents` / `AgentProcessMemoryLimitMb` 用于约束 Agent 并发数量和内存占用。 +- `AgentTaskTimeoutSeconds` / `AgentMaxRecoveryAttempts` 控制任务超时后的重试和死信恢复策略。 + ## 五、安全考虑 ### 5.1 管理员专用工具 diff --git a/README.md b/README.md index b301613b..f96672ea 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,17 @@ "EnableVideoASR": false, "EnableOpenAI": false, "OpenAIModelName": "gpt-4o", + "EnableLLMAgentProcess": false, + "AgentHeartbeatIntervalSeconds": 10, + "AgentHeartbeatTimeoutSeconds": 60, + "AgentChunkPollingIntervalMilliseconds": 200, + "AgentIdleTimeoutMinutes": 15, + "MaxConcurrentAgents": 8, + "AgentTaskTimeoutSeconds": 300, + "AgentShutdownGracePeriodSeconds": 15, + "AgentMaxRecoveryAttempts": 2, + "AgentQueueBacklogWarningThreshold": 20, + "AgentProcessMemoryLimitMb": 256, "MaxToolCycles": 25, "OLTPAuth": "", "OLTPAuthUrl": "", @@ -79,8 +90,21 @@ - `OllamaModelName`: 本地模型名称(默认"qwen2.5:72b-instruct-q2_K") - `EnableOpenAI`: 是否启用OpenAI(默认false) - `OpenAIModelName`: OpenAI模型名称(默认"gpt-4o") + - `EnableLLMAgentProcess`: 是否启用独立 LLM Agent 进程模式(默认false) + - `AgentHeartbeatIntervalSeconds`: Agent 心跳上报间隔(默认10秒) + - `AgentHeartbeatTimeoutSeconds`: 主进程判定 Agent 失活的超时时间(默认60秒) + - `AgentChunkPollingIntervalMilliseconds`: 主进程轮询流式输出块的间隔(默认200毫秒) + - `AgentIdleTimeoutMinutes`: Agent 空闲超时时间(默认15分钟) + - `MaxConcurrentAgents`: 同时允许的 Agent 进程数上限(默认8) + - `AgentTaskTimeoutSeconds`: 单个 Agent 任务无进展时的超时时间(默认300秒) + - `AgentShutdownGracePeriodSeconds`: Agent 收到停机请求后的优雅退出等待时间(默认15秒) + - `AgentMaxRecoveryAttempts`: Agent 崩溃或超时后的最大恢复重试次数(默认2) + - `AgentQueueBacklogWarningThreshold`: Agent 任务队列告警阈值(默认20) + - `AgentProcessMemoryLimitMb`: Agent 进程工作集上限(默认256MB) - `MaxToolCycles`: LLM工具调用最大迭代次数(默认25),防止无限循环 +启用 `EnableLLMAgentProcess=true` 后,主进程会负责任务排队、Telegram 发消息和流式转发;独立 Agent 进程负责执行 LLM 循环、本地工具和故障恢复。主进程会在 Agent 心跳超时、任务超时或配置切换时执行恢复、重试、死信投递和优雅停机。 + - **日志推送**: - `OLTPAuth`: OLTP日志推送认证密钥 - `OLTPAuthUrl`: OLTP日志推送URL diff --git a/TelegramSearchBot.Common/Env.cs b/TelegramSearchBot.Common/Env.cs index 158f830f..3fcee22b 100644 --- a/TelegramSearchBot.Common/Env.cs +++ b/TelegramSearchBot.Common/Env.cs @@ -42,6 +42,17 @@ static Env() { BraveApiKey = config.BraveApiKey; EnableAccounting = config.EnableAccounting; MaxToolCycles = config.MaxToolCycles; + EnableLLMAgentProcess = config.EnableLLMAgentProcess; + AgentHeartbeatIntervalSeconds = config.AgentHeartbeatIntervalSeconds; + AgentHeartbeatTimeoutSeconds = config.AgentHeartbeatTimeoutSeconds; + AgentChunkPollingIntervalMilliseconds = config.AgentChunkPollingIntervalMilliseconds; + AgentIdleTimeoutMinutes = config.AgentIdleTimeoutMinutes; + MaxConcurrentAgents = config.MaxConcurrentAgents; + AgentTaskTimeoutSeconds = config.AgentTaskTimeoutSeconds; + AgentShutdownGracePeriodSeconds = config.AgentShutdownGracePeriodSeconds; + AgentMaxRecoveryAttempts = config.AgentMaxRecoveryAttempts; + AgentQueueBacklogWarningThreshold = config.AgentQueueBacklogWarningThreshold; + AgentProcessMemoryLimitMb = config.AgentProcessMemoryLimitMb; } catch { } @@ -73,6 +84,17 @@ static Env() { public static string BraveApiKey { get; set; } = null!; public static bool EnableAccounting { get; set; } = false; public static int MaxToolCycles { get; set; } + public static bool EnableLLMAgentProcess { get; set; } = false; + public static int AgentHeartbeatIntervalSeconds { get; set; } = 10; + public static int AgentHeartbeatTimeoutSeconds { get; set; } = 60; + public static int AgentChunkPollingIntervalMilliseconds { get; set; } = 200; + public static int AgentIdleTimeoutMinutes { get; set; } = 15; + public static int MaxConcurrentAgents { get; set; } = 8; + public static int AgentTaskTimeoutSeconds { get; set; } = 300; + public static int AgentShutdownGracePeriodSeconds { get; set; } = 15; + public static int AgentMaxRecoveryAttempts { get; set; } = 2; + public static int AgentQueueBacklogWarningThreshold { get; set; } = 20; + public static int AgentProcessMemoryLimitMb { get; set; } = 256; public static Dictionary Configuration { get; set; } = new Dictionary(); } @@ -100,5 +122,16 @@ public class Config { public string BraveApiKey { get; set; } = null!; public bool EnableAccounting { get; set; } = false; public int MaxToolCycles { get; set; } = 25; + public bool EnableLLMAgentProcess { get; set; } = false; + public int AgentHeartbeatIntervalSeconds { get; set; } = 10; + public int AgentHeartbeatTimeoutSeconds { get; set; } = 60; + public int AgentChunkPollingIntervalMilliseconds { get; set; } = 200; + public int AgentIdleTimeoutMinutes { get; set; } = 15; + public int MaxConcurrentAgents { get; set; } = 8; + public int AgentTaskTimeoutSeconds { get; set; } = 300; + public int AgentShutdownGracePeriodSeconds { get; set; } = 15; + public int AgentMaxRecoveryAttempts { get; set; } = 2; + public int AgentQueueBacklogWarningThreshold { get; set; } = 20; + public int AgentProcessMemoryLimitMb { get; set; } = 256; } } diff --git a/TelegramSearchBot.Database/Model/AI/LLMProvider.cs b/TelegramSearchBot.Common/Model/AI/LLMProvider.cs similarity index 63% rename from TelegramSearchBot.Database/Model/AI/LLMProvider.cs rename to TelegramSearchBot.Common/Model/AI/LLMProvider.cs index 7c278a7f..a653f278 100644 --- a/TelegramSearchBot.Database/Model/AI/LLMProvider.cs +++ b/TelegramSearchBot.Common/Model/AI/LLMProvider.cs @@ -1,9 +1,3 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; - namespace TelegramSearchBot.Model.AI { public enum LLMProvider { None, diff --git a/TelegramSearchBot.Common/Model/AI/LlmAgentContracts.cs b/TelegramSearchBot.Common/Model/AI/LlmAgentContracts.cs new file mode 100644 index 00000000..9b048bc2 --- /dev/null +++ b/TelegramSearchBot.Common/Model/AI/LlmAgentContracts.cs @@ -0,0 +1,199 @@ +using System; +using System.Collections.Generic; + +namespace TelegramSearchBot.Model.AI { + public enum AgentTaskKind { + Message = 0, + Continuation = 1 + } + + public enum AgentTaskStatus { + Pending = 0, + Running = 1, + Completed = 2, + Failed = 3, + Recovering = 4, + Cancelled = 5 + } + + public enum AgentChunkType { + Snapshot = 0, + Done = 1, + Error = 2, + IterationLimitReached = 3 + } + + public sealed class AgentUserSnapshot { + public long UserId { get; set; } + public string FirstName { get; set; } = string.Empty; + public string LastName { get; set; } = string.Empty; + public string UserName { get; set; } = string.Empty; + public bool? IsPremium { get; set; } + public bool? IsBot { get; set; } + } + + public sealed class AgentMessageExtensionSnapshot { + public string Name { get; set; } = string.Empty; + public string Value { get; set; } = string.Empty; + } + + public sealed class AgentHistoryMessage { + public long DataId { get; set; } + public DateTime DateTime { get; set; } + public long GroupId { get; set; } + public long MessageId { get; set; } + public long FromUserId { get; set; } + public long ReplyToUserId { get; set; } + public long ReplyToMessageId { get; set; } + public string Content { get; set; } = string.Empty; + public AgentUserSnapshot User { get; set; } = new AgentUserSnapshot(); + public List Extensions { get; set; } = []; + } + + public sealed class AgentModelCapability { + public string Name { get; set; } = string.Empty; + public string Value { get; set; } = string.Empty; + public string Description { get; set; } = string.Empty; + } + + public sealed class AgentChannelConfig { + public int ChannelId { get; set; } + public string Name { get; set; } = string.Empty; + public string Gateway { get; set; } = string.Empty; + public string ApiKey { get; set; } = string.Empty; + public LLMProvider Provider { get; set; } + public int Parallel { get; set; } + public int Priority { get; set; } + public string ModelName { get; set; } = string.Empty; + public List Capabilities { get; set; } = []; + } + + public sealed class AgentExecutionTask { + public string TaskId { get; set; } = Guid.NewGuid().ToString("N"); + public AgentTaskKind Kind { get; set; } = AgentTaskKind.Message; + public long ChatId { get; set; } + public long UserId { get; set; } + public long MessageId { get; set; } + public long BotUserId { get; set; } + public string BotName { get; set; } = string.Empty; + public string InputMessage { get; set; } = string.Empty; + public string ModelName { get; set; } = string.Empty; + public int MaxToolCycles { get; set; } + public AgentChannelConfig Channel { get; set; } = new AgentChannelConfig(); + public List History { get; set; } = []; + public LlmContinuationSnapshot? ContinuationSnapshot { get; set; } + public DateTime CreatedAtUtc { get; set; } = DateTime.UtcNow; + public int RecoveryAttempt { get; set; } + } + + public sealed class AgentStreamChunk { + public string TaskId { get; set; } = string.Empty; + public AgentChunkType Type { get; set; } = AgentChunkType.Snapshot; + public int Sequence { get; set; } + public string Content { get; set; } = string.Empty; + public string ErrorMessage { get; set; } = string.Empty; + public LlmContinuationSnapshot? ContinuationSnapshot { get; set; } + public DateTime CreatedAtUtc { get; set; } = DateTime.UtcNow; + } + + public sealed class TelegramAgentToolTask { + public string RequestId { get; set; } = Guid.NewGuid().ToString("N"); + public string ToolName { get; set; } = string.Empty; + public Dictionary Arguments { get; set; } = new Dictionary(StringComparer.OrdinalIgnoreCase); + public long ChatId { get; set; } + public long UserId { get; set; } + public long MessageId { get; set; } + public DateTime CreatedAtUtc { get; set; } = DateTime.UtcNow; + } + + public sealed class TelegramAgentToolResult { + public string RequestId { get; set; } = string.Empty; + public bool Success { get; set; } + public string Result { get; set; } = string.Empty; + public string ErrorMessage { get; set; } = string.Empty; + public long TelegramMessageId { get; set; } + public DateTime CompletedAtUtc { get; set; } = DateTime.UtcNow; + } + + public sealed class AgentSessionInfo { + public long ChatId { get; set; } + public int ProcessId { get; set; } + public int Port { get; set; } + public string Status { get; set; } = "starting"; + public string CurrentTaskId { get; set; } = string.Empty; + public DateTime StartedAtUtc { get; set; } = DateTime.UtcNow; + public DateTime LastHeartbeatUtc { get; set; } = DateTime.UtcNow; + public DateTime LastActiveAtUtc { get; set; } = DateTime.UtcNow; + public DateTime ShutdownRequestedAtUtc { get; set; } = DateTime.MinValue; + public string ErrorMessage { get; set; } = string.Empty; + } + + public sealed class AgentControlCommand { + public long ChatId { get; set; } + public string Action { get; set; } = string.Empty; + public string Reason { get; set; } = string.Empty; + public DateTime RequestedAtUtc { get; set; } = DateTime.UtcNow; + } + + public sealed class AgentDeadLetterEntry { + public string TaskId { get; set; } = string.Empty; + public long ChatId { get; set; } + public string Reason { get; set; } = string.Empty; + public int RecoveryAttempt { get; set; } + public string Payload { get; set; } = string.Empty; + public string LastContent { get; set; } = string.Empty; + public DateTime FailedAtUtc { get; set; } = DateTime.UtcNow; + } + + public sealed class SubAgentTaskEnvelope { + public string RequestId { get; set; } = Guid.NewGuid().ToString("N"); + public string Type { get; set; } = "echo"; + public string Payload { get; set; } = string.Empty; + public SubAgentMcpExecuteRequest? McpExecute { get; set; } + public SubAgentBackgroundTaskRequest? BackgroundTask { get; set; } + public DateTime CreatedAtUtc { get; set; } = DateTime.UtcNow; + } + + public sealed class SubAgentMcpExecuteRequest { + public string ServerName { get; set; } = "subagent"; + public string Command { get; set; } = string.Empty; + public List Args { get; set; } = []; + public Dictionary Env { get; set; } = new(StringComparer.OrdinalIgnoreCase); + public int TimeoutSeconds { get; set; } = 30; + public string ToolName { get; set; } = string.Empty; + public Dictionary Arguments { get; set; } = new(StringComparer.OrdinalIgnoreCase); + } + + public sealed class SubAgentBackgroundTaskRequest { + public string Command { get; set; } = string.Empty; + public List Args { get; set; } = []; + public Dictionary Env { get; set; } = new(StringComparer.OrdinalIgnoreCase); + public string WorkingDirectory { get; set; } = string.Empty; + public int TimeoutSeconds { get; set; } = 30; + } + + public sealed class SubAgentTaskResult { + public string RequestId { get; set; } = string.Empty; + public bool Success { get; set; } + public string Result { get; set; } = string.Empty; + public string ErrorMessage { get; set; } = string.Empty; + public int ExitCode { get; set; } + public DateTime CompletedAtUtc { get; set; } = DateTime.UtcNow; + } + + public static class LlmAgentRedisKeys { + public const string AgentTaskQueue = "AGENT_TASKS"; + public const string AgentTaskDeadLetterQueue = "AGENT_TASKS:DEAD"; + public const string TelegramTaskQueue = "TELEGRAM_TASKS"; + public const string ActiveTaskSet = "AGENT_ACTIVE_TASKS"; + public const string SubAgentTaskQueue = "SUBAGENT_TASKS"; + + public static string AgentTaskState(string taskId) => $"AGENT_TASK:{taskId}"; + public static string AgentChunks(string taskId) => $"AGENT_CHUNKS:{taskId}"; + public static string AgentChunkIndex(string taskId) => $"AGENT_CHUNK_INDEX:{taskId}"; + public static string AgentSession(long chatId) => $"AGENT_SESSION:{chatId}"; + public static string AgentControl(long chatId) => $"AGENT_CONTROL:{chatId}"; + public static string TelegramResult(string requestId) => $"TELEGRAM_RESULT:{requestId}"; + public static string SubAgentResult(string requestId) => $"SUBAGENT_RESULT:{requestId}"; + } +} diff --git a/TelegramSearchBot.LLM.Test/Service/AI/LLM/GarnetClientTests.cs b/TelegramSearchBot.LLM.Test/Service/AI/LLM/GarnetClientTests.cs new file mode 100644 index 00000000..41e42cf5 --- /dev/null +++ b/TelegramSearchBot.LLM.Test/Service/AI/LLM/GarnetClientTests.cs @@ -0,0 +1,44 @@ +using System.Threading.Tasks; +using Moq; +using StackExchange.Redis; +using TelegramSearchBot.LLMAgent.Service; +using TelegramSearchBot.Model.AI; +using Xunit; + +namespace TelegramSearchBot.LLM.Test.Service.AI.LLM { + public class GarnetClientTests { + private readonly Mock _redisMock = new(); + private readonly Mock _dbMock = new(); + + public GarnetClientTests() { + _redisMock.Setup(r => r.GetDatabase(It.IsAny(), It.IsAny())).Returns(_dbMock.Object); + } + + [Fact] + public async Task PublishChunkAsync_WritesSerializedChunkToRedisList() { + RedisKey capturedKey = default; + RedisValue capturedValue = default; + + _dbMock.Setup(d => d.ListRightPushAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny())) + .Callback((key, value, _, _) => { + capturedKey = key; + capturedValue = value; + }) + .ReturnsAsync(1); + + var client = new GarnetClient(_redisMock.Object); + await client.PublishChunkAsync(new AgentStreamChunk { + TaskId = "task-1", + Type = AgentChunkType.Snapshot, + Content = "hello" + }); + + Assert.Equal(LlmAgentRedisKeys.AgentChunks("task-1"), capturedKey.ToString()); + Assert.Contains("\"Content\":\"hello\"", capturedValue.ToString()); + } + } +} diff --git a/TelegramSearchBot.LLM.Test/Service/AI/LLM/GarnetRpcClientTests.cs b/TelegramSearchBot.LLM.Test/Service/AI/LLM/GarnetRpcClientTests.cs new file mode 100644 index 00000000..a67d1661 --- /dev/null +++ b/TelegramSearchBot.LLM.Test/Service/AI/LLM/GarnetRpcClientTests.cs @@ -0,0 +1,62 @@ +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Moq; +using StackExchange.Redis; +using TelegramSearchBot.LLMAgent.Service; +using TelegramSearchBot.Model.AI; +using Xunit; + +namespace TelegramSearchBot.LLM.Test.Service.AI.LLM { + public class GarnetRpcClientTests { + private readonly Mock _redisMock = new(); + private readonly Mock _dbMock = new(); + + public GarnetRpcClientTests() { + _redisMock.Setup(r => r.GetDatabase(It.IsAny(), It.IsAny())).Returns(_dbMock.Object); + } + + [Fact] + public async Task SaveTaskStateAsync_WritesStatusErrorAndExtraFields() { + var writes = new Dictionary(); + _dbMock.Setup(d => d.HashSetAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny())) + .Callback((_, field, value, _, _) => { + writes[field.ToString()] = value.ToString(); + }) + .ReturnsAsync(true); + + var client = new GarnetRpcClient(_redisMock.Object); + await client.SaveTaskStateAsync("task-1", AgentTaskStatus.Running, null, new Dictionary { + ["payload"] = "json", + ["lastContent"] = "hello" + }); + + Assert.Equal(AgentTaskStatus.Running.ToString(), writes["status"]); + Assert.Equal(string.Empty, writes["error"]); + Assert.Equal("json", writes["payload"]); + Assert.Equal("hello", writes["lastContent"]); + Assert.True(writes.ContainsKey("updatedAtUtc")); + } + + [Fact] + public async Task WaitForTelegramResultAsync_ReturnsDeserializedResult() { + var json = "{\"RequestId\":\"req-1\",\"Success\":true,\"Result\":\"42\"}"; + _dbMock.Setup(d => d.StringGetAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new RedisValue(json)); + _dbMock.Setup(d => d.KeyDeleteAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(true); + + var client = new GarnetRpcClient(_redisMock.Object); + var result = await client.WaitForTelegramResultAsync("req-1", TimeSpan.FromSeconds(1), CancellationToken.None); + + Assert.NotNull(result); + Assert.True(result.Success); + Assert.Equal("42", result.Result); + } + } +} diff --git a/TelegramSearchBot.LLM.Test/Service/AI/LLM/ToolExecutorTests.cs b/TelegramSearchBot.LLM.Test/Service/AI/LLM/ToolExecutorTests.cs new file mode 100644 index 00000000..e2729252 --- /dev/null +++ b/TelegramSearchBot.LLM.Test/Service/AI/LLM/ToolExecutorTests.cs @@ -0,0 +1,61 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Moq; +using StackExchange.Redis; +using TelegramSearchBot.LLMAgent.Service; +using TelegramSearchBot.Model.AI; +using Xunit; + +namespace TelegramSearchBot.LLM.Test.Service.AI.LLM { + public class ToolExecutorTests { + [Fact] + public async Task EchoAsync_ReturnsInputText() { + var executor = new ToolExecutor(null!, null!); + + var result = await executor.EchoAsync("hello"); + + Assert.Equal("hello", result); + } + + [Fact] + public async Task CalculateAsync_EvaluatesExpression() { + var executor = new ToolExecutor(null!, null!); + + var result = await executor.CalculateAsync("1 + 2 * 3"); + + Assert.Equal("7", result); + } + + [Fact] + public async Task SendMessageAsync_QueuesTelegramTaskAndReturnsResult() { + var redisMock = new Mock(); + var dbMock = new Mock(); + redisMock.Setup(r => r.GetDatabase(It.IsAny(), It.IsAny())).Returns(dbMock.Object); + + RedisKey pushedKey = default; + RedisValue pushedValue = default; + dbMock.Setup(d => d.ListRightPushAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny())) + .Callback((key, value, _, _) => { + pushedKey = key; + pushedValue = value; + }) + .ReturnsAsync(1); + dbMock.Setup(d => d.StringGetAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new RedisValue("{\"RequestId\":\"req\",\"Success\":true,\"Result\":\"123\"}")); + dbMock.Setup(d => d.KeyDeleteAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(true); + + var executor = new ToolExecutor(new GarnetClient(redisMock.Object), new GarnetRpcClient(redisMock.Object)); + var result = await executor.SendMessageAsync(100, "hello", 1, 2, CancellationToken.None); + + Assert.Equal(LlmAgentRedisKeys.TelegramTaskQueue, pushedKey.ToString()); + Assert.Contains("\"ToolName\":\"send_message\"", pushedValue.ToString()); + Assert.Equal("123", result); + } + } +} diff --git a/TelegramSearchBot.LLM.Test/TelegramSearchBot.LLM.Test.csproj b/TelegramSearchBot.LLM.Test/TelegramSearchBot.LLM.Test.csproj index bcd323e1..a8fa6596 100644 --- a/TelegramSearchBot.LLM.Test/TelegramSearchBot.LLM.Test.csproj +++ b/TelegramSearchBot.LLM.Test/TelegramSearchBot.LLM.Test.csproj @@ -14,10 +14,10 @@ - - + + - + @@ -30,6 +30,7 @@ + - \ No newline at end of file + diff --git a/TelegramSearchBot.LLMAgent/LLMAgentProgram.cs b/TelegramSearchBot.LLMAgent/LLMAgentProgram.cs new file mode 100644 index 00000000..3b9e7c8d --- /dev/null +++ b/TelegramSearchBot.LLMAgent/LLMAgentProgram.cs @@ -0,0 +1,74 @@ +using System.Reflection; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using StackExchange.Redis; +using TelegramSearchBot.Common; +using TelegramSearchBot.Interface; +using TelegramSearchBot.Interface.AI.LLM; +using TelegramSearchBot.Model; +using TelegramSearchBot.Service.AI.LLM; + +namespace TelegramSearchBot.LLMAgent { + public static class LLMAgentProgram { + public static async Task RunAsync(string[] args) { + var effectiveArgs = NormalizeArgs(args); + if (effectiveArgs.Length != 2 || + !long.TryParse(effectiveArgs[0], out var chatId) || + !int.TryParse(effectiveArgs[1], out var port)) { + Console.Error.WriteLine("Usage: LLMAgent "); + Environment.ExitCode = 1; + return; + } + + using var services = BuildServices(port); + var logger = services.GetRequiredService().CreateLogger("LLMAgent"); + McpToolHelper.EnsureInitialized(typeof(Service.AgentToolService).Assembly, services, logger); + + var loop = services.GetRequiredService(); + using var shutdownCts = new CancellationTokenSource(); + Console.CancelKeyPress += (_, eventArgs) => { + eventArgs.Cancel = true; + shutdownCts.Cancel(); + }; + AppDomain.CurrentDomain.ProcessExit += (_, _) => shutdownCts.Cancel(); + + await loop.RunAsync(chatId, port, shutdownCts.Token); + } + + private static string[] NormalizeArgs(string[] args) { + if (args.Length > 0 && args[0].Equals("LLMAgent", StringComparison.OrdinalIgnoreCase)) { + return args.Skip(1).ToArray(); + } + + return args; + } + + private static ServiceProvider BuildServices(int port) { + var services = new ServiceCollection(); + services.AddLogging(builder => builder.AddSimpleConsole(options => { + options.SingleLine = true; + options.TimestampFormat = "[yyyy-MM-dd HH:mm:ss] "; + })); + services.AddHttpClient(); + services.AddHttpClient("OllamaClient"); + services.AddDbContext(options => { + options.UseInMemoryDatabase($"llm-agent-{port}"); + }, contextLifetime: ServiceLifetime.Scoped, optionsLifetime: ServiceLifetime.Singleton); + services.AddSingleton(_ => ConnectionMultiplexer.Connect($"localhost:{port},abortConnect=false,connectTimeout=5000,connectRetry=5")); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddScoped(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + return services.BuildServiceProvider(); + } + } +} diff --git a/TelegramSearchBot.LLMAgent/Program.cs b/TelegramSearchBot.LLMAgent/Program.cs new file mode 100644 index 00000000..fcdd09fc --- /dev/null +++ b/TelegramSearchBot.LLMAgent/Program.cs @@ -0,0 +1,7 @@ +namespace TelegramSearchBot.LLMAgent { + public static class Program { + public static Task Main(string[] args) { + return LLMAgentProgram.RunAsync(args); + } + } +} diff --git a/TelegramSearchBot.LLMAgent/Service/AgentLoopService.cs b/TelegramSearchBot.LLMAgent/Service/AgentLoopService.cs new file mode 100644 index 00000000..591c8dc8 --- /dev/null +++ b/TelegramSearchBot.LLMAgent/Service/AgentLoopService.cs @@ -0,0 +1,227 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Newtonsoft.Json; +using TelegramSearchBot.Common; +using TelegramSearchBot.Model.AI; + +namespace TelegramSearchBot.LLMAgent.Service { + public sealed class AgentLoopService { + private readonly IServiceProvider _serviceProvider; + private readonly GarnetClient _garnetClient; + private readonly GarnetRpcClient _rpcClient; + private readonly ILogger _logger; + + public AgentLoopService( + IServiceProvider serviceProvider, + GarnetClient garnetClient, + GarnetRpcClient rpcClient, + ILogger logger) { + _serviceProvider = serviceProvider; + _garnetClient = garnetClient; + _rpcClient = rpcClient; + _logger = logger; + } + + public async Task RunAsync(long chatId, int port, CancellationToken cancellationToken) { + using var heartbeatCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + var session = new AgentSessionInfo { + ChatId = chatId, + Port = port, + ProcessId = Environment.ProcessId, + Status = "idle" + }; + + var heartbeatTask = RunHeartbeatAsync(session, heartbeatCts.Token); + await _rpcClient.SaveSessionAsync(session); + + try { + while (!cancellationToken.IsCancellationRequested) { + if (await IsShutdownRequestedAsync(chatId)) { + session.Status = "shutting_down"; + await _rpcClient.SaveSessionAsync(session); + break; + } + + var payload = await _garnetClient.BRPopAsync(LlmAgentRedisKeys.AgentTaskQueue, TimeSpan.FromSeconds(5)); + if (string.IsNullOrWhiteSpace(payload)) { + continue; + } + + var task = JsonConvert.DeserializeObject(payload); + if (task == null || task.ChatId != chatId) { + if (task != null) { + await _garnetClient.RPushAsync(LlmAgentRedisKeys.AgentTaskQueue, payload); + } + + continue; + } + + if (await IsShutdownRequestedAsync(chatId)) { + await _garnetClient.RPushAsync(LlmAgentRedisKeys.AgentTaskQueue, payload); + session.Status = "shutting_down"; + await _rpcClient.SaveSessionAsync(session); + break; + } + + session.Status = "processing"; + session.CurrentTaskId = task.TaskId; + session.LastActiveAtUtc = DateTime.UtcNow; + await _rpcClient.SaveSessionAsync(session); + + var taskState = await _rpcClient.GetTaskStateAsync(task.TaskId); + var recoveredContent = taskState.TryGetValue("lastContent", out var existingContent) + ? existingContent + : string.Empty; + var recoveryAttempt = taskState.TryGetValue("recoveryCount", out var recoveryCountString) && + int.TryParse(recoveryCountString, out var recoveryAttemptValue) + ? recoveryAttemptValue + : 0; + + task.RecoveryAttempt = recoveryAttempt; + + await _rpcClient.SaveTaskStateAsync(task.TaskId, AgentTaskStatus.Running, null, new Dictionary { + ["payload"] = payload, + ["workerChatId"] = chatId.ToString(), + ["startedAtUtc"] = DateTime.UtcNow.ToString("O"), + ["recoveryCount"] = recoveryAttempt.ToString() + }); + + var shouldStopAfterTask = false; + try { + await ProcessTaskAsync(task, payload, chatId, recoveredContent, cancellationToken); + } finally { + session.Status = await IsShutdownRequestedAsync(chatId) ? "shutting_down" : "idle"; + session.CurrentTaskId = string.Empty; + session.LastActiveAtUtc = DateTime.UtcNow; + await _rpcClient.SaveSessionAsync(session); + shouldStopAfterTask = session.Status == "shutting_down"; + } + + if (shouldStopAfterTask) { + break; + } + } + } finally { + heartbeatCts.Cancel(); + session.Status = "stopped"; + session.LastHeartbeatUtc = DateTime.UtcNow; + await _rpcClient.SaveSessionAsync(session); + await _rpcClient.KeyDeleteAsync(LlmAgentRedisKeys.AgentControl(chatId)); + await heartbeatTask; + } + } + + public async Task ProcessTaskAsync( + AgentExecutionTask task, + string payload, + long workerChatId, + string recoveredContent, + CancellationToken cancellationToken) { + using var scope = _serviceProvider.CreateScope(); + var executor = scope.ServiceProvider.GetRequiredService(); + var executionContext = new LlmExecutionContext(); + var sequence = 0; + var suppressUntilRecoveryCatchup = !string.IsNullOrWhiteSpace(recoveredContent); + + try { + await foreach (var snapshot in executor.CallAsync(task, executionContext, cancellationToken).WithCancellation(cancellationToken)) { + await _rpcClient.SaveTaskStateAsync(task.TaskId, AgentTaskStatus.Running, null, new Dictionary { + ["lastContent"] = snapshot, + ["lastSequence"] = sequence.ToString() + }); + + if (!ShouldPublishSnapshot(snapshot, recoveredContent, ref suppressUntilRecoveryCatchup)) { + continue; + } + + await _garnetClient.PublishChunkAsync(new AgentStreamChunk { + TaskId = task.TaskId, + Type = AgentChunkType.Snapshot, + Sequence = sequence++, + Content = snapshot + }); + } + + if (executionContext.IterationLimitReached && executionContext.SnapshotData != null) { + await _garnetClient.PublishChunkAsync(new AgentStreamChunk { + TaskId = task.TaskId, + Type = AgentChunkType.IterationLimitReached, + Sequence = sequence++, + Content = executionContext.SnapshotData.LastAccumulatedContent ?? string.Empty, + ContinuationSnapshot = executionContext.SnapshotData + }); + await _rpcClient.SaveTaskStateAsync(task.TaskId, AgentTaskStatus.Completed, null, new Dictionary { + ["payload"] = payload, + ["workerChatId"] = workerChatId.ToString(), + ["lastContent"] = executionContext.SnapshotData.LastAccumulatedContent ?? string.Empty, + ["completedAtUtc"] = DateTime.UtcNow.ToString("O") + }); + } else { + await _garnetClient.PublishChunkAsync(new AgentStreamChunk { + TaskId = task.TaskId, + Type = AgentChunkType.Done, + Sequence = sequence, + }); + await _rpcClient.SaveTaskStateAsync(task.TaskId, AgentTaskStatus.Completed, null, new Dictionary { + ["payload"] = payload, + ["workerChatId"] = workerChatId.ToString(), + ["completedAtUtc"] = DateTime.UtcNow.ToString("O") + }); + } + } catch (Exception ex) { + _logger.LogError(ex, "Agent task {TaskId} failed", task.TaskId); + await _garnetClient.PublishChunkAsync(new AgentStreamChunk { + TaskId = task.TaskId, + Type = AgentChunkType.Error, + Sequence = sequence, + ErrorMessage = ex.Message + }); + await _rpcClient.SaveTaskStateAsync(task.TaskId, AgentTaskStatus.Failed, ex.Message, new Dictionary { + ["payload"] = payload, + ["workerChatId"] = workerChatId.ToString(), + ["failedAtUtc"] = DateTime.UtcNow.ToString("O") + }); + } + } + + private async Task RunHeartbeatAsync(AgentSessionInfo session, CancellationToken cancellationToken) { + using var timer = new PeriodicTimer(TimeSpan.FromSeconds(Math.Max(1, Env.AgentHeartbeatIntervalSeconds))); + while (!cancellationToken.IsCancellationRequested && await timer.WaitForNextTickAsync(cancellationToken)) { + session.LastHeartbeatUtc = DateTime.UtcNow; + await _rpcClient.SaveSessionAsync(session); + } + } + + private async Task IsShutdownRequestedAsync(long chatId) { + var command = await _rpcClient.GetControlCommandAsync(chatId); + return command != null && command.Action.Equals("shutdown", StringComparison.OrdinalIgnoreCase); + } + + private static bool ShouldPublishSnapshot(string snapshot, string recoveredContent, ref bool suppressUntilRecoveryCatchup) { + if (!suppressUntilRecoveryCatchup) { + return true; + } + + if (string.IsNullOrWhiteSpace(recoveredContent)) { + suppressUntilRecoveryCatchup = false; + return true; + } + + if (string.Equals(snapshot, recoveredContent, StringComparison.Ordinal)) { + return false; + } + + if (snapshot.Length < recoveredContent.Length && recoveredContent.StartsWith(snapshot, StringComparison.Ordinal)) { + return false; + } + + if (snapshot.StartsWith(recoveredContent, StringComparison.Ordinal)) { + suppressUntilRecoveryCatchup = false; + return snapshot.Length > recoveredContent.Length; + } + + suppressUntilRecoveryCatchup = false; + return true; + } + } +} diff --git a/TelegramSearchBot.LLMAgent/Service/AgentToolService.cs b/TelegramSearchBot.LLMAgent/Service/AgentToolService.cs new file mode 100644 index 00000000..7b255118 --- /dev/null +++ b/TelegramSearchBot.LLMAgent/Service/AgentToolService.cs @@ -0,0 +1,30 @@ +using TelegramSearchBot.Attributes; +using TelegramSearchBot.Model; + +namespace TelegramSearchBot.LLMAgent.Service { + public sealed class AgentToolService { + private readonly ToolExecutor _toolExecutor; + + public AgentToolService(ToolExecutor toolExecutor) { + _toolExecutor = toolExecutor; + } + + [BuiltInTool("Return the input text unchanged.", Name = "echo")] + public Task EchoAsync([BuiltInParameter("Text to echo back.")] string text) { + return _toolExecutor.EchoAsync(text); + } + + [BuiltInTool("Evaluate a simple arithmetic expression.", Name = "calculator")] + public Task CalculatorAsync([BuiltInParameter("Arithmetic expression such as 1+2*3.")] string expression) { + return _toolExecutor.CalculateAsync(expression); + } + + [BuiltInTool("Send a plain text Telegram message via the main process.", Name = "send_message")] + public Task SendMessageAsync( + [BuiltInParameter("Target Telegram chat ID.")] long chatId, + [BuiltInParameter("Message text to send.")] string text, + ToolContext toolContext) { + return _toolExecutor.SendMessageAsync(chatId, text, toolContext?.UserId ?? 0, toolContext?.MessageId ?? 0, CancellationToken.None); + } + } +} diff --git a/TelegramSearchBot.LLMAgent/Service/GarnetClient.cs b/TelegramSearchBot.LLMAgent/Service/GarnetClient.cs new file mode 100644 index 00000000..cffab9e8 --- /dev/null +++ b/TelegramSearchBot.LLMAgent/Service/GarnetClient.cs @@ -0,0 +1,39 @@ +using Newtonsoft.Json; +using StackExchange.Redis; +using TelegramSearchBot.Model.AI; + +namespace TelegramSearchBot.LLMAgent.Service { + public sealed class GarnetClient { + private readonly IConnectionMultiplexer _redis; + + public GarnetClient(IConnectionMultiplexer redis) { + _redis = redis; + } + + public Task LPushAsync(string key, string value) { + return _redis.GetDatabase().ListLeftPushAsync(key, value); + } + + public Task RPushAsync(string key, string value) { + return _redis.GetDatabase().ListRightPushAsync(key, value); + } + + public async Task BRPopAsync(string key, TimeSpan timeout) { + var result = await _redis.GetDatabase().ExecuteAsync("BRPOP", key, (int)Math.Ceiling(timeout.TotalSeconds)); + if (result.IsNull) { + return null; + } + + var parts = (RedisResult[])result!; + if (parts.Length == 2) { + return parts[1].ToString(); + } + + return null; + } + + public Task PublishChunkAsync(AgentStreamChunk chunk) { + return RPushAsync(LlmAgentRedisKeys.AgentChunks(chunk.TaskId), JsonConvert.SerializeObject(chunk)); + } + } +} diff --git a/TelegramSearchBot.LLMAgent/Service/GarnetRpcClient.cs b/TelegramSearchBot.LLMAgent/Service/GarnetRpcClient.cs new file mode 100644 index 00000000..2227b429 --- /dev/null +++ b/TelegramSearchBot.LLMAgent/Service/GarnetRpcClient.cs @@ -0,0 +1,121 @@ +using Newtonsoft.Json; +using StackExchange.Redis; +using TelegramSearchBot.Common; +using TelegramSearchBot.Model.AI; + +namespace TelegramSearchBot.LLMAgent.Service { + public sealed class GarnetRpcClient { + private readonly IConnectionMultiplexer _redis; + + public GarnetRpcClient(IConnectionMultiplexer redis) { + _redis = redis; + } + + private IDatabase Db => _redis.GetDatabase(); + + public Task HashSetAsync(string key, string field, string value) => Db.HashSetAsync(key, field, value); + public async Task HashGetAsync(string key, string field) => (await Db.HashGetAsync(key, field)).ToString(); + public async Task> HashGetAllAsync(string key) { + var entries = await Db.HashGetAllAsync(key); + return entries.ToDictionary(x => x.Name.ToString(), x => x.Value.ToString(), StringComparer.OrdinalIgnoreCase); + } + + public async Task> ListRangeAsync(string key, long start, long stop) { + var values = await Db.ListRangeAsync(key, start, stop); + return values.Select(x => x.ToString()).ToList(); + } + + public Task ListTrimAsync(string key, long start, long stop) => Db.ListTrimAsync(key, start, stop); + public Task IncrementAsync(string key) => Db.StringIncrementAsync(key); + public Task DecrementAsync(string key) => Db.StringDecrementAsync(key); + public Task KeyDeleteAsync(string key) => Db.KeyDeleteAsync(key); + public Task KeyExpireAsync(string key, TimeSpan expiry) => Db.KeyExpireAsync(key, expiry); + public async Task StringSetAsync(string key, string value, TimeSpan? expiry = null) { + if (expiry.HasValue) { + await Db.ExecuteAsync("SETEX", key, Math.Max(1, (int)Math.Ceiling(expiry.Value.TotalSeconds)), value); + return true; + } + + await Db.ExecuteAsync("SET", key, value); + return true; + } + public async Task StringGetAsync(string key) => (await Db.StringGetAsync(key)).ToString(); + + public Task SaveTaskStateAsync(string taskId, AgentTaskStatus status, string? error = null, IReadOnlyDictionary? extraFields = null) { + var key = LlmAgentRedisKeys.AgentTaskState(taskId); + var tasks = new List { + HashSetAsync(key, "status", status.ToString()), + HashSetAsync(key, "updatedAtUtc", DateTime.UtcNow.ToString("O")), + HashSetAsync(key, "error", error ?? string.Empty) + }; + + if (extraFields != null) { + tasks.AddRange(extraFields.Select(entry => HashSetAsync(key, entry.Key, entry.Value ?? string.Empty))); + } + + return Task.WhenAll(tasks); + } + + public Task> GetTaskStateAsync(string taskId) { + return HashGetAllAsync(LlmAgentRedisKeys.AgentTaskState(taskId)); + } + + public async Task SaveSessionAsync(AgentSessionInfo session) { + var key = LlmAgentRedisKeys.AgentSession(session.ChatId); + var fields = new Dictionary { + ["chatId"] = session.ChatId.ToString(), + ["processId"] = session.ProcessId.ToString(), + ["port"] = session.Port.ToString(), + ["status"] = session.Status, + ["currentTaskId"] = session.CurrentTaskId, + ["startedAtUtc"] = session.StartedAtUtc.ToString("O"), + ["lastHeartbeatUtc"] = session.LastHeartbeatUtc.ToString("O"), + ["lastActiveAtUtc"] = session.LastActiveAtUtc.ToString("O"), + ["shutdownRequestedAtUtc"] = session.ShutdownRequestedAtUtc == DateTime.MinValue ? string.Empty : session.ShutdownRequestedAtUtc.ToString("O"), + ["error"] = session.ErrorMessage + }; + + foreach (var entry in fields) { + await HashSetAsync(key, entry.Key, entry.Value); + } + + await KeyExpireAsync(key, TimeSpan.FromSeconds(Math.Max(Env.AgentHeartbeatTimeoutSeconds * 2, 30))); + } + + public async Task WaitForTelegramResultAsync(string requestId, TimeSpan timeout, CancellationToken cancellationToken) { + var key = LlmAgentRedisKeys.TelegramResult(requestId); + var startedAt = DateTime.UtcNow; + + while (DateTime.UtcNow - startedAt < timeout && !cancellationToken.IsCancellationRequested) { + var json = await StringGetAsync(key); + if (!string.IsNullOrWhiteSpace(json)) { + await KeyDeleteAsync(key); + return JsonConvert.DeserializeObject(json); + } + + await Task.Delay(200, cancellationToken); + } + + return null; + } + + public Task RequestShutdownAsync(long chatId, string reason) { + var command = new AgentControlCommand { + ChatId = chatId, + Action = "shutdown", + Reason = reason, + RequestedAtUtc = DateTime.UtcNow + }; + + return StringSetAsync( + LlmAgentRedisKeys.AgentControl(chatId), + JsonConvert.SerializeObject(command), + TimeSpan.FromSeconds(Math.Max(Env.AgentShutdownGracePeriodSeconds * 2, 30))); + } + + public async Task GetControlCommandAsync(long chatId) { + var json = await StringGetAsync(LlmAgentRedisKeys.AgentControl(chatId)); + return string.IsNullOrWhiteSpace(json) ? null : JsonConvert.DeserializeObject(json); + } + } +} diff --git a/TelegramSearchBot.LLMAgent/Service/IAgentTaskExecutor.cs b/TelegramSearchBot.LLMAgent/Service/IAgentTaskExecutor.cs new file mode 100644 index 00000000..bd8d1eaa --- /dev/null +++ b/TelegramSearchBot.LLMAgent/Service/IAgentTaskExecutor.cs @@ -0,0 +1,10 @@ +using TelegramSearchBot.Model.AI; + +namespace TelegramSearchBot.LLMAgent.Service { + public interface IAgentTaskExecutor { + IAsyncEnumerable CallAsync( + AgentExecutionTask task, + LlmExecutionContext executionContext, + CancellationToken cancellationToken); + } +} diff --git a/TelegramSearchBot.LLMAgent/Service/InMemoryMessageExtensionService.cs b/TelegramSearchBot.LLMAgent/Service/InMemoryMessageExtensionService.cs new file mode 100644 index 00000000..1fe7d10c --- /dev/null +++ b/TelegramSearchBot.LLMAgent/Service/InMemoryMessageExtensionService.cs @@ -0,0 +1,63 @@ +using Microsoft.EntityFrameworkCore; +using TelegramSearchBot.Interface; +using TelegramSearchBot.Model; +using TelegramSearchBot.Model.Data; + +namespace TelegramSearchBot.LLMAgent.Service { + public sealed class InMemoryMessageExtensionService : IMessageExtensionService { + private readonly DataDbContext _dbContext; + + public InMemoryMessageExtensionService(DataDbContext dbContext) { + _dbContext = dbContext; + } + + public string ServiceName => nameof(InMemoryMessageExtensionService); + + public Task GetByIdAsync(int id) => _dbContext.MessageExtensions.FindAsync(id).AsTask(); + + public Task> GetByMessageDataIdAsync(long messageDataId) { + return _dbContext.MessageExtensions.Where(x => x.MessageDataId == messageDataId).ToListAsync(); + } + + public async Task AddOrUpdateAsync(MessageExtension extension) { + var existing = await _dbContext.MessageExtensions.FirstOrDefaultAsync(x => + x.MessageDataId == extension.MessageDataId && + x.Name == extension.Name); + + if (existing == null) { + await _dbContext.MessageExtensions.AddAsync(extension); + } else { + existing.Value = extension.Value; + } + + await _dbContext.SaveChangesAsync(); + } + + public Task AddOrUpdateAsync(long messageDataId, string name, string value) { + return AddOrUpdateAsync(new MessageExtension { + MessageDataId = messageDataId, + Name = name, + Value = value + }); + } + + public async Task DeleteAsync(int id) { + var entity = await _dbContext.MessageExtensions.FindAsync(id); + if (entity != null) { + _dbContext.MessageExtensions.Remove(entity); + await _dbContext.SaveChangesAsync(); + } + } + + public async Task DeleteByMessageDataIdAsync(long messageDataId) { + var items = await _dbContext.MessageExtensions.Where(x => x.MessageDataId == messageDataId).ToListAsync(); + _dbContext.MessageExtensions.RemoveRange(items); + await _dbContext.SaveChangesAsync(); + } + + public async Task GetMessageIdByMessageIdAndGroupId(long messageId, long groupId) { + var entity = await _dbContext.Messages.FirstOrDefaultAsync(x => x.MessageId == messageId && x.GroupId == groupId); + return entity?.Id; + } + } +} diff --git a/TelegramSearchBot.LLMAgent/Service/LlmServiceProxy.cs b/TelegramSearchBot.LLMAgent/Service/LlmServiceProxy.cs new file mode 100644 index 00000000..57320341 --- /dev/null +++ b/TelegramSearchBot.LLMAgent/Service/LlmServiceProxy.cs @@ -0,0 +1,158 @@ +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using TelegramSearchBot.Common; +using TelegramSearchBot.Interface.AI.LLM; +using TelegramSearchBot.Model; +using TelegramSearchBot.Model.AI; +using TelegramSearchBot.Model.Data; +using TelegramSearchBot.Service.AI.LLM; + +namespace TelegramSearchBot.LLMAgent.Service { + public sealed class LlmServiceProxy : IAgentTaskExecutor { + private readonly IServiceProvider _serviceProvider; + private readonly ILogger _logger; + + public LlmServiceProxy(IServiceProvider serviceProvider, ILogger logger) { + _serviceProvider = serviceProvider; + _logger = logger; + } + + public async IAsyncEnumerable CallAsync( + AgentExecutionTask task, + LlmExecutionContext executionContext, + [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken) { + await SeedTaskDataAsync(task, cancellationToken); + + var service = ResolveService(task.Channel.Provider); + ApplyBotIdentity(service, task.BotName, task.BotUserId); + var channel = ToEntity(task.Channel); + + if (task.Kind == AgentTaskKind.Continuation && task.ContinuationSnapshot != null) { + await foreach (var chunk in service.ResumeFromSnapshotAsync(task.ContinuationSnapshot, channel, executionContext, cancellationToken) + .WithCancellation(cancellationToken)) { + yield return chunk; + } + + yield break; + } + + var message = new Message { + Id = -1, + GroupId = task.ChatId, + MessageId = task.MessageId, + FromUserId = task.UserId, + ReplyToMessageId = 0, + Content = task.InputMessage, + DateTime = task.CreatedAtUtc + }; + + await foreach (var chunk in service.ExecAsync(message, task.ChatId, task.ModelName, channel, executionContext, cancellationToken) + .WithCancellation(cancellationToken)) { + yield return chunk; + } + } + + private ILLMService ResolveService(LLMProvider provider) { + return provider switch { + LLMProvider.Ollama => _serviceProvider.GetRequiredService(), + LLMProvider.Gemini => _serviceProvider.GetRequiredService(), + LLMProvider.Anthropic => _serviceProvider.GetRequiredService(), + _ => _serviceProvider.GetRequiredService() + }; + } + + private static void ApplyBotIdentity(ILLMService service, string botName, long botUserId) { + Env.BotId = botUserId; + switch (service) { + case OpenAIService openAi: + openAi.BotName = botName; + break; + case OllamaService ollama: + ollama.BotName = botName; + break; + case GeminiService gemini: + gemini.BotName = botName; + break; + case AnthropicService anthropic: + anthropic.BotName = botName; + break; + } + } + + private async Task SeedTaskDataAsync(AgentExecutionTask task, CancellationToken cancellationToken) { + var dbContext = _serviceProvider.GetRequiredService(); + await dbContext.Database.EnsureDeletedAsync(cancellationToken); + await dbContext.Database.EnsureCreatedAsync(cancellationToken); + + dbContext.LLMChannels.Add(ToEntity(task.Channel)); + var channelWithModel = new ChannelWithModel { + Id = 1, + LLMChannelId = task.Channel.ChannelId, + ModelName = task.ModelName, + IsDeleted = false + }; + dbContext.ChannelsWithModel.Add(channelWithModel); + dbContext.GroupSettings.Add(new GroupSettings { + GroupId = task.ChatId, + LLMModelName = task.ModelName + }); + + foreach (var capability in task.Channel.Capabilities) { + dbContext.ModelCapabilities.Add(new ModelCapability { + ChannelWithModelId = channelWithModel.Id, + CapabilityName = capability.Name, + CapabilityValue = capability.Value, + Description = capability.Description + }); + } + + var seededUsers = new HashSet(); + foreach (var historyMessage in task.History) { + dbContext.Messages.Add(new Message { + Id = historyMessage.DataId, + DateTime = historyMessage.DateTime, + GroupId = historyMessage.GroupId, + MessageId = historyMessage.MessageId, + FromUserId = historyMessage.FromUserId, + ReplyToUserId = historyMessage.ReplyToUserId, + ReplyToMessageId = historyMessage.ReplyToMessageId, + Content = historyMessage.Content + }); + + if (seededUsers.Add(historyMessage.User.UserId)) { + dbContext.UserData.Add(new UserData { + Id = historyMessage.User.UserId, + FirstName = historyMessage.User.FirstName, + LastName = historyMessage.User.LastName, + UserName = historyMessage.User.UserName, + IsBot = historyMessage.User.IsBot, + IsPremium = historyMessage.User.IsPremium + }); + } + + foreach (var extension in historyMessage.Extensions) { + dbContext.MessageExtensions.Add(new MessageExtension { + MessageDataId = historyMessage.DataId, + Name = extension.Name, + Value = extension.Value + }); + } + } + + await dbContext.SaveChangesAsync(cancellationToken); + } + + private static LLMChannel ToEntity(AgentChannelConfig config) { + return new LLMChannel { + Id = config.ChannelId, + Name = config.Name, + Gateway = config.Gateway, + ApiKey = config.ApiKey, + Parallel = config.Parallel, + Priority = config.Priority, + Provider = config.Provider + }; + } + } +} diff --git a/TelegramSearchBot.LLMAgent/Service/ToolExecutor.cs b/TelegramSearchBot.LLMAgent/Service/ToolExecutor.cs new file mode 100644 index 00000000..3987ff53 --- /dev/null +++ b/TelegramSearchBot.LLMAgent/Service/ToolExecutor.cs @@ -0,0 +1,48 @@ +using System.Data; +using Newtonsoft.Json; +using TelegramSearchBot.Model.AI; + +namespace TelegramSearchBot.LLMAgent.Service { + public sealed class ToolExecutor { + private readonly GarnetClient _garnetClient; + private readonly GarnetRpcClient _rpcClient; + + public ToolExecutor(GarnetClient garnetClient, GarnetRpcClient rpcClient) { + _garnetClient = garnetClient; + _rpcClient = rpcClient; + } + + public Task EchoAsync(string text) => Task.FromResult(text); + + public Task CalculateAsync(string expression) { + var table = new DataTable(); + var result = table.Compute(expression, string.Empty); + return Task.FromResult(Convert.ToString(result, System.Globalization.CultureInfo.InvariantCulture) ?? string.Empty); + } + + public async Task SendMessageAsync(long chatId, string text, long userId, long messageId, CancellationToken cancellationToken) { + var task = new TelegramAgentToolTask { + ToolName = "send_message", + ChatId = chatId, + UserId = userId, + MessageId = messageId, + Arguments = new Dictionary(StringComparer.OrdinalIgnoreCase) { + ["chatId"] = chatId.ToString(), + ["text"] = text + } + }; + + await _garnetClient.RPushAsync(LlmAgentRedisKeys.TelegramTaskQueue, JsonConvert.SerializeObject(task)); + var result = await _rpcClient.WaitForTelegramResultAsync(task.RequestId, TimeSpan.FromSeconds(30), cancellationToken); + if (result == null) { + throw new TimeoutException("Timed out waiting for TELEGRAM_RESULT."); + } + + if (!result.Success) { + throw new InvalidOperationException(result.ErrorMessage); + } + + return string.IsNullOrWhiteSpace(result.Result) ? "ok" : result.Result; + } + } +} diff --git a/TelegramSearchBot.LLMAgent/TelegramSearchBot.LLMAgent.csproj b/TelegramSearchBot.LLMAgent/TelegramSearchBot.LLMAgent.csproj new file mode 100644 index 00000000..612881ff --- /dev/null +++ b/TelegramSearchBot.LLMAgent/TelegramSearchBot.LLMAgent.csproj @@ -0,0 +1,25 @@ + + + Exe + net10.0 + enable + enable + + + + + + + + + + + + + + + + + + + diff --git a/TelegramSearchBot.SubAgent/Program.cs b/TelegramSearchBot.SubAgent/Program.cs new file mode 100644 index 00000000..4935d139 --- /dev/null +++ b/TelegramSearchBot.SubAgent/Program.cs @@ -0,0 +1,37 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using StackExchange.Redis; + +namespace TelegramSearchBot.SubAgent { + public static class Program { + public static async Task Main(string[] args) { + var effectiveArgs = args.Length > 0 && args[0].Equals("SubAgent", StringComparison.OrdinalIgnoreCase) + ? args.Skip(1).ToArray() + : args; + + if (effectiveArgs.Length != 1 || !int.TryParse(effectiveArgs[0], out var port)) { + Console.Error.WriteLine("Usage: SubAgent "); + Environment.ExitCode = 1; + return; + } + + var services = new ServiceCollection(); + services.AddLogging(builder => builder.AddSimpleConsole(options => { + options.SingleLine = true; + options.TimestampFormat = "[yyyy-MM-dd HH:mm:ss] "; + })); + services.AddSingleton(_ => ConnectionMultiplexer.Connect($"localhost:{port},abortConnect=false,connectTimeout=5000,connectRetry=5")); + services.AddSingleton(); + + using var provider = services.BuildServiceProvider(); + using var shutdownCts = new CancellationTokenSource(); + Console.CancelKeyPress += (_, eventArgs) => { + eventArgs.Cancel = true; + shutdownCts.Cancel(); + }; + AppDomain.CurrentDomain.ProcessExit += (_, _) => shutdownCts.Cancel(); + + await provider.GetRequiredService().RunAsync(shutdownCts.Token); + } + } +} diff --git a/TelegramSearchBot.SubAgent/Service/SubAgentService.cs b/TelegramSearchBot.SubAgent/Service/SubAgentService.cs new file mode 100644 index 00000000..295a91dc --- /dev/null +++ b/TelegramSearchBot.SubAgent/Service/SubAgentService.cs @@ -0,0 +1,128 @@ +using System.Diagnostics; +using Microsoft.Extensions.Logging; +using Newtonsoft.Json; +using StackExchange.Redis; +using TelegramSearchBot.Model.AI; +using TelegramSearchBot.Model.Mcp; +using TelegramSearchBot.Service.Mcp; + +namespace TelegramSearchBot.SubAgent.Service { + public sealed class SubAgentService { + private readonly IConnectionMultiplexer _redis; + private readonly ILogger _logger; + private readonly ILoggerFactory _loggerFactory; + + public SubAgentService(IConnectionMultiplexer redis, ILogger logger, ILoggerFactory loggerFactory) { + _redis = redis; + _logger = logger; + _loggerFactory = loggerFactory; + } + + public async Task RunAsync(CancellationToken cancellationToken) { + while (!cancellationToken.IsCancellationRequested) { + var result = await _redis.GetDatabase().ExecuteAsync("BRPOP", LlmAgentRedisKeys.SubAgentTaskQueue, 5); + if (result.IsNull) { + continue; + } + + var items = (RedisResult[])result!; + if (items.Length != 2) { + continue; + } + + var payload = items[1].ToString(); + if (string.IsNullOrWhiteSpace(payload)) { + continue; + } + + var response = new SubAgentTaskResult { + Success = false + }; + try { + var task = JsonConvert.DeserializeObject(payload); + if (task == null) { + throw new InvalidOperationException("Invalid sub-agent payload."); + } + + response.RequestId = task.RequestId; + response.Result = task.Type switch { + "echo" => task.Payload, + "mcp_execute" => await ExecuteMcpAsync(task, cancellationToken), + "background_task" => await ExecuteBackgroundTaskAsync(task, cancellationToken), + _ => throw new InvalidOperationException($"unsupported:{task.Type}") + }; + response.Success = true; + } catch (Exception ex) { + _logger.LogError(ex, "SubAgent task failed"); + response.ErrorMessage = ex.Message; + } + + await _redis.GetDatabase().StringSetAsync( + LlmAgentRedisKeys.SubAgentResult(response.RequestId), + JsonConvert.SerializeObject(response), + TimeSpan.FromMinutes(5)); + } + } + + private async Task ExecuteMcpAsync(SubAgentTaskEnvelope task, CancellationToken cancellationToken) { + var request = task.McpExecute ?? throw new InvalidOperationException("Missing mcp_execute payload."); + var config = new McpServerConfig { + Name = request.ServerName, + Command = request.Command, + Args = request.Args, + Env = request.Env, + TimeoutSeconds = request.TimeoutSeconds + }; + + using var client = new McpClient(config, _loggerFactory.CreateLogger("SubAgent.McpClient")); + await client.ConnectAsync(cancellationToken); + var result = await client.CallToolAsync( + request.ToolName, + request.Arguments.ToDictionary(x => x.Key, x => x.Value ?? string.Empty), + cancellationToken); + await client.DisconnectAsync(); + + return JsonConvert.SerializeObject(result); + } + + private static async Task ExecuteBackgroundTaskAsync(SubAgentTaskEnvelope task, CancellationToken cancellationToken) { + var request = task.BackgroundTask ?? throw new InvalidOperationException("Missing background_task payload."); + var startInfo = new ProcessStartInfo { + FileName = request.Command, + RedirectStandardOutput = true, + RedirectStandardError = true, + UseShellExecute = false + }; + + if (!string.IsNullOrWhiteSpace(request.WorkingDirectory)) { + startInfo.WorkingDirectory = request.WorkingDirectory; + } + + foreach (var arg in request.Args) { + startInfo.ArgumentList.Add(arg); + } + + foreach (var env in request.Env) { + startInfo.Environment[env.Key] = env.Value; + } + + using var process = Process.Start(startInfo) ?? throw new InvalidOperationException("Failed to start background task process."); + using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + timeoutCts.CancelAfter(TimeSpan.FromSeconds(Math.Max(1, request.TimeoutSeconds))); + + var stdoutTask = process.StandardOutput.ReadToEndAsync(timeoutCts.Token); + var stderrTask = process.StandardError.ReadToEndAsync(timeoutCts.Token); + await process.WaitForExitAsync(timeoutCts.Token); + + var stdout = await stdoutTask; + var stderr = await stderrTask; + if (process.ExitCode != 0) { + throw new InvalidOperationException(string.IsNullOrWhiteSpace(stderr) + ? $"Background task exited with code {process.ExitCode}." + : stderr.Trim()); + } + + return string.IsNullOrWhiteSpace(stdout) ? stderr.Trim() : stdout.Trim(); + } + } +} diff --git a/TelegramSearchBot.SubAgent/TelegramSearchBot.SubAgent.csproj b/TelegramSearchBot.SubAgent/TelegramSearchBot.SubAgent.csproj new file mode 100644 index 00000000..b005206c --- /dev/null +++ b/TelegramSearchBot.SubAgent/TelegramSearchBot.SubAgent.csproj @@ -0,0 +1,22 @@ + + + Exe + net10.0 + enable + enable + + + + + + + + + + + + + + + + diff --git a/TelegramSearchBot.Test/Service/AI/LLM/AgentEnvCollection.cs b/TelegramSearchBot.Test/Service/AI/LLM/AgentEnvCollection.cs new file mode 100644 index 00000000..b57b0cab --- /dev/null +++ b/TelegramSearchBot.Test/Service/AI/LLM/AgentEnvCollection.cs @@ -0,0 +1,7 @@ +using Xunit; + +namespace TelegramSearchBot.Test.Service.AI.LLM { + [CollectionDefinition("AgentEnvSerial", DisableParallelization = true)] + public class AgentEnvCollection { + } +} diff --git a/TelegramSearchBot.Test/Service/AI/LLM/AgentIntegrationTests.cs b/TelegramSearchBot.Test/Service/AI/LLM/AgentIntegrationTests.cs new file mode 100644 index 00000000..6084c0eb --- /dev/null +++ b/TelegramSearchBot.Test/Service/AI/LLM/AgentIntegrationTests.cs @@ -0,0 +1,281 @@ +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Moq; +using Telegram.Bot.Types; +using Telegram.Bot.Types.Enums; +using TelegramSearchBot.Common; +using TelegramSearchBot.LLMAgent.Service; +using TelegramSearchBot.Model; +using TelegramSearchBot.Model.AI; +using TelegramSearchBot.Model.Data; +using TelegramSearchBot.Service.AI.LLM; +using Xunit; + +namespace TelegramSearchBot.Test.Service.AI.LLM { + [Collection("AgentEnvSerial")] + public class AgentIntegrationTests { + [Fact] + public async Task MessageFlow_EndToEnd_QueuesExecutesAndStreamsResponse() { + var originalFlag = Env.EnableLLMAgentProcess; + Env.EnableLLMAgentProcess = true; + try { + var harness = new InMemoryRedisTestHarness(); + await using var dbContext = CreateDbContext(); + SeedChannelAndGroup(dbContext, -1001, "gpt-endtoend"); + SeedAliveAgent(harness, -1001, 3456); + + var registry = CreateRegistry(harness); + var polling = new ChunkPollingService(harness.Connection.Object); + var queue = new LLMTaskQueueService(dbContext, harness.Connection.Object, polling, registry); + var loop = CreateAgentLoop(harness, CreateExecutor((task, _, _) => YieldSnapshotsAsync($"{task.InputMessage}-1", $"{task.InputMessage}-2"))); + + var handle = await queue.EnqueueMessageTaskAsync(CreateTelegramMessage(-1001, 11, 99, "hello-agent"), "bot", 999); + var payload = harness.PopFirstListValue(LlmAgentRedisKeys.AgentTaskQueue); + Assert.NotNull(payload); + var task = Newtonsoft.Json.JsonConvert.DeserializeObject(payload!); + Assert.NotNull(task); + + var snapshotsTask = ReadSnapshotsAsync(handle); + await loop.ProcessTaskAsync(task!, payload!, task!.ChatId, string.Empty, CancellationToken.None); + await DrainUntilCompletedAsync(polling, handle); + + var snapshots = await snapshotsTask; + var terminal = await handle.Completion; + + Assert.Equal(new[] { "hello-agent-1", "hello-agent-2" }, snapshots); + Assert.Equal(AgentChunkType.Done, terminal.Type); + } finally { + Env.EnableLLMAgentProcess = originalFlag; + } + } + + [Fact] + public async Task ConcurrentSessions_IsolateQueuedTasksAndStreams() { + var originalFlag = Env.EnableLLMAgentProcess; + Env.EnableLLMAgentProcess = true; + try { + var harness = new InMemoryRedisTestHarness(); + await using var dbContext = CreateDbContext(); + SeedChannelAndGroup(dbContext, -2001, "gpt-concurrent"); + SeedChannelAndGroup(dbContext, -2002, "gpt-concurrent"); + SeedAliveAgent(harness, -2001, 4001); + SeedAliveAgent(harness, -2002, 4002); + + var registry = CreateRegistry(harness); + var polling = new ChunkPollingService(harness.Connection.Object); + var queue = new LLMTaskQueueService(dbContext, harness.Connection.Object, polling, registry); + var executor = CreateExecutor((task, _, _) => YieldSnapshotsAsync($"chat:{task.ChatId}:1", $"chat:{task.ChatId}:2")); + var loop = CreateAgentLoop(harness, executor); + + var handle1 = await queue.EnqueueMessageTaskAsync(CreateTelegramMessage(-2001, 21, 1, "first"), "bot", 999); + var handle2 = await queue.EnqueueMessageTaskAsync(CreateTelegramMessage(-2002, 22, 2, "second"), "bot", 999); + var payload1 = harness.PopFirstListValue(LlmAgentRedisKeys.AgentTaskQueue); + var payload2 = harness.PopFirstListValue(LlmAgentRedisKeys.AgentTaskQueue); + Assert.NotNull(payload1); + Assert.NotNull(payload2); + var task1 = Newtonsoft.Json.JsonConvert.DeserializeObject(payload1!); + var task2 = Newtonsoft.Json.JsonConvert.DeserializeObject(payload2!); + Assert.NotNull(task1); + Assert.NotNull(task2); + + var snapshotsTask1 = ReadSnapshotsAsync(handle1); + var snapshotsTask2 = ReadSnapshotsAsync(handle2); + await Task.WhenAll( + loop.ProcessTaskAsync(task1!, payload1!, task1!.ChatId, string.Empty, CancellationToken.None), + loop.ProcessTaskAsync(task2!, payload2!, task2!.ChatId, string.Empty, CancellationToken.None)); + await DrainUntilCompletedAsync(polling, handle1, handle2); + + Assert.Equal(new[] { "chat:-2001:1", "chat:-2001:2" }, await snapshotsTask1); + Assert.Equal(new[] { "chat:-2002:1", "chat:-2002:2" }, await snapshotsTask2); + } finally { + Env.EnableLLMAgentProcess = originalFlag; + } + } + + [Fact] + public async Task RecoveryFlow_RequeuesTimedOutTaskAndCompletesOnRetry() { + var originalFlag = Env.EnableLLMAgentProcess; + var originalTimeout = Env.AgentHeartbeatTimeoutSeconds; + Env.EnableLLMAgentProcess = true; + Env.AgentHeartbeatTimeoutSeconds = 1; + try { + var harness = new InMemoryRedisTestHarness(); + await using var dbContext = CreateDbContext(); + SeedChannelAndGroup(dbContext, -3001, "gpt-recovery"); + SeedAliveAgent(harness, -3001, 5001); + + var launcher = new Mock(); + launcher.Setup(l => l.TryKill(It.IsAny())).Returns(true); + launcher.Setup(l => l.StartAsync(It.IsAny(), It.IsAny())) + .Returns((chatId, _) => { + SeedAliveAgent(harness, chatId, 7777); + return Task.FromResult(7777); + }); + var registry = new AgentRegistryService(harness.Connection.Object, launcher.Object, Mock.Of>()); + var polling = new ChunkPollingService(harness.Connection.Object); + var queue = new LLMTaskQueueService(dbContext, harness.Connection.Object, polling, registry); + var loop = CreateAgentLoop(harness, CreateExecutor((task, _, _) => YieldSnapshotsAsync($"recovered:{task.InputMessage}"))); + + var handle = await queue.EnqueueMessageTaskAsync(CreateTelegramMessage(-3001, 31, 3, "recover-me"), "bot", 999); + var payload = harness.PopFirstListValue(LlmAgentRedisKeys.AgentTaskQueue); + Assert.NotNull(payload); + var task = Newtonsoft.Json.JsonConvert.DeserializeObject(payload!); + Assert.NotNull(task); + SeedAliveAgent(harness, -3001, 5001, "processing", task!.TaskId, DateTime.UtcNow.AddMinutes(-10)); + harness.SetHash(LlmAgentRedisKeys.AgentTaskState(task!.TaskId), new Dictionary(StringComparer.OrdinalIgnoreCase) { + ["status"] = AgentTaskStatus.Running.ToString(), + ["chatId"] = task.ChatId.ToString(), + ["messageId"] = task.MessageId.ToString(), + ["modelName"] = task.ModelName, + ["createdAtUtc"] = task.CreatedAtUtc.ToString("O"), + ["updatedAtUtc"] = DateTime.UtcNow.AddMinutes(-10).ToString("O"), + ["payload"] = payload!, + ["recoveryCount"] = "0", + ["maxRecoveryAttempts"] = Env.AgentMaxRecoveryAttempts.ToString(), + ["lastContent"] = string.Empty + }); + await registry.GetSessionAsync(task.ChatId); + + await registry.RunMaintenanceOnceAsync(); + var requeuedPayload = harness.PopFirstListValue(LlmAgentRedisKeys.AgentTaskQueue); + Assert.NotNull(requeuedPayload); + var requeuedTask = Newtonsoft.Json.JsonConvert.DeserializeObject(requeuedPayload!); + Assert.NotNull(requeuedTask); + + var snapshotsTask = ReadSnapshotsAsync(handle); + await loop.ProcessTaskAsync(requeuedTask!, requeuedPayload!, requeuedTask!.ChatId, string.Empty, CancellationToken.None); + await DrainUntilCompletedAsync(polling, handle); + + Assert.Equal(new[] { "recovered:recover-me" }, await snapshotsTask); + launcher.Verify(l => l.StartAsync(task.ChatId, It.IsAny()), Times.Once); + } finally { + Env.EnableLLMAgentProcess = originalFlag; + Env.AgentHeartbeatTimeoutSeconds = originalTimeout; + } + } + + [Fact] + public async Task ConfigToggle_DisabledModeRequestsAgentDrain() { + var originalFlag = Env.EnableLLMAgentProcess; + Env.EnableLLMAgentProcess = false; + try { + var harness = new InMemoryRedisTestHarness(); + SeedAliveAgent(harness, -4001, 8888); + var registry = CreateRegistry(harness); + await registry.GetSessionAsync(-4001); + + await registry.RunMaintenanceOnceAsync(); + + var command = harness.GetString(LlmAgentRedisKeys.AgentControl(-4001)); + Assert.NotNull(command); + Assert.Contains("\"Action\":\"shutdown\"", command); + } finally { + Env.EnableLLMAgentProcess = originalFlag; + } + } + + private static DataDbContext CreateDbContext() { + var options = new DbContextOptionsBuilder() + .UseInMemoryDatabase($"AgentIntegrationTests_{Guid.NewGuid():N}") + .Options; + return new DataDbContext(options); + } + + private static void SeedChannelAndGroup(DataDbContext dbContext, long groupId, string modelName) { + var channelId = (int)Math.Abs(groupId % int.MaxValue); + var channel = new LLMChannel { + Id = channelId, + Name = $"channel-{groupId}", + Gateway = "https://example.invalid", + ApiKey = "key", + Provider = LLMProvider.OpenAI, + Parallel = 1, + Priority = 10 + }; + dbContext.LLMChannels.Add(channel); + dbContext.ChannelsWithModel.Add(new ChannelWithModel { + Id = channelId, + LLMChannelId = channelId, + LLMChannel = channel, + ModelName = modelName, + IsDeleted = false + }); + dbContext.GroupSettings.Add(new GroupSettings { + GroupId = groupId, + LLMModelName = modelName + }); + dbContext.SaveChanges(); + } + + private static void SeedAliveAgent(InMemoryRedisTestHarness harness, long chatId, int processId, string status = "idle", string currentTaskId = "", DateTime? heartbeatUtc = null) { + var heartbeat = heartbeatUtc ?? DateTime.UtcNow; + harness.SetHash(LlmAgentRedisKeys.AgentSession(chatId), new Dictionary(StringComparer.OrdinalIgnoreCase) { + ["chatId"] = chatId.ToString(), + ["processId"] = processId.ToString(), + ["port"] = Env.SchedulerPort.ToString(), + ["status"] = status, + ["currentTaskId"] = currentTaskId, + ["startedAtUtc"] = heartbeat.ToString("O"), + ["lastHeartbeatUtc"] = heartbeat.ToString("O"), + ["lastActiveAtUtc"] = heartbeat.ToString("O"), + ["error"] = string.Empty + }); + } + + private static AgentRegistryService CreateRegistry(InMemoryRedisTestHarness harness) { + return new AgentRegistryService( + harness.Connection.Object, + Mock.Of(), + Mock.Of>()); + } + + private static AgentLoopService CreateAgentLoop(InMemoryRedisTestHarness harness, IAgentTaskExecutor executor) { + var services = new ServiceCollection(); + services.AddScoped(_ => executor); + var provider = services.BuildServiceProvider(); + return new AgentLoopService( + provider, + new GarnetClient(harness.Connection.Object), + new GarnetRpcClient(harness.Connection.Object), + Mock.Of>()); + } + + private static Telegram.Bot.Types.Message CreateTelegramMessage(long chatId, int messageId, long userId, string text) { + return new Telegram.Bot.Types.Message { + Id = messageId, + Date = DateTime.UtcNow, + Text = text, + Chat = new Chat { Id = chatId, Type = ChatType.Group }, + From = new User { Id = userId, FirstName = "Tester" } + }; + } + + private static async Task> ReadSnapshotsAsync(AgentTaskStreamHandle handle) { + var results = new List(); + await foreach (var snapshot in handle.ReadSnapshotsAsync()) { + results.Add(snapshot); + } + + return results; + } + + private static async Task DrainUntilCompletedAsync(ChunkPollingService polling, params AgentTaskStreamHandle[] handles) { + for (var i = 0; i < 50 && handles.Any(h => !h.Completion.IsCompleted); i++) { + await polling.RunPollCycleAsync(); + await Task.Delay(10); + } + } + + private static FakeAgentTaskExecutor CreateExecutor(Func> handler) { + return new FakeAgentTaskExecutor(handler); + } + + private static async IAsyncEnumerable YieldSnapshotsAsync(params string[] snapshots) { + foreach (var snapshot in snapshots) { + yield return snapshot; + await Task.Yield(); + } + } + } +} diff --git a/TelegramSearchBot.Test/Service/AI/LLM/AgentRegistryServiceTests.cs b/TelegramSearchBot.Test/Service/AI/LLM/AgentRegistryServiceTests.cs new file mode 100644 index 00000000..11180c1c --- /dev/null +++ b/TelegramSearchBot.Test/Service/AI/LLM/AgentRegistryServiceTests.cs @@ -0,0 +1,138 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Moq; +using StackExchange.Redis; +using TelegramSearchBot.Common; +using TelegramSearchBot.Model.AI; +using TelegramSearchBot.Service.AI.LLM; +using Xunit; + +namespace TelegramSearchBot.Test.Service.AI.LLM { + [Collection("AgentEnvSerial")] + public class AgentRegistryServiceTests { + [Fact] + public async Task EnsureAgentAsync_WhenAgentModeDisabled_Throws() { + var originalFlag = Env.EnableLLMAgentProcess; + Env.EnableLLMAgentProcess = false; + + try { + var (service, _, _, _, _) = CreateService(); + await Assert.ThrowsAsync(() => service.EnsureAgentAsync(55)); + } finally { + Env.EnableLLMAgentProcess = originalFlag; + } + } + + [Fact] + public async Task EnsureAgentAsync_WhenAliveSessionExists_DoesNotStartNewProcess() { + var originalFlag = Env.EnableLLMAgentProcess; + Env.EnableLLMAgentProcess = true; + + try { + var (service, hashes, _, _, launcherMock) = CreateService(); + TrackAliveSession(service, hashes, 77); + + await service.EnsureAgentAsync(77); + + launcherMock.Verify(l => l.StartAsync(It.IsAny(), It.IsAny()), Times.Never); + } finally { + Env.EnableLLMAgentProcess = originalFlag; + } + } + + private static void TrackAliveSession(AgentRegistryService service, Dictionary> hashes, long chatId) { + var session = new AgentSessionInfo { + ChatId = chatId, + ProcessId = 1234, + Status = "idle", + LastHeartbeatUtc = DateTime.UtcNow, + LastActiveAtUtc = DateTime.UtcNow + }; + + var sessionsDict = typeof(AgentRegistryService) + .GetField("_knownSessions", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)! + .GetValue(service) as System.Collections.Concurrent.ConcurrentDictionary; + if (sessionsDict == null) { + throw new InvalidOperationException("Unable to access known sessions."); + } + + sessionsDict[chatId] = session; + hashes[LlmAgentRedisKeys.AgentSession(chatId)] = new Dictionary(StringComparer.OrdinalIgnoreCase) { + ["chatId"] = chatId.ToString(), + ["processId"] = "1234", + ["port"] = "0", + ["status"] = "idle", + ["currentTaskId"] = string.Empty, + ["lastHeartbeatUtc"] = DateTime.UtcNow.ToString("O"), + ["lastActiveAtUtc"] = DateTime.UtcNow.ToString("O") + }; + } + + private static (AgentRegistryService service, Dictionary> hashes, Dictionary> lists, Dictionary strings, Mock launcherMock) CreateService() { + var hashes = new Dictionary>(StringComparer.OrdinalIgnoreCase); + var lists = new Dictionary>(StringComparer.OrdinalIgnoreCase); + var strings = new Dictionary(StringComparer.OrdinalIgnoreCase); + var redisMock = new Mock(); + var dbMock = new Mock(); + redisMock.Setup(r => r.GetDatabase(It.IsAny(), It.IsAny())).Returns(dbMock.Object); + + dbMock.Setup(d => d.HashGetAllAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((RedisKey key, CommandFlags _) => hashes.TryGetValue(key.ToString(), out var values) + ? values.Select(entry => new HashEntry(entry.Key, entry.Value)).ToArray() + : []); + dbMock.Setup(d => d.HashSetAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Callback((key, entries, _) => { + if (!hashes.TryGetValue(key.ToString(), out var values)) { + values = new Dictionary(StringComparer.OrdinalIgnoreCase); + hashes[key.ToString()] = values; + } + + foreach (var entry in entries) { + values[entry.Name.ToString()] = entry.Value.ToString(); + } + }) + .Returns(Task.CompletedTask); + dbMock.Setup(d => d.StringSetAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Callback((key, value, _, _, _, _) => strings[key.ToString()] = value.ToString()) + .ReturnsAsync(true); + dbMock.Setup(d => d.StringSetAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Callback((key, value, _, _, _) => strings[key.ToString()] = value.ToString()) + .ReturnsAsync(true); + dbMock.Setup(d => d.KeyDeleteAsync(It.IsAny(), It.IsAny())) + .Callback((key, _) => { + hashes.Remove(key.ToString()); + strings.Remove(key.ToString()); + lists.Remove(key.ToString()); + }) + .ReturnsAsync(true); + dbMock.Setup(d => d.KeyExpireAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .ReturnsAsync(true); + dbMock.Setup(d => d.ListLeftPushAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Callback((key, value, _, _) => { + if (!lists.TryGetValue(key.ToString(), out var values)) { + values = []; + lists[key.ToString()] = values; + } + + values.Insert(0, value.ToString()); + }) + .ReturnsAsync(1); + dbMock.Setup(d => d.ListLengthAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((RedisKey key, CommandFlags _) => lists.TryGetValue(key.ToString(), out var values) ? values.Count : 0); + + var launcherMock = new Mock(); + launcherMock.Setup(l => l.TryKill(It.IsAny())).Returns(true); + launcherMock.Setup(l => l.StartAsync(It.IsAny(), It.IsAny())).ReturnsAsync(9999); + + var service = new AgentRegistryService( + redisMock.Object, + launcherMock.Object, + Mock.Of>()); + return (service, hashes, lists, strings, launcherMock); + } + } +} diff --git a/TelegramSearchBot.Test/Service/AI/LLM/ChunkPollingServiceTests.cs b/TelegramSearchBot.Test/Service/AI/LLM/ChunkPollingServiceTests.cs new file mode 100644 index 00000000..ccaf2c98 --- /dev/null +++ b/TelegramSearchBot.Test/Service/AI/LLM/ChunkPollingServiceTests.cs @@ -0,0 +1,43 @@ +using System; +using System.Threading.Tasks; +using Moq; +using StackExchange.Redis; +using TelegramSearchBot.Model.AI; +using TelegramSearchBot.Service.AI.LLM; +using Xunit; + +namespace TelegramSearchBot.Test.Service.AI.LLM { + public class ChunkPollingServiceTests { + [Fact] + public async Task RunPollCycleAsync_CompletesTrackedTaskWhenTaskStateFails() { + var redisMock = new Mock(); + var dbMock = new Mock(); + redisMock.Setup(r => r.GetDatabase(It.IsAny(), It.IsAny())).Returns(dbMock.Object); + + dbMock.Setup(d => d.ListRangeAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny())) + .ReturnsAsync([]); + dbMock.Setup(d => d.HashGetAllAsync( + It.Is(key => key == LlmAgentRedisKeys.AgentTaskState("task-1")), + It.IsAny())) + .ReturnsAsync([ + new HashEntry("status", AgentTaskStatus.Failed.ToString()), + new HashEntry("error", "boom") + ]); + dbMock.Setup(d => d.KeyDeleteAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(true); + + var service = new ChunkPollingService(redisMock.Object); + var handle = service.TrackTask("task-1"); + + await service.RunPollCycleAsync(); + var terminal = await handle.Completion; + + Assert.Equal(AgentChunkType.Error, terminal.Type); + Assert.Equal("boom", terminal.ErrorMessage); + } + } +} diff --git a/TelegramSearchBot.Test/Service/AI/LLM/FakeAgentTaskExecutor.cs b/TelegramSearchBot.Test/Service/AI/LLM/FakeAgentTaskExecutor.cs new file mode 100644 index 00000000..5c936ad5 --- /dev/null +++ b/TelegramSearchBot.Test/Service/AI/LLM/FakeAgentTaskExecutor.cs @@ -0,0 +1,16 @@ +using TelegramSearchBot.LLMAgent.Service; +using TelegramSearchBot.Model.AI; + +namespace TelegramSearchBot.Test.Service.AI.LLM { + internal sealed class FakeAgentTaskExecutor : IAgentTaskExecutor { + private readonly Func> _handler; + + public FakeAgentTaskExecutor(Func> handler) { + _handler = handler; + } + + public IAsyncEnumerable CallAsync(AgentExecutionTask task, LlmExecutionContext executionContext, CancellationToken cancellationToken) { + return _handler(task, executionContext, cancellationToken); + } + } +} diff --git a/TelegramSearchBot.Test/Service/AI/LLM/InMemoryRedisTestHarness.cs b/TelegramSearchBot.Test/Service/AI/LLM/InMemoryRedisTestHarness.cs new file mode 100644 index 00000000..db0d47e5 --- /dev/null +++ b/TelegramSearchBot.Test/Service/AI/LLM/InMemoryRedisTestHarness.cs @@ -0,0 +1,165 @@ +using System.Collections.Concurrent; +using Moq; +using StackExchange.Redis; + +namespace TelegramSearchBot.Test.Service.AI.LLM { + internal sealed class InMemoryRedisTestHarness { + private readonly object _gate = new(); + private readonly ConcurrentDictionary> _lists = new(StringComparer.OrdinalIgnoreCase); + private readonly ConcurrentDictionary> _hashes = new(StringComparer.OrdinalIgnoreCase); + private readonly ConcurrentDictionary _strings = new(StringComparer.OrdinalIgnoreCase); + + public InMemoryRedisTestHarness() { + Database = new Mock(MockBehavior.Strict); + Connection = new Mock(MockBehavior.Strict); + Connection.Setup(r => r.GetDatabase(It.IsAny(), It.IsAny())).Returns(Database.Object); + + Database.Setup(d => d.ListLeftPushAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .ReturnsAsync((RedisKey key, RedisValue value, When _, CommandFlags _) => { + lock (_gate) { + var list = _lists.GetOrAdd(key.ToString(), _ => []); + list.Insert(0, value.ToString()); + return list.Count; + } + }); + + Database.Setup(d => d.ListRightPushAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .ReturnsAsync((RedisKey key, RedisValue value, When _, CommandFlags _) => { + lock (_gate) { + var list = _lists.GetOrAdd(key.ToString(), _ => []); + list.Add(value.ToString()); + return list.Count; + } + }); + + Database.Setup(d => d.ListRangeAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .ReturnsAsync((RedisKey key, long start, long stop, CommandFlags _) => { + lock (_gate) { + if (!_lists.TryGetValue(key.ToString(), out var list)) { + return Array.Empty(); + } + + var normalizedStart = (int)Math.Max(0, start); + var normalizedStop = stop < 0 ? list.Count - 1 : (int)Math.Min(stop, list.Count - 1); + if (normalizedStart > normalizedStop || normalizedStart >= list.Count) { + return Array.Empty(); + } + + return list.Skip(normalizedStart).Take(normalizedStop - normalizedStart + 1).Select(x => (RedisValue)x).ToArray(); + } + }); + + Database.Setup(d => d.ListLengthAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((RedisKey key, CommandFlags _) => { + lock (_gate) { + return _lists.TryGetValue(key.ToString(), out var list) ? list.Count : 0; + } + }); + + Database.Setup(d => d.HashSetAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns((RedisKey key, HashEntry[] entries, CommandFlags _) => { + lock (_gate) { + var hash = _hashes.GetOrAdd(key.ToString(), _ => new Dictionary(StringComparer.OrdinalIgnoreCase)); + foreach (var entry in entries) { + hash[entry.Name.ToString()] = entry.Value.ToString(); + } + } + + return Task.CompletedTask; + }); + + Database.Setup(d => d.HashSetAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .ReturnsAsync((RedisKey key, RedisValue field, RedisValue value, When _, CommandFlags _) => { + lock (_gate) { + var hash = _hashes.GetOrAdd(key.ToString(), _ => new Dictionary(StringComparer.OrdinalIgnoreCase)); + hash[field.ToString()] = value.ToString(); + } + + return true; + }); + + Database.Setup(d => d.HashGetAllAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((RedisKey key, CommandFlags _) => { + lock (_gate) { + return _hashes.TryGetValue(key.ToString(), out var hash) + ? hash.Select(entry => new HashEntry(entry.Key, entry.Value)).ToArray() + : Array.Empty(); + } + }); + + Database.Setup(d => d.StringSetAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .ReturnsAsync((RedisKey key, RedisValue value, TimeSpan? _, When _, CommandFlags _) => { + _strings[key.ToString()] = value.ToString(); + return true; + }); + + Database.Setup(d => d.StringSetAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .ReturnsAsync((RedisKey key, RedisValue value, TimeSpan? _, When _) => { + _strings[key.ToString()] = value.ToString(); + return true; + }); + + Database.Setup(d => d.StringSetAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .ReturnsAsync((RedisKey key, RedisValue value, TimeSpan? _, bool _, When _, CommandFlags _) => { + _strings[key.ToString()] = value.ToString(); + return true; + }); + + Database.Setup(d => d.StringGetAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((RedisKey key, CommandFlags _) => _strings.TryGetValue(key.ToString(), out var value) ? (RedisValue)value : RedisValue.Null); + + Database.Setup(d => d.KeyDeleteAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync((RedisKey key, CommandFlags flags) => { + var removed = false; + removed |= _strings.TryRemove(key.ToString(), out _); + removed |= _lists.TryRemove(key.ToString(), out _); + removed |= _hashes.TryRemove(key.ToString(), out _); + return removed; + }); + + Database.Setup(d => d.KeyExpireAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .ReturnsAsync(true); + } + + public Mock Connection { get; } + public Mock Database { get; } + + public string? PeekFirstListValue(string key) { + lock (_gate) { + return _lists.TryGetValue(key, out var list) && list.Count > 0 ? list[0] : null; + } + } + + public string? PopFirstListValue(string key) { + lock (_gate) { + if (!_lists.TryGetValue(key, out var list) || list.Count == 0) { + return null; + } + + var value = list[0]; + list.RemoveAt(0); + return value; + } + } + + public IReadOnlyList GetListValues(string key) { + lock (_gate) { + return _lists.TryGetValue(key, out var list) ? list.ToList() : []; + } + } + + public IReadOnlyDictionary GetHash(string key) { + lock (_gate) { + return _hashes.TryGetValue(key, out var hash) + ? new Dictionary(hash, StringComparer.OrdinalIgnoreCase) + : new Dictionary(StringComparer.OrdinalIgnoreCase); + } + } + + public void SetHash(string key, IDictionary values) { + _hashes[key] = new Dictionary(values, StringComparer.OrdinalIgnoreCase); + } + + public string? GetString(string key) => _strings.TryGetValue(key, out var value) ? value : null; + } +} diff --git a/TelegramSearchBot.Test/Service/AI/LLM/LLMTaskQueueServiceTests.cs b/TelegramSearchBot.Test/Service/AI/LLM/LLMTaskQueueServiceTests.cs new file mode 100644 index 00000000..a24084e9 --- /dev/null +++ b/TelegramSearchBot.Test/Service/AI/LLM/LLMTaskQueueServiceTests.cs @@ -0,0 +1,123 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.Logging; +using Moq; +using StackExchange.Redis; +using TelegramSearchBot.Common; +using TelegramSearchBot.Model; +using TelegramSearchBot.Model.AI; +using TelegramSearchBot.Model.Data; +using TelegramSearchBot.Service.AI.LLM; +using Xunit; + +namespace TelegramSearchBot.Test.Service.AI.LLM { + [Collection("AgentEnvSerial")] + public class LLMTaskQueueServiceTests { + [Fact] + public async Task EnqueueContinuationTaskAsync_PersistsPayloadAndRecoveryMetadata() { + var originalFlag = Env.EnableLLMAgentProcess; + Env.EnableLLMAgentProcess = true; + + try { + await using var dbContext = CreateDbContext(); + SeedChannel(dbContext, 321, "gpt-test"); + + var redisMock = new Mock(); + var dbMock = new Mock(); + redisMock.Setup(r => r.GetDatabase(It.IsAny(), It.IsAny())).Returns(dbMock.Object); + + dbMock.Setup(d => d.HashGetAllAsync( + It.Is(key => key == LlmAgentRedisKeys.AgentSession(123)), + It.IsAny())) + .ReturnsAsync([ + new HashEntry("chatId", 123), + new HashEntry("processId", 999), + new HashEntry("port", 0), + new HashEntry("status", "idle"), + new HashEntry("lastHeartbeatUtc", DateTime.UtcNow.ToString("O")), + new HashEntry("lastActiveAtUtc", DateTime.UtcNow.ToString("O")) + ]); + + string pushedPayload = string.Empty; + HashEntry[] persistedState = []; + dbMock.Setup(d => d.ListLeftPushAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny())) + .Callback((_, value, _, _) => pushedPayload = value.ToString()) + .ReturnsAsync(1); + dbMock.Setup(d => d.HashSetAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Callback((_, entries, _) => persistedState = entries) + .Returns(Task.CompletedTask); + + var registry = new AgentRegistryService( + redisMock.Object, + Mock.Of(), + Mock.Of>()); + var polling = new ChunkPollingService(redisMock.Object); + var service = new LLMTaskQueueService(dbContext, redisMock.Object, polling, registry); + + var snapshot = new LlmContinuationSnapshot { + ChatId = 123, + OriginalMessageId = 456, + UserId = 789, + ModelName = "gpt-test", + Provider = "OpenAI", + ChannelId = 321, + LastAccumulatedContent = "partial" + }; + + var handle = await service.EnqueueContinuationTaskAsync(snapshot, "bot", 1001); + + Assert.NotNull(handle); + Assert.Contains("\"ChatId\":123", pushedPayload); + Assert.Contains("\"Kind\":1", pushedPayload); + + var state = persistedState.ToDictionary(x => x.Name.ToString(), x => x.Value.ToString(), StringComparer.OrdinalIgnoreCase); + Assert.Equal(AgentTaskStatus.Pending.ToString(), state["status"]); + Assert.Equal("0", state["recoveryCount"]); + Assert.Equal(Env.AgentMaxRecoveryAttempts.ToString(), state["maxRecoveryAttempts"]); + Assert.Equal(pushedPayload, state["payload"]); + } finally { + Env.EnableLLMAgentProcess = originalFlag; + } + } + + private static DataDbContext CreateDbContext() { + var options = new DbContextOptionsBuilder() + .UseInMemoryDatabase($"LLMTaskQueueServiceTests_{Guid.NewGuid():N}") + .Options; + return new DataDbContext(options); + } + + private static void SeedChannel(DataDbContext dbContext, int channelId, string modelName) { + var channel = new LLMChannel { + Id = channelId, + Name = "test-channel", + Gateway = "https://example.invalid", + ApiKey = "key", + Provider = LLMProvider.OpenAI, + Parallel = 1, + Priority = 10 + }; + var channelWithModel = new ChannelWithModel { + Id = 1, + LLMChannelId = channelId, + LLMChannel = channel, + ModelName = modelName, + IsDeleted = false, + Capabilities = new List { + new() { Id = 1, ChannelWithModelId = 1, CapabilityName = "function_calling", CapabilityValue = "true", Description = "enabled" } + } + }; + + dbContext.LLMChannels.Add(channel); + dbContext.ChannelsWithModel.Add(channelWithModel); + dbContext.SaveChanges(); + } + } +} diff --git a/TelegramSearchBot.Test/TelegramSearchBot.Test.csproj b/TelegramSearchBot.Test/TelegramSearchBot.Test.csproj index d9e2226f..16752449 100644 --- a/TelegramSearchBot.Test/TelegramSearchBot.Test.csproj +++ b/TelegramSearchBot.Test/TelegramSearchBot.Test.csproj @@ -25,6 +25,7 @@ + diff --git a/TelegramSearchBot.sln b/TelegramSearchBot.sln index 49774e1f..2585a3fd 100644 --- a/TelegramSearchBot.sln +++ b/TelegramSearchBot.sln @@ -31,6 +31,10 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TelegramSearchBot.LLM", "Te EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TelegramSearchBot.LLM.Test", "TelegramSearchBot.LLM.Test\TelegramSearchBot.LLM.Test.csproj", "{60EB5F23-139A-4BD3-96F7-DE3A1C3BE8E1}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TelegramSearchBot.LLMAgent", "TelegramSearchBot.LLMAgent\TelegramSearchBot.LLMAgent.csproj", "{AF9FCEE4-D58F-4183-8C0E-7D872F2668C2}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TelegramSearchBot.SubAgent", "TelegramSearchBot.SubAgent\TelegramSearchBot.SubAgent.csproj", "{EA1B5C5E-088E-4D4D-8835-95D10625B72A}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -137,6 +141,30 @@ Global {60EB5F23-139A-4BD3-96F7-DE3A1C3BE8E1}.Release|x64.Build.0 = Release|Any CPU {60EB5F23-139A-4BD3-96F7-DE3A1C3BE8E1}.Release|x86.ActiveCfg = Release|Any CPU {60EB5F23-139A-4BD3-96F7-DE3A1C3BE8E1}.Release|x86.Build.0 = Release|Any CPU + {AF9FCEE4-D58F-4183-8C0E-7D872F2668C2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {AF9FCEE4-D58F-4183-8C0E-7D872F2668C2}.Debug|Any CPU.Build.0 = Debug|Any CPU + {AF9FCEE4-D58F-4183-8C0E-7D872F2668C2}.Debug|x64.ActiveCfg = Debug|Any CPU + {AF9FCEE4-D58F-4183-8C0E-7D872F2668C2}.Debug|x64.Build.0 = Debug|Any CPU + {AF9FCEE4-D58F-4183-8C0E-7D872F2668C2}.Debug|x86.ActiveCfg = Debug|Any CPU + {AF9FCEE4-D58F-4183-8C0E-7D872F2668C2}.Debug|x86.Build.0 = Debug|Any CPU + {AF9FCEE4-D58F-4183-8C0E-7D872F2668C2}.Release|Any CPU.ActiveCfg = Release|Any CPU + {AF9FCEE4-D58F-4183-8C0E-7D872F2668C2}.Release|Any CPU.Build.0 = Release|Any CPU + {AF9FCEE4-D58F-4183-8C0E-7D872F2668C2}.Release|x64.ActiveCfg = Release|Any CPU + {AF9FCEE4-D58F-4183-8C0E-7D872F2668C2}.Release|x64.Build.0 = Release|Any CPU + {AF9FCEE4-D58F-4183-8C0E-7D872F2668C2}.Release|x86.ActiveCfg = Release|Any CPU + {AF9FCEE4-D58F-4183-8C0E-7D872F2668C2}.Release|x86.Build.0 = Release|Any CPU + {EA1B5C5E-088E-4D4D-8835-95D10625B72A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {EA1B5C5E-088E-4D4D-8835-95D10625B72A}.Debug|Any CPU.Build.0 = Debug|Any CPU + {EA1B5C5E-088E-4D4D-8835-95D10625B72A}.Debug|x64.ActiveCfg = Debug|Any CPU + {EA1B5C5E-088E-4D4D-8835-95D10625B72A}.Debug|x64.Build.0 = Debug|Any CPU + {EA1B5C5E-088E-4D4D-8835-95D10625B72A}.Debug|x86.ActiveCfg = Debug|Any CPU + {EA1B5C5E-088E-4D4D-8835-95D10625B72A}.Debug|x86.Build.0 = Debug|Any CPU + {EA1B5C5E-088E-4D4D-8835-95D10625B72A}.Release|Any CPU.ActiveCfg = Release|Any CPU + {EA1B5C5E-088E-4D4D-8835-95D10625B72A}.Release|Any CPU.Build.0 = Release|Any CPU + {EA1B5C5E-088E-4D4D-8835-95D10625B72A}.Release|x64.ActiveCfg = Release|Any CPU + {EA1B5C5E-088E-4D4D-8835-95D10625B72A}.Release|x64.Build.0 = Release|Any CPU + {EA1B5C5E-088E-4D4D-8835-95D10625B72A}.Release|x86.ActiveCfg = Release|Any CPU + {EA1B5C5E-088E-4D4D-8835-95D10625B72A}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/TelegramSearchBot/AppBootstrap/AppBootstrap.cs b/TelegramSearchBot/AppBootstrap/AppBootstrap.cs index 5847e421..6120d48c 100644 --- a/TelegramSearchBot/AppBootstrap/AppBootstrap.cs +++ b/TelegramSearchBot/AppBootstrap/AppBootstrap.cs @@ -2,7 +2,6 @@ using System.Collections.Generic; using System.ComponentModel; using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Reflection; using System.Reflection.Metadata; @@ -17,56 +16,76 @@ namespace TelegramSearchBot.AppBootstrap { public class AppBootstrap { public sealed class ChildProcessManager : IDisposable { - private SafeJobHandle? _handle; + private readonly List _handles = []; private bool _disposed; - public ChildProcessManager() { - _handle = new SafeJobHandle(CreateJobObject(IntPtr.Zero, null)); - - var info = new JOBOBJECT_BASIC_LIMIT_INFORMATION { - LimitFlags = 0x2000 - }; - - var extendedInfo = new JOBOBJECT_EXTENDED_LIMIT_INFORMATION { - BasicLimitInformation = info - }; - - var length = Marshal.SizeOf(typeof(JOBOBJECT_EXTENDED_LIMIT_INFORMATION)); - var extendedInfoPtr = Marshal.AllocHGlobal(length); - Marshal.StructureToPtr(extendedInfo, extendedInfoPtr, false); - - if (!SetInformationJobObject(_handle, JobObjectInfoType.ExtendedLimitInformation, extendedInfoPtr, ( uint ) length)) { - throw new InvalidOperationException("Unable to set information", new Win32Exception()); - } - } - public void Dispose() { if (_disposed) return; - _handle?.Dispose(); - _handle = null; + foreach (var handle in _handles) { + handle.Dispose(); + } + _handles.Clear(); _disposed = true; } - [MemberNotNull(nameof(_handle))] private void ValidateDisposed() { - ObjectDisposedException.ThrowIf(_disposed || _handle is null, this); + ObjectDisposedException.ThrowIf(_disposed, this); } - public void AddProcess(SafeProcessHandle processHandle) { + public void AddProcess(SafeProcessHandle processHandle, long? processMemoryLimitBytes = null) { ValidateDisposed(); - if (!AssignProcessToJobObject(_handle, processHandle)) { + var jobHandle = CreateConfiguredJobHandle(processMemoryLimitBytes); + if (!AssignProcessToJobObject(jobHandle, processHandle)) { + jobHandle.Dispose(); throw new InvalidOperationException("Unable to add the process"); } + _handles.Add(jobHandle); } - public void AddProcess(Process process) { - AddProcess(process.SafeHandle); + public void AddProcess(Process process, long? processMemoryLimitBytes = null) { + AddProcess(process.SafeHandle, processMemoryLimitBytes); } - public void AddProcess(int processId) { + public void AddProcess(int processId, long? processMemoryLimitBytes = null) { using var process = Process.GetProcessById(processId); - AddProcess(process); + AddProcess(process, processMemoryLimitBytes); + } + + private static SafeJobHandle CreateConfiguredJobHandle(long? processMemoryLimitBytes) { + var handle = new SafeJobHandle(CreateJobObject(IntPtr.Zero, null)); + if (handle.IsInvalid) { + throw new InvalidOperationException("Unable to create job object", new Win32Exception()); + } + + var limitFlags = JobObjectLimitFlags.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE; + if (processMemoryLimitBytes.HasValue && processMemoryLimitBytes.Value > 0) { + limitFlags |= JobObjectLimitFlags.JOB_OBJECT_LIMIT_PROCESS_MEMORY; + } + + var info = new JOBOBJECT_BASIC_LIMIT_INFORMATION { + LimitFlags = (uint)limitFlags + }; + + var extendedInfo = new JOBOBJECT_EXTENDED_LIMIT_INFORMATION { + BasicLimitInformation = info, + ProcessMemoryLimit = processMemoryLimitBytes.HasValue && processMemoryLimitBytes.Value > 0 + ? (UIntPtr)processMemoryLimitBytes.Value + : UIntPtr.Zero + }; + + var length = Marshal.SizeOf(typeof(JOBOBJECT_EXTENDED_LIMIT_INFORMATION)); + var extendedInfoPtr = Marshal.AllocHGlobal(length); + try { + Marshal.StructureToPtr(extendedInfo, extendedInfoPtr, false); + if (!SetInformationJobObject(handle, JobObjectInfoType.ExtendedLimitInformation, extendedInfoPtr, (uint)length)) { + throw new InvalidOperationException("Unable to set information", new Win32Exception()); + } + } finally { + Marshal.FreeHGlobal(extendedInfoPtr); + } + + return handle; } private sealed class SafeJobHandle : SafeHandleZeroOrMinusOneIsInvalid { @@ -140,10 +159,16 @@ private enum JobObjectInfoType { SecurityLimitInformation = 5, GroupInformation = 11 } + + [Flags] + private enum JobObjectLimitFlags : uint { + JOB_OBJECT_LIMIT_PROCESS_MEMORY = 0x00000100, + JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE = 0x00002000 + } } public static ChildProcessManager childProcessManager = new ChildProcessManager(); - public static void Fork(string[] args) { + public static void Fork(string[] args, long? processMemoryLimitBytes = null) { string exePath = Environment.ProcessPath; // 将参数数组转换为空格分隔的字符串,并正确处理包含空格的参数 @@ -162,7 +187,7 @@ public static void Fork(string[] args) { if (newProcess == null) { throw new Exception("启动新进程失败"); } - childProcessManager.AddProcess(newProcess); + childProcessManager.AddProcess(newProcess, processMemoryLimitBytes); Log.Logger.Information($"主进程:{args[0]} {args[1]}已启动"); } private static Dictionary ForkLock = new Dictionary(); @@ -181,7 +206,7 @@ public static async Task RateLimitForkAsync(string[] args) { } } - public static Process Fork(string exePath, string[] args) { + public static Process Fork(string exePath, string[] args, long? processMemoryLimitBytes = null) { // 将参数数组转换为空格分隔的字符串,并正确处理包含空格的参数 string arguments = string.Join(" ", args.Select(arg => $"{arg}")); @@ -198,7 +223,7 @@ public static Process Fork(string exePath, string[] args) { if (newProcess == null) { throw new Exception("启动新进程失败"); } - childProcessManager.AddProcess(newProcess); + childProcessManager.AddProcess(newProcess, processMemoryLimitBytes); Log.Logger.Information($"进程:{exePath} {string.Join(" ", args)}已启动"); return newProcess; } diff --git a/TelegramSearchBot/AppBootstrap/LLMAgentBootstrap.cs b/TelegramSearchBot/AppBootstrap/LLMAgentBootstrap.cs new file mode 100644 index 00000000..64e7db59 --- /dev/null +++ b/TelegramSearchBot/AppBootstrap/LLMAgentBootstrap.cs @@ -0,0 +1,32 @@ +using System; +using System.Diagnostics; +using System.IO; +using System.Linq; +using Serilog; + +namespace TelegramSearchBot.AppBootstrap { + public class LLMAgentBootstrap : AppBootstrap { + public static void Startup(string[] args) { + try { + var effectiveArgs = args.Length > 0 && args[0].Equals("LLMAgent", StringComparison.OrdinalIgnoreCase) + ? args.Skip(1).ToArray() + : args; + var dllPath = Path.Combine(AppContext.BaseDirectory, "TelegramSearchBot.LLMAgent.dll"); + if (!File.Exists(dllPath)) { + throw new FileNotFoundException("LLMAgent executable not found.", dllPath); + } + + using var process = Process.Start(new ProcessStartInfo { + FileName = "dotnet", + Arguments = $"\"{dllPath}\" {string.Join(" ", effectiveArgs)}", + UseShellExecute = false + }); + process?.WaitForExit(); + Environment.ExitCode = process?.ExitCode ?? 1; + } catch (Exception ex) { + Log.Error(ex, "LLMAgent startup failed."); + Environment.ExitCode = 1; + } + } + } +} diff --git a/TelegramSearchBot/AppBootstrap/SubAgentBootstrap.cs b/TelegramSearchBot/AppBootstrap/SubAgentBootstrap.cs new file mode 100644 index 00000000..3d635643 --- /dev/null +++ b/TelegramSearchBot/AppBootstrap/SubAgentBootstrap.cs @@ -0,0 +1,32 @@ +using System; +using System.Diagnostics; +using System.IO; +using System.Linq; +using Serilog; + +namespace TelegramSearchBot.AppBootstrap { + public class SubAgentBootstrap : AppBootstrap { + public static void Startup(string[] args) { + try { + var effectiveArgs = args.Length > 0 && args[0].Equals("SubAgent", StringComparison.OrdinalIgnoreCase) + ? args.Skip(1).ToArray() + : args; + var dllPath = Path.Combine(AppContext.BaseDirectory, "TelegramSearchBot.SubAgent.dll"); + if (!File.Exists(dllPath)) { + throw new FileNotFoundException("SubAgent executable not found.", dllPath); + } + + using var process = Process.Start(new ProcessStartInfo { + FileName = "dotnet", + Arguments = $"\"{dllPath}\" {string.Join(" ", effectiveArgs)}", + UseShellExecute = false + }); + process?.WaitForExit(); + Environment.ExitCode = process?.ExitCode ?? 1; + } catch (Exception ex) { + Log.Error(ex, "SubAgent startup failed."); + Environment.ExitCode = 1; + } + } + } +} diff --git a/TelegramSearchBot/Controller/AI/LLM/AgentMonitorController.cs b/TelegramSearchBot/Controller/AI/LLM/AgentMonitorController.cs new file mode 100644 index 00000000..f22fd95f --- /dev/null +++ b/TelegramSearchBot/Controller/AI/LLM/AgentMonitorController.cs @@ -0,0 +1,88 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using StackExchange.Redis; +using TelegramSearchBot.Interface; +using TelegramSearchBot.Interface.Controller; +using TelegramSearchBot.Model; +using TelegramSearchBot.Service.AI.LLM; +using TelegramSearchBot.Service.Manage; + +namespace TelegramSearchBot.Controller.AI.LLM { + public class AgentMonitorController : IOnUpdate { + private readonly AdminService _adminService; + private readonly AgentRegistryService _agentRegistryService; + private readonly IConnectionMultiplexer _redis; + private readonly ISendMessageService _sendMessageService; + + public AgentMonitorController( + AdminService adminService, + AgentRegistryService agentRegistryService, + IConnectionMultiplexer redis, + ISendMessageService sendMessageService) { + _adminService = adminService; + _agentRegistryService = agentRegistryService; + _redis = redis; + _sendMessageService = sendMessageService; + } + + public List Dependencies => new(); + + public async Task ExecuteAsync(PipelineContext p) { + var message = p.Update?.Message; + var text = message?.Text; + if (message?.From == null || string.IsNullOrWhiteSpace(text) || !text.StartsWith("/agent", StringComparison.OrdinalIgnoreCase)) { + return; + } + + if (!await _adminService.IsNormalAdmin(message.From.Id) && !_adminService.IsGlobalAdmin(message.From.Id)) { + return; + } + + if (text.Equals("/agent list", StringComparison.OrdinalIgnoreCase)) { + var sessions = await _agentRegistryService.ListActiveAsync(); + if (!sessions.Any()) { + await _sendMessageService.SendMessage("当前没有活跃的 LLM Agent。", message.Chat.Id, message.MessageId); + return; + } + + var sb = new StringBuilder(); + sb.AppendLine("活跃 LLM Agent:"); + foreach (var session in sessions) { + sb.AppendLine($"- ChatId={session.ChatId}, PID={session.ProcessId}, Status={session.Status}, LastHeartbeat={session.LastHeartbeatUtc:O}"); + } + + await _sendMessageService.SendMessage(sb.ToString(), message.Chat.Id, message.MessageId); + return; + } + + if (text.Equals("/agent stats", StringComparison.OrdinalIgnoreCase)) { + var db = _redis.GetDatabase(); + var pending = await db.ListLengthAsync(TelegramSearchBot.Model.AI.LlmAgentRedisKeys.AgentTaskQueue); + var telegramTasks = await db.ListLengthAsync(TelegramSearchBot.Model.AI.LlmAgentRedisKeys.TelegramTaskQueue); + var deadLetter = await db.ListLengthAsync(TelegramSearchBot.Model.AI.LlmAgentRedisKeys.AgentTaskDeadLetterQueue); + var sessions = await _agentRegistryService.ListActiveAsync(); + var processing = sessions.Count(x => !string.IsNullOrWhiteSpace(x.CurrentTaskId)); + var stats = $"Agents={sessions.Count}\nProcessingAgents={processing}\nPendingAgentTasks={pending}\nPendingTelegramTasks={telegramTasks}\nDeadLetterTasks={deadLetter}"; + await _sendMessageService.SendMessage(stats, message.Chat.Id, message.MessageId); + return; + } + + if (text.StartsWith("/agent kill ", StringComparison.OrdinalIgnoreCase)) { + var suffix = text.Substring("/agent kill ".Length).Trim(); + if (!long.TryParse(suffix, out var chatId)) { + await _sendMessageService.SendMessage("用法:/agent kill ", message.Chat.Id, message.MessageId); + return; + } + + var killed = await _agentRegistryService.TryKillAsync(chatId); + await _sendMessageService.SendMessage( + killed ? $"已终止 chatId={chatId} 的 Agent。" : $"无法终止 chatId={chatId} 的 Agent(可能不存在或仍在处理任务)。", + message.Chat.Id, + message.MessageId); + } + } + } +} diff --git a/TelegramSearchBot/Controller/AI/LLM/GeneralLLMController.cs b/TelegramSearchBot/Controller/AI/LLM/GeneralLLMController.cs index 8de420c1..b3f58e96 100644 --- a/TelegramSearchBot/Controller/AI/LLM/GeneralLLMController.cs +++ b/TelegramSearchBot/Controller/AI/LLM/GeneralLLMController.cs @@ -32,6 +32,7 @@ public class GeneralLLMController : IOnUpdate { public ISendMessageService SendMessageService { get; set; } public IGeneralLLMService GeneralLLMService { get; set; } public ILlmContinuationService ContinuationService { get; set; } + public LLMTaskQueueService LlmTaskQueueService { get; set; } public GeneralLLMController( MessageService messageService, ITelegramBotClient botClient, @@ -41,7 +42,8 @@ public GeneralLLMController( AdminService adminService, ISendMessageService SendMessageService, IGeneralLLMService generalLLMService, - ILlmContinuationService continuationService + ILlmContinuationService continuationService, + LLMTaskQueueService llmTaskQueueService ) { this.logger = logger; this.botClient = botClient; @@ -52,6 +54,7 @@ ILlmContinuationService continuationService this.SendMessageService = SendMessageService; GeneralLLMService = generalLLMService; ContinuationService = continuationService; + LlmTaskQueueService = llmTaskQueueService; } public async Task ExecuteAsync(PipelineContext p) { @@ -116,8 +119,19 @@ public async Task ExecuteAsync(PipelineContext p) { // Use execution context to detect iteration limit (no stream pollution) var executionContext = new LlmExecutionContext(); - IAsyncEnumerable fullMessageStream = GeneralLLMService.ExecAsync( - inputLlMessage, e.Message.Chat.Id, executionContext, CancellationToken.None); + IAsyncEnumerable fullMessageStream; + AgentTaskStreamHandle? agentTaskHandle = null; + if (Env.EnableLLMAgentProcess) { + agentTaskHandle = await LlmTaskQueueService.EnqueueMessageTaskAsync( + e.Message, + service.BotName, + Env.BotId, + CancellationToken.None); + fullMessageStream = agentTaskHandle.ReadSnapshotsAsync(CancellationToken.None); + } else { + fullMessageStream = GeneralLLMService.ExecAsync( + inputLlMessage, e.Message.Chat.Id, executionContext, CancellationToken.None); + } // Use sendMessageDraft API for LLM streaming (better performance, no send+edit) List sentMessagesForDb = await SendMessageService.SendDraftStream( @@ -146,6 +160,31 @@ await messageService.ExecuteAsync(new MessageOption() { }); } + if (agentTaskHandle != null) { + var terminalChunk = await agentTaskHandle.Completion; + if (terminalChunk.Type == AgentChunkType.Error) { + await SendMessageService.SendMessage($"AI Agent 执行失败:{terminalChunk.ErrorMessage}", e.Message.Chat.Id, e.Message.MessageId); + } else if (terminalChunk.Type == AgentChunkType.IterationLimitReached && terminalChunk.ContinuationSnapshot != null) { + var snapshotId = await ContinuationService.SaveSnapshotAsync(terminalChunk.ContinuationSnapshot); + + var keyboard = new InlineKeyboardMarkup(new[] { + new[] { + InlineKeyboardButton.WithCallbackData("✅ 继续迭代", $"llm_continue:{snapshotId}"), + InlineKeyboardButton.WithCallbackData("❌ 停止", $"llm_stop:{snapshotId}"), + } + }); + + await botClient.SendMessage( + e.Message.Chat.Id, + $"⚠️ AI 已达到最大迭代次数限制({Env.MaxToolCycles} 次),是否继续迭代?", + replyMarkup: keyboard, + replyParameters: new ReplyParameters { MessageId = e.Message.MessageId } + ); + } + + return; + } + // Check if the iteration limit was reached via execution context if (executionContext.IterationLimitReached && executionContext.SnapshotData != null) { logger.LogInformation("Iteration limit reached for ChatId {ChatId}, MessageId {MessageId}. Saving snapshot and prompting user.", diff --git a/TelegramSearchBot/Controller/AI/LLM/LLMIterationCallbackController.cs b/TelegramSearchBot/Controller/AI/LLM/LLMIterationCallbackController.cs index e51db817..2af4c7c0 100644 --- a/TelegramSearchBot/Controller/AI/LLM/LLMIterationCallbackController.cs +++ b/TelegramSearchBot/Controller/AI/LLM/LLMIterationCallbackController.cs @@ -36,6 +36,7 @@ public class LLMIterationCallbackController : IOnUpdate { private readonly ISendMessageService _sendMessageService; private readonly MessageService _messageService; private readonly ILlmContinuationService _continuationService; + private readonly LLMTaskQueueService _llmTaskQueueService; public List Dependencies => new List(); @@ -45,13 +46,15 @@ public LLMIterationCallbackController( IGeneralLLMService generalLLMService, ISendMessageService sendMessageService, MessageService messageService, - ILlmContinuationService continuationService) { + ILlmContinuationService continuationService, + LLMTaskQueueService llmTaskQueueService) { _logger = logger; _botClient = botClient; _generalLLMService = generalLLMService; _sendMessageService = sendMessageService; _messageService = messageService; _continuationService = continuationService; + _llmTaskQueueService = llmTaskQueueService; } public async Task ExecuteAsync(PipelineContext p) { @@ -146,8 +149,20 @@ await _botClient.EditMessageReplyMarkup( // Resume with full context — yields only NEW content (not re-sending old history) var executionContext = new LlmExecutionContext(); - IAsyncEnumerable resumeStream = _generalLLMService.ResumeFromSnapshotAsync( - snapshot, executionContext, CancellationToken.None); + AgentTaskStreamHandle? agentTaskHandle = null; + IAsyncEnumerable resumeStream; + if (Env.EnableLLMAgentProcess) { + var me = await _botClient.GetMe(); + agentTaskHandle = await _llmTaskQueueService.EnqueueContinuationTaskAsync( + snapshot, + me.Username ?? string.Empty, + me.Id, + CancellationToken.None); + resumeStream = agentTaskHandle.ReadSnapshotsAsync(CancellationToken.None); + } else { + resumeStream = _generalLLMService.ResumeFromSnapshotAsync( + snapshot, executionContext, CancellationToken.None); + } var initialContent = $"{snapshot.ModelName} 继续迭代中..."; // Use SendDraftStream for continuation — only new content is streamed @@ -181,6 +196,31 @@ await _messageService.ExecuteAsync(new MessageOption { // Delete the used snapshot await _continuationService.DeleteSnapshotAsync(snapshotId); + if (agentTaskHandle != null) { + var terminalChunk = await agentTaskHandle.Completion; + if (terminalChunk.Type == AgentChunkType.Error) { + await _sendMessageService.SendMessage($"AI Agent 执行失败:{terminalChunk.ErrorMessage}", snapshot.ChatId, (int)snapshot.OriginalMessageId); + } else if (terminalChunk.Type == AgentChunkType.IterationLimitReached && terminalChunk.ContinuationSnapshot != null) { + var newSnapshotId = await _continuationService.SaveSnapshotAsync(terminalChunk.ContinuationSnapshot); + + var keyboard = new InlineKeyboardMarkup(new[] { + new[] { + InlineKeyboardButton.WithCallbackData("✅ 继续迭代", $"llm_continue:{newSnapshotId}"), + InlineKeyboardButton.WithCallbackData("❌ 停止", $"llm_stop:{newSnapshotId}"), + } + }); + + await _botClient.SendMessage( + snapshot.ChatId, + $"⚠️ AI 再次达到最大迭代次数限制({Env.MaxToolCycles} 次),是否继续迭代?", + replyMarkup: keyboard, + replyParameters: new ReplyParameters { MessageId = (int)snapshot.OriginalMessageId } + ); + } + + return; + } + // If iteration limit reached again, save new snapshot and show prompt if (executionContext.IterationLimitReached && executionContext.SnapshotData != null) { _logger.LogInformation("Iteration limit reached again after continuation for ChatId {ChatId}.", snapshot.ChatId); diff --git a/TelegramSearchBot/Extension/ServiceCollectionExtension.cs b/TelegramSearchBot/Extension/ServiceCollectionExtension.cs index f7135663..34bdffbb 100644 --- a/TelegramSearchBot/Extension/ServiceCollectionExtension.cs +++ b/TelegramSearchBot/Extension/ServiceCollectionExtension.cs @@ -27,6 +27,7 @@ using TelegramSearchBot.Manager; using TelegramSearchBot.Model; using TelegramSearchBot.Search.Tool; +using TelegramSearchBot.Service.AI.LLM; using TelegramSearchBot.Service.BotAPI; using TelegramSearchBot.Service.Storage; using TelegramSearchBot.View; @@ -63,6 +64,12 @@ public static IServiceCollection AddCoreServices(this IServiceCollection service .AddSingleton() .AddHostedService() .AddHostedService() + .AddSingleton() + .AddHostedService(sp => sp.GetRequiredService()) + .AddSingleton() + .AddSingleton() + .AddHostedService(sp => sp.GetRequiredService()) + .AddHostedService() .AddSingleton>(sp => sp.GetRequiredService().Log) .AddSingleton() .AddSingleton() diff --git a/TelegramSearchBot/Service/AI/LLM/AgentRegistryService.cs b/TelegramSearchBot/Service/AI/LLM/AgentRegistryService.cs new file mode 100644 index 00000000..4e292865 --- /dev/null +++ b/TelegramSearchBot/Service/AI/LLM/AgentRegistryService.cs @@ -0,0 +1,350 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics; +using System.Globalization; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Newtonsoft.Json; +using StackExchange.Redis; +using TelegramSearchBot.Common; +using TelegramSearchBot.Model.AI; + +namespace TelegramSearchBot.Service.AI.LLM { + public sealed class AgentRegistryService : BackgroundService { + private readonly IConnectionMultiplexer _redis; + private readonly IAgentProcessLauncher _processLauncher; + private readonly ILogger _logger; + private readonly ConcurrentDictionary _knownSessions = new(); + + public AgentRegistryService( + IConnectionMultiplexer redis, + IAgentProcessLauncher processLauncher, + ILogger logger) { + _redis = redis; + _processLauncher = processLauncher; + _logger = logger; + } + + public async Task EnsureAgentAsync(long chatId, CancellationToken cancellationToken = default(CancellationToken)) { + if (!Env.EnableLLMAgentProcess) { + throw new InvalidOperationException("LLM agent process mode is currently disabled."); + } + + var existing = await GetSessionAsync(chatId); + if (existing != null && await IsAliveAsync(existing) && !IsShuttingDown(existing)) { + return; + } + + if (existing == null && await CountLiveSessionsAsync() >= Env.MaxConcurrentAgents) { + throw new InvalidOperationException($"当前 Agent 数量已达到上限 {Env.MaxConcurrentAgents}。"); + } + + var processId = await _processLauncher.StartAsync(chatId, cancellationToken); + _knownSessions[chatId] = new AgentSessionInfo { + ChatId = chatId, + ProcessId = processId, + Port = Env.SchedulerPort, + Status = "starting" + }; + + var startedAt = DateTime.UtcNow; + while (DateTime.UtcNow - startedAt < TimeSpan.FromSeconds(10) && !cancellationToken.IsCancellationRequested) { + if (await IsAliveAsync(chatId)) { + return; + } + + await Task.Delay(200, cancellationToken); + } + } + + public async Task> ListActiveAsync() { + var result = new List(); + foreach (var chatId in _knownSessions.Keys.ToArray()) { + var session = await GetSessionAsync(chatId); + if (session != null && await IsAliveAsync(session)) { + result.Add(session); + } + } + + return result.OrderBy(x => x.ChatId).ToList(); + } + + public async Task GetSessionAsync(long chatId) { + var entries = await _redis.GetDatabase().HashGetAllAsync(LlmAgentRedisKeys.AgentSession(chatId)); + if (entries.Length == 0) { + _knownSessions.TryRemove(chatId, out _); + return null; + } + + var session = new AgentSessionInfo { + ChatId = chatId, + ProcessId = ParseInt(entries, "processId"), + Port = ParseInt(entries, "port"), + Status = Parse(entries, "status"), + CurrentTaskId = Parse(entries, "currentTaskId"), + ErrorMessage = Parse(entries, "error"), + StartedAtUtc = ParseDate(entries, "startedAtUtc"), + LastHeartbeatUtc = ParseDate(entries, "lastHeartbeatUtc"), + LastActiveAtUtc = ParseDate(entries, "lastActiveAtUtc"), + ShutdownRequestedAtUtc = ParseDate(entries, "shutdownRequestedAtUtc") + }; + _knownSessions[chatId] = session; + return session; + } + + public async Task IsAliveAsync(long chatId) { + var session = await GetSessionAsync(chatId); + return await IsAliveAsync(session); + } + + public Task RequestShutdownAsync(long chatId, string reason) { + var command = new AgentControlCommand { + ChatId = chatId, + Action = "shutdown", + Reason = reason, + RequestedAtUtc = DateTime.UtcNow + }; + + return _redis.GetDatabase().StringSetAsync( + LlmAgentRedisKeys.AgentControl(chatId), + JsonConvert.SerializeObject(command), + TimeSpan.FromSeconds(Math.Max(Env.AgentShutdownGracePeriodSeconds * 2, 30)), + When.Always); + } + + public async Task IsAliveAsync(AgentSessionInfo? session) { + if (session == null) { + return false; + } + + return DateTime.UtcNow - session.LastHeartbeatUtc <= TimeSpan.FromSeconds(Env.AgentHeartbeatTimeoutSeconds); + } + + public async Task TryKillAsync(long chatId) { + var session = await GetSessionAsync(chatId); + if (session == null || !string.IsNullOrWhiteSpace(session.CurrentTaskId)) { + return false; + } + + if (!_processLauncher.TryKill(session.ProcessId)) { + return false; + } + + await CleanupSessionAsync(chatId); + return true; + } + + public async Task RunMaintenanceOnceAsync(CancellationToken cancellationToken = default) { + var db = _redis.GetDatabase(); + var backlog = await db.ListLengthAsync(LlmAgentRedisKeys.AgentTaskQueue); + if (Env.AgentQueueBacklogWarningThreshold > 0 && backlog >= Env.AgentQueueBacklogWarningThreshold) { + _logger.LogWarning("LLM agent backlog is high: {Backlog}", backlog); + } + + if (!Env.EnableLLMAgentProcess) { + foreach (var chatId in _knownSessions.Keys.ToArray()) { + await RequestShutdownAsync(chatId, "agent mode disabled"); + } + } + + foreach (var entry in _knownSessions.ToArray()) { + var session = await GetSessionAsync(entry.Key); + if (session == null) { + _knownSessions.TryRemove(entry.Key, out _); + continue; + } + + if (!await IsAliveAsync(session)) { + await RecoverSessionAsync(session, "heartbeat timeout", cancellationToken); + continue; + } + + if (await IsTaskTimedOutAsync(session)) { + _processLauncher.TryKill(session.ProcessId); + await RecoverSessionAsync(session, "task timeout", cancellationToken); + continue; + } + + if (ShouldRequestIdleShutdown(session)) { + await RequestShutdownAsync(session.ChatId, "idle timeout"); + session.ShutdownRequestedAtUtc = DateTime.UtcNow; + session.Status = "shutting_down"; + await SaveSessionAsync(session); + } else if (ShouldForceShutdown(session)) { + _processLauncher.TryKill(session.ProcessId); + await CleanupSessionAsync(session.ChatId); + } + } + } + + protected override async Task ExecuteAsync(CancellationToken stoppingToken) { + while (!stoppingToken.IsCancellationRequested) { + await RunMaintenanceOnceAsync(stoppingToken); + await Task.Delay(TimeSpan.FromSeconds(Math.Max(5, Env.AgentHeartbeatIntervalSeconds)), stoppingToken); + } + } + + private async Task CountLiveSessionsAsync() { + var count = 0; + foreach (var chatId in _knownSessions.Keys.ToArray()) { + var session = await GetSessionAsync(chatId); + if (session != null && await IsAliveAsync(session) && !IsShuttingDown(session)) { + count++; + } + } + + return count; + } + + private async Task IsTaskTimedOutAsync(AgentSessionInfo session) { + if (string.IsNullOrWhiteSpace(session.CurrentTaskId) || Env.AgentTaskTimeoutSeconds <= 0) { + return false; + } + + var entries = await _redis.GetDatabase().HashGetAllAsync(LlmAgentRedisKeys.AgentTaskState(session.CurrentTaskId)); + if (entries.Length == 0) { + return false; + } + + var updatedAt = ParseDate(entries, "updatedAtUtc"); + if (updatedAt == DateTime.MinValue) { + updatedAt = ParseDate(entries, "startedAtUtc"); + } + + return updatedAt != DateTime.MinValue && + DateTime.UtcNow - updatedAt > TimeSpan.FromSeconds(Env.AgentTaskTimeoutSeconds); + } + + private async Task RecoverSessionAsync(AgentSessionInfo session, string reason, CancellationToken cancellationToken) { + _logger.LogWarning("Recovering LLM agent session for chat {ChatId}: {Reason}", session.ChatId, reason); + await CleanupSessionAsync(session.ChatId); + + if (string.IsNullOrWhiteSpace(session.CurrentTaskId)) { + return; + } + + var db = _redis.GetDatabase(); + var entries = await db.HashGetAllAsync(LlmAgentRedisKeys.AgentTaskState(session.CurrentTaskId)); + if (entries.Length == 0) { + return; + } + + var status = Parse(entries, "status"); + if (status.Equals(AgentTaskStatus.Completed.ToString(), StringComparison.OrdinalIgnoreCase) || + status.Equals(AgentTaskStatus.Failed.ToString(), StringComparison.OrdinalIgnoreCase) || + status.Equals(AgentTaskStatus.Cancelled.ToString(), StringComparison.OrdinalIgnoreCase)) { + return; + } + + var payload = Parse(entries, "payload"); + var lastContent = Parse(entries, "lastContent"); + var recoveryCount = int.TryParse(Parse(entries, "recoveryCount"), out var parsedRecoveryCount) + ? parsedRecoveryCount + : 0; + var maxRecoveryAttempts = int.TryParse(Parse(entries, "maxRecoveryAttempts"), out var parsedMaxAttempts) + ? parsedMaxAttempts + : Env.AgentMaxRecoveryAttempts; + + if (string.IsNullOrWhiteSpace(payload) || recoveryCount >= maxRecoveryAttempts) { + await db.HashSetAsync(LlmAgentRedisKeys.AgentTaskState(session.CurrentTaskId), [ + new HashEntry("status", AgentTaskStatus.Failed.ToString()), + new HashEntry("error", reason), + new HashEntry("updatedAtUtc", DateTime.UtcNow.ToString("O")) + ]); + await db.ListLeftPushAsync(LlmAgentRedisKeys.AgentTaskDeadLetterQueue, JsonConvert.SerializeObject(new AgentDeadLetterEntry { + TaskId = session.CurrentTaskId, + ChatId = session.ChatId, + Reason = reason, + RecoveryAttempt = recoveryCount, + Payload = payload ?? string.Empty, + LastContent = lastContent ?? string.Empty + })); + return; + } + + var task = JsonConvert.DeserializeObject(payload); + if (task != null) { + task.RecoveryAttempt = recoveryCount + 1; + payload = JsonConvert.SerializeObject(task); + } + + await db.HashSetAsync(LlmAgentRedisKeys.AgentTaskState(session.CurrentTaskId), [ + new HashEntry("status", AgentTaskStatus.Recovering.ToString()), + new HashEntry("error", reason), + new HashEntry("updatedAtUtc", DateTime.UtcNow.ToString("O")), + new HashEntry("recoveryCount", recoveryCount + 1), + new HashEntry("payload", payload) + ]); + await db.ListLeftPushAsync(LlmAgentRedisKeys.AgentTaskQueue, payload); + + if (Env.EnableLLMAgentProcess) { + try { + await EnsureAgentAsync(session.ChatId, cancellationToken); + } catch (Exception ex) { + _logger.LogError(ex, "Failed to respawn LLM agent for chat {ChatId}", session.ChatId); + } + } + } + + private async Task CleanupSessionAsync(long chatId) { + await _redis.GetDatabase().KeyDeleteAsync(LlmAgentRedisKeys.AgentSession(chatId)); + await _redis.GetDatabase().KeyDeleteAsync(LlmAgentRedisKeys.AgentControl(chatId)); + _knownSessions.TryRemove(chatId, out _); + } + + private async Task SaveSessionAsync(AgentSessionInfo session) { + await _redis.GetDatabase().HashSetAsync(LlmAgentRedisKeys.AgentSession(session.ChatId), [ + new HashEntry("chatId", session.ChatId), + new HashEntry("processId", session.ProcessId), + new HashEntry("port", session.Port), + new HashEntry("status", session.Status), + new HashEntry("currentTaskId", session.CurrentTaskId ?? string.Empty), + new HashEntry("startedAtUtc", session.StartedAtUtc.ToString("O")), + new HashEntry("lastHeartbeatUtc", session.LastHeartbeatUtc.ToString("O")), + new HashEntry("lastActiveAtUtc", session.LastActiveAtUtc.ToString("O")), + new HashEntry("shutdownRequestedAtUtc", session.ShutdownRequestedAtUtc == DateTime.MinValue ? string.Empty : session.ShutdownRequestedAtUtc.ToString("O")), + new HashEntry("error", session.ErrorMessage ?? string.Empty) + ]); + await _redis.GetDatabase().KeyExpireAsync( + LlmAgentRedisKeys.AgentSession(session.ChatId), + TimeSpan.FromSeconds(Math.Max(Env.AgentHeartbeatTimeoutSeconds * 2, 30))); + } + + private static bool IsShuttingDown(AgentSessionInfo session) { + return session.Status.Equals("shutting_down", StringComparison.OrdinalIgnoreCase); + } + + private static bool ShouldRequestIdleShutdown(AgentSessionInfo session) { + return Env.AgentIdleTimeoutMinutes > 0 && + string.IsNullOrWhiteSpace(session.CurrentTaskId) && + !IsShuttingDown(session) && + session.LastActiveAtUtc != DateTime.MinValue && + DateTime.UtcNow - session.LastActiveAtUtc > TimeSpan.FromMinutes(Env.AgentIdleTimeoutMinutes); + } + + private static bool ShouldForceShutdown(AgentSessionInfo session) { + return IsShuttingDown(session) && + session.ShutdownRequestedAtUtc != DateTime.MinValue && + DateTime.UtcNow - session.ShutdownRequestedAtUtc > TimeSpan.FromSeconds(Math.Max(5, Env.AgentShutdownGracePeriodSeconds)); + } + + private static string Parse(HashEntry[] entries, string key) { + return entries.FirstOrDefault(x => x.Name == key).Value.ToString(); + } + + private static int ParseInt(HashEntry[] entries, string key) { + return int.TryParse(Parse(entries, key), out var value) ? value : 0; + } + + private static DateTime ParseDate(HashEntry[] entries, string key) { + return DateTime.TryParse(Parse(entries, key), CultureInfo.InvariantCulture, DateTimeStyles.RoundtripKind, out var value) + ? value + : DateTime.MinValue; + } + } +} diff --git a/TelegramSearchBot/Service/AI/LLM/ChunkPollingService.cs b/TelegramSearchBot/Service/AI/LLM/ChunkPollingService.cs new file mode 100644 index 00000000..ed437738 --- /dev/null +++ b/TelegramSearchBot/Service/AI/LLM/ChunkPollingService.cs @@ -0,0 +1,139 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using System.Threading.Channels; +using Microsoft.Extensions.Hosting; +using Newtonsoft.Json; +using StackExchange.Redis; +using TelegramSearchBot.Common; +using TelegramSearchBot.Model.AI; + +namespace TelegramSearchBot.Service.AI.LLM { + public sealed class AgentTaskStreamHandle { + private readonly Channel _channel; + + internal AgentTaskStreamHandle(Channel channel, Task completion) { + _channel = channel; + Completion = completion; + } + + public Task Completion { get; } + + public async IAsyncEnumerable ReadSnapshotsAsync( + [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default) { + await foreach (var chunk in _channel.Reader.ReadAllAsync(cancellationToken)) { + if (chunk.Type is AgentChunkType.Snapshot or AgentChunkType.IterationLimitReached) { + yield return chunk.Content; + } + } + } + } + + public sealed class ChunkPollingService : BackgroundService { + private readonly IConnectionMultiplexer _redis; + private readonly ConcurrentDictionary _trackedTasks = new(StringComparer.OrdinalIgnoreCase); + + public ChunkPollingService(IConnectionMultiplexer redis) { + _redis = redis; + } + + public AgentTaskStreamHandle TrackTask(string taskId) { + var tracked = _trackedTasks.GetOrAdd(taskId, _ => new TrackedTask()); + return tracked.Handle; + } + + public async Task RunPollCycleAsync(CancellationToken cancellationToken = default) { + foreach (var entry in _trackedTasks.ToArray()) { + await PollTaskAsync(entry.Key, entry.Value, cancellationToken); + } + } + + protected override async Task ExecuteAsync(CancellationToken stoppingToken) { + while (!stoppingToken.IsCancellationRequested) { + await RunPollCycleAsync(stoppingToken); + await Task.Delay(Math.Max(50, Env.AgentChunkPollingIntervalMilliseconds), stoppingToken); + } + } + + private async Task PollTaskAsync(string taskId, TrackedTask tracked, CancellationToken cancellationToken) { + var values = await _redis.GetDatabase().ListRangeAsync(LlmAgentRedisKeys.AgentChunks(taskId), tracked.NextIndex, -1); + if (values.Length == 0) { + await TryCompleteFromTaskStateAsync(taskId, tracked, cancellationToken); + return; + } + + foreach (var value in values) { + var chunk = JsonConvert.DeserializeObject(value.ToString()); + if (chunk == null) { + tracked.NextIndex++; + continue; + } + + await tracked.Channel.Writer.WriteAsync(chunk, cancellationToken); + tracked.NextIndex++; + await _redis.GetDatabase().StringSetAsync( + LlmAgentRedisKeys.AgentChunkIndex(taskId), + tracked.NextIndex, + TimeSpan.FromHours(1), + When.Always); + + if (chunk.Type is AgentChunkType.Done or AgentChunkType.Error or AgentChunkType.IterationLimitReached) { + tracked.Completion.TrySetResult(chunk); + tracked.Channel.Writer.TryComplete(); + _trackedTasks.TryRemove(taskId, out _); + await _redis.GetDatabase().KeyDeleteAsync(LlmAgentRedisKeys.AgentChunkIndex(taskId)); + await _redis.GetDatabase().KeyDeleteAsync(LlmAgentRedisKeys.AgentChunks(taskId)); + break; + } + } + } + + private async Task TryCompleteFromTaskStateAsync(string taskId, TrackedTask tracked, CancellationToken cancellationToken) { + var entries = await _redis.GetDatabase().HashGetAllAsync(LlmAgentRedisKeys.AgentTaskState(taskId)); + if (entries.Length == 0) { + return; + } + + var statusEntry = entries.FirstOrDefault(x => x.Name == "status").Value.ToString(); + if (!Enum.TryParse(statusEntry, ignoreCase: true, out var status)) { + return; + } + + if (status == AgentTaskStatus.Failed || status == AgentTaskStatus.Cancelled) { + var error = entries.FirstOrDefault(x => x.Name == "error").Value.ToString(); + await CompleteTrackedTaskAsync(taskId, tracked, new AgentStreamChunk { + TaskId = taskId, + Type = AgentChunkType.Error, + ErrorMessage = string.IsNullOrWhiteSpace(error) ? "Agent task failed." : error + }, cancellationToken); + return; + } + + if (status == AgentTaskStatus.Completed) { + await CompleteTrackedTaskAsync(taskId, tracked, new AgentStreamChunk { + TaskId = taskId, + Type = AgentChunkType.Done + }, cancellationToken); + } + } + + private async Task CompleteTrackedTaskAsync(string taskId, TrackedTask tracked, AgentStreamChunk chunk, CancellationToken cancellationToken) { + await tracked.Channel.Writer.WriteAsync(chunk, cancellationToken); + tracked.Completion.TrySetResult(chunk); + tracked.Channel.Writer.TryComplete(); + _trackedTasks.TryRemove(taskId, out _); + await _redis.GetDatabase().KeyDeleteAsync(LlmAgentRedisKeys.AgentChunkIndex(taskId)); + await _redis.GetDatabase().KeyDeleteAsync(LlmAgentRedisKeys.AgentChunks(taskId)); + } + + private sealed class TrackedTask { + public Channel Channel { get; } = System.Threading.Channels.Channel.CreateUnbounded(); + public TaskCompletionSource Completion { get; } = new(TaskCreationOptions.RunContinuationsAsynchronously); + public long NextIndex { get; set; } + public AgentTaskStreamHandle Handle => new AgentTaskStreamHandle(Channel, Completion.Task); + } + } +} diff --git a/TelegramSearchBot/Service/AI/LLM/IAgentProcessLauncher.cs b/TelegramSearchBot/Service/AI/LLM/IAgentProcessLauncher.cs new file mode 100644 index 00000000..0e1ae863 --- /dev/null +++ b/TelegramSearchBot/Service/AI/LLM/IAgentProcessLauncher.cs @@ -0,0 +1,9 @@ +using System.Threading; +using System.Threading.Tasks; + +namespace TelegramSearchBot.Service.AI.LLM { + public interface IAgentProcessLauncher { + Task StartAsync(long chatId, CancellationToken cancellationToken = default); + bool TryKill(int processId); + } +} diff --git a/TelegramSearchBot/Service/AI/LLM/LLMTaskQueueService.cs b/TelegramSearchBot/Service/AI/LLM/LLMTaskQueueService.cs new file mode 100644 index 00000000..eef26e2e --- /dev/null +++ b/TelegramSearchBot/Service/AI/LLM/LLMTaskQueueService.cs @@ -0,0 +1,224 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; +using Newtonsoft.Json; +using StackExchange.Redis; +using TelegramMessage = Telegram.Bot.Types.Message; +using TelegramSearchBot.Attributes; +using TelegramSearchBot.Common; +using TelegramSearchBot.Interface; +using TelegramSearchBot.Model; +using TelegramSearchBot.Model.AI; +using TelegramSearchBot.Model.Data; + +namespace TelegramSearchBot.Service.AI.LLM { + [Injectable(ServiceLifetime.Transient)] + public class LLMTaskQueueService : IService { + private readonly DataDbContext _dbContext; + private readonly IConnectionMultiplexer _redis; + private readonly ChunkPollingService _chunkPollingService; + private readonly AgentRegistryService _agentRegistryService; + + public LLMTaskQueueService( + DataDbContext dbContext, + IConnectionMultiplexer redis, + ChunkPollingService chunkPollingService, + AgentRegistryService agentRegistryService) { + _dbContext = dbContext; + _redis = redis; + _chunkPollingService = chunkPollingService; + _agentRegistryService = agentRegistryService; + } + + public string ServiceName => nameof(LLMTaskQueueService); + + public async Task EnqueueMessageTaskAsync( + TelegramMessage telegramMessage, + string botName, + long botUserId, + CancellationToken cancellationToken = default) { + ArgumentNullException.ThrowIfNull(telegramMessage); + + var task = await BuildMessageTaskAsync(telegramMessage, botName, botUserId, cancellationToken); + await _agentRegistryService.EnsureAgentAsync(task.ChatId, cancellationToken); + return await EnqueueTaskAsync(task); + } + + public async Task EnqueueContinuationTaskAsync( + LlmContinuationSnapshot snapshot, + string botName, + long botUserId, + CancellationToken cancellationToken = default) { + ArgumentNullException.ThrowIfNull(snapshot); + + var channelInfo = await LoadChannelAsync(snapshot.ModelName, snapshot.ChannelId, cancellationToken); + var task = new AgentExecutionTask { + Kind = AgentTaskKind.Continuation, + TaskId = Guid.NewGuid().ToString("N"), + ChatId = snapshot.ChatId, + UserId = snapshot.UserId, + MessageId = snapshot.OriginalMessageId, + BotName = botName, + BotUserId = botUserId, + ModelName = snapshot.ModelName, + MaxToolCycles = Env.MaxToolCycles, + Channel = channelInfo, + ContinuationSnapshot = snapshot + }; + + await _agentRegistryService.EnsureAgentAsync(task.ChatId, cancellationToken); + return await EnqueueTaskAsync(task); + } + + private async Task EnqueueTaskAsync(AgentExecutionTask task) { + var db = _redis.GetDatabase(); + var payload = JsonConvert.SerializeObject(task); + await db.ListLeftPushAsync(LlmAgentRedisKeys.AgentTaskQueue, payload); + await db.HashSetAsync(LlmAgentRedisKeys.AgentTaskState(task.TaskId), [ + new HashEntry("status", AgentTaskStatus.Pending.ToString()), + new HashEntry("chatId", task.ChatId), + new HashEntry("messageId", task.MessageId), + new HashEntry("modelName", task.ModelName), + new HashEntry("createdAtUtc", task.CreatedAtUtc.ToString("O")), + new HashEntry("updatedAtUtc", DateTime.UtcNow.ToString("O")), + new HashEntry("payload", payload), + new HashEntry("recoveryCount", 0), + new HashEntry("maxRecoveryAttempts", Env.AgentMaxRecoveryAttempts), + new HashEntry("lastContent", string.Empty) + ]); + + return _chunkPollingService.TrackTask(task.TaskId); + } + + private async Task BuildMessageTaskAsync( + TelegramMessage telegramMessage, + string botName, + long botUserId, + CancellationToken cancellationToken) { + var modelName = await _dbContext.GroupSettings.AsNoTracking() + .Where(x => x.GroupId == telegramMessage.Chat.Id) + .Select(x => x.LLMModelName) + .FirstOrDefaultAsync(cancellationToken); + + if (string.IsNullOrWhiteSpace(modelName)) { + throw new InvalidOperationException("请先为当前群组设置模型。"); + } + + var channelInfo = await LoadChannelAsync(modelName, null, cancellationToken); + var history = await LoadHistoryAsync(telegramMessage.Chat.Id, cancellationToken); + return new AgentExecutionTask { + TaskId = Guid.NewGuid().ToString("N"), + Kind = AgentTaskKind.Message, + ChatId = telegramMessage.Chat.Id, + UserId = telegramMessage.From?.Id ?? 0, + MessageId = telegramMessage.MessageId, + BotName = botName, + BotUserId = botUserId, + ModelName = modelName, + InputMessage = string.IsNullOrWhiteSpace(telegramMessage.Text) ? telegramMessage.Caption ?? string.Empty : telegramMessage.Text, + MaxToolCycles = Env.MaxToolCycles, + Channel = channelInfo, + History = history, + CreatedAtUtc = telegramMessage.Date.ToUniversalTime() + }; + } + + private async Task LoadChannelAsync(string modelName, int? channelId, CancellationToken cancellationToken) { + var query = _dbContext.ChannelsWithModel.AsNoTracking() + .Include(x => x.LLMChannel) + .Include(x => x.Capabilities) + .Where(x => !x.IsDeleted && x.ModelName == modelName); + + if (channelId.HasValue) { + query = query.Where(x => x.LLMChannelId == channelId.Value); + } + + var channelWithModel = await query + .OrderByDescending(x => x.LLMChannel.Priority) + .FirstOrDefaultAsync(cancellationToken); + + if (channelWithModel?.LLMChannel == null) { + throw new InvalidOperationException($"找不到模型 {modelName} 可用的渠道配置。"); + } + + return new AgentChannelConfig { + ChannelId = channelWithModel.LLMChannel.Id, + Name = channelWithModel.LLMChannel.Name, + Gateway = channelWithModel.LLMChannel.Gateway, + ApiKey = channelWithModel.LLMChannel.ApiKey, + Provider = channelWithModel.LLMChannel.Provider, + Parallel = channelWithModel.LLMChannel.Parallel, + Priority = channelWithModel.LLMChannel.Priority, + ModelName = channelWithModel.ModelName, + Capabilities = channelWithModel.Capabilities + .Select(x => new AgentModelCapability { + Name = x.CapabilityName, + Value = x.CapabilityValue, + Description = x.Description ?? string.Empty + }) + .ToList() + }; + } + + private async Task> LoadHistoryAsync(long chatId, CancellationToken cancellationToken) { + var history = await _dbContext.Messages.AsNoTracking() + .Where(x => x.GroupId == chatId && x.DateTime > DateTime.UtcNow.AddHours(-1)) + .OrderBy(x => x.DateTime) + .ToListAsync(cancellationToken); + + if (history.Count < 10) { + history = await _dbContext.Messages.AsNoTracking() + .Where(x => x.GroupId == chatId) + .OrderByDescending(x => x.DateTime) + .Take(10) + .OrderBy(x => x.DateTime) + .ToListAsync(cancellationToken); + } + + var userIds = history.Select(x => x.FromUserId).Distinct().ToList(); + var users = await _dbContext.UserData.AsNoTracking() + .Where(x => userIds.Contains(x.Id)) + .ToDictionaryAsync(x => x.Id, cancellationToken); + var messageIds = history.Select(x => x.Id).ToList(); + var extensionRecords = await _dbContext.MessageExtensions.AsNoTracking() + .Where(x => messageIds.Contains(x.MessageDataId)) + .ToListAsync(cancellationToken); + var extensions = extensionRecords + .GroupBy(x => x.MessageDataId) + .ToDictionary( + x => x.Key, + x => x.Select(e => new AgentMessageExtensionSnapshot { + Name = e.Name, + Value = e.Value + }).ToList()); + + return history.Select(message => { + users.TryGetValue(message.FromUserId, out var user); + extensions.TryGetValue(message.Id, out var messageExtensions); + return new AgentHistoryMessage { + DataId = message.Id, + DateTime = message.DateTime, + GroupId = message.GroupId, + MessageId = message.MessageId, + FromUserId = message.FromUserId, + ReplyToUserId = message.ReplyToUserId, + ReplyToMessageId = message.ReplyToMessageId, + Content = message.Content ?? string.Empty, + User = new AgentUserSnapshot { + UserId = user?.Id ?? message.FromUserId, + FirstName = user?.FirstName ?? string.Empty, + LastName = user?.LastName ?? string.Empty, + UserName = user?.UserName ?? string.Empty, + IsBot = user?.IsBot, + IsPremium = user?.IsPremium + }, + Extensions = messageExtensions ?? [] + }; + }).ToList(); + } + } +} diff --git a/TelegramSearchBot/Service/AI/LLM/LlmAgentProcessLauncher.cs b/TelegramSearchBot/Service/AI/LLM/LlmAgentProcessLauncher.cs new file mode 100644 index 00000000..e56c8f18 --- /dev/null +++ b/TelegramSearchBot/Service/AI/LLM/LlmAgentProcessLauncher.cs @@ -0,0 +1,44 @@ +using System; +using System.Diagnostics; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using TelegramSearchBot.AppBootstrap; +using TelegramSearchBot.Common; + +namespace TelegramSearchBot.Service.AI.LLM { + public sealed class LlmAgentProcessLauncher : IAgentProcessLauncher { + public Task StartAsync(long chatId, CancellationToken cancellationToken = default) { + cancellationToken.ThrowIfCancellationRequested(); + + var dllPath = Path.Combine(AppContext.BaseDirectory, "TelegramSearchBot.LLMAgent.dll"); + if (!File.Exists(dllPath)) { + throw new FileNotFoundException("LLMAgent executable not found.", dllPath); + } + + var process = AppBootstrap.AppBootstrap.Fork( + "dotnet", + [dllPath, chatId.ToString(), Env.SchedulerPort.ToString()], + GetMemoryLimitBytes()); + return Task.FromResult(process.Id); + } + + public bool TryKill(int processId) { + try { + using var process = Process.GetProcessById(processId); + process.Kill(true); + return true; + } catch { + return false; + } + } + + private static long? GetMemoryLimitBytes() { + if (Env.AgentProcessMemoryLimitMb <= 0) { + return null; + } + + return (long)Env.AgentProcessMemoryLimitMb * 1024L * 1024L; + } + } +} diff --git a/TelegramSearchBot/Service/AI/LLM/TelegramTaskConsumer.cs b/TelegramSearchBot/Service/AI/LLM/TelegramTaskConsumer.cs new file mode 100644 index 00000000..ea418ee2 --- /dev/null +++ b/TelegramSearchBot/Service/AI/LLM/TelegramTaskConsumer.cs @@ -0,0 +1,87 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Newtonsoft.Json; +using StackExchange.Redis; +using Telegram.Bot; +using TelegramSearchBot.Common; +using TelegramSearchBot.Manager; +using TelegramSearchBot.Model.AI; + +namespace TelegramSearchBot.Service.AI.LLM { + public sealed class TelegramTaskConsumer : BackgroundService { + private readonly IConnectionMultiplexer _redis; + private readonly ITelegramBotClient _botClient; + private readonly SendMessage _sendMessage; + private readonly ILogger _logger; + + public TelegramTaskConsumer( + IConnectionMultiplexer redis, + ITelegramBotClient botClient, + SendMessage sendMessage, + ILogger logger) { + _redis = redis; + _botClient = botClient; + _sendMessage = sendMessage; + _logger = logger; + } + + protected override async Task ExecuteAsync(CancellationToken stoppingToken) { + while (!stoppingToken.IsCancellationRequested) { + var result = await _redis.GetDatabase().ExecuteAsync("BRPOP", LlmAgentRedisKeys.TelegramTaskQueue, 5); + if (result.IsNull) { + continue; + } + + var parts = (RedisResult[])result!; + if (parts.Length != 2) { + continue; + } + + var payload = parts[1].ToString(); + if (string.IsNullOrWhiteSpace(payload)) { + continue; + } + + var task = JsonConvert.DeserializeObject(payload); + if (task == null) { + continue; + } + + var response = new TelegramAgentToolResult { + RequestId = task.RequestId, + Success = false + }; + + try { + if (!task.ToolName.Equals("send_message", StringComparison.OrdinalIgnoreCase)) { + throw new InvalidOperationException($"Unsupported telegram tool: {task.ToolName}"); + } + + if (!task.Arguments.TryGetValue("text", out var text) || string.IsNullOrWhiteSpace(text)) { + throw new InvalidOperationException("send_message 缺少 text 参数。"); + } + + var chatId = task.Arguments.TryGetValue("chatId", out var chatIdString) && long.TryParse(chatIdString, out var parsedChatId) + ? parsedChatId + : task.ChatId; + + var sent = await _sendMessage.AddTaskWithResult(() => _botClient.SendMessage(chatId, text, cancellationToken: stoppingToken), chatId); + response.Success = true; + response.TelegramMessageId = sent.MessageId; + response.Result = sent.MessageId.ToString(); + } catch (Exception ex) { + _logger.LogError(ex, "Failed to execute telegram task {RequestId}", task.RequestId); + response.ErrorMessage = ex.Message; + } + + await _redis.GetDatabase().StringSetAsync( + LlmAgentRedisKeys.TelegramResult(task.RequestId), + JsonConvert.SerializeObject(response), + TimeSpan.FromMinutes(5)); + } + } + } +} diff --git a/TelegramSearchBot/TelegramSearchBot.csproj b/TelegramSearchBot/TelegramSearchBot.csproj index c4396be4..48e775c7 100644 --- a/TelegramSearchBot/TelegramSearchBot.csproj +++ b/TelegramSearchBot/TelegramSearchBot.csproj @@ -95,6 +95,8 @@ + +