diff --git a/BotSharp.sln b/BotSharp.sln index 102137084..93f90cae1 100644 --- a/BotSharp.sln +++ b/BotSharp.sln @@ -151,6 +151,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "BotSharp.Plugin.ImageHandle EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "BotSharp.Plugin.FuzzySharp", "src\Plugins\BotSharp.Plugin.FuzzySharp\BotSharp.Plugin.FuzzySharp.csproj", "{E7C243B9-E751-B3B4-8F16-95C76CA90D31}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "BotSharp.Plugin.MMPEmbedding", "src\Plugins\BotSharp.Plugin.MMPEmbedding\BotSharp.Plugin.MMPEmbedding.csproj", "{394B858B-9C26-B977-A2DA-8CC7BE5914CB}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -639,6 +641,14 @@ Global {E7C243B9-E751-B3B4-8F16-95C76CA90D31}.Release|Any CPU.Build.0 = Release|Any CPU {E7C243B9-E751-B3B4-8F16-95C76CA90D31}.Release|x64.ActiveCfg = Release|Any CPU {E7C243B9-E751-B3B4-8F16-95C76CA90D31}.Release|x64.Build.0 = Release|Any CPU + {394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Debug|Any CPU.Build.0 = Debug|Any CPU + {394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Debug|x64.ActiveCfg = Debug|Any CPU + {394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Debug|x64.Build.0 = Debug|Any CPU + {394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Release|Any CPU.ActiveCfg = Release|Any CPU + {394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Release|Any CPU.Build.0 = Release|Any CPU + {394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Release|x64.ActiveCfg = Release|Any CPU + {394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Release|x64.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -712,6 +722,7 @@ Global {FC63C875-E880-D8BB-B8B5-978AB7B62983} = {51AFE054-AE99-497D-A593-69BAEFB5106F} {242F2D93-FCCE-4982-8075-F3052ECCA92C} = {51AFE054-AE99-497D-A593-69BAEFB5106F} {E7C243B9-E751-B3B4-8F16-95C76CA90D31} = {51AFE054-AE99-497D-A593-69BAEFB5106F} + {394B858B-9C26-B977-A2DA-8CC7BE5914CB} = {2635EC9B-2E5F-4313-AC21-0B847F31F36C} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {A9969D89-C98B-40A5-A12B-FC87E55B3A19} diff --git a/src/Plugins/BotSharp.Plugin.MMPEmbedding/BotSharp.Plugin.MMPEmbedding.csproj b/src/Plugins/BotSharp.Plugin.MMPEmbedding/BotSharp.Plugin.MMPEmbedding.csproj new file mode 100644 index 000000000..901ebb687 --- /dev/null +++ b/src/Plugins/BotSharp.Plugin.MMPEmbedding/BotSharp.Plugin.MMPEmbedding.csproj @@ -0,0 +1,18 @@ + + + + $(TargetFramework) + enable + enable + + + + + + + + + + + + diff --git a/src/Plugins/BotSharp.Plugin.MMPEmbedding/MMPEmbeddingPlugin.cs b/src/Plugins/BotSharp.Plugin.MMPEmbedding/MMPEmbeddingPlugin.cs new file mode 100644 index 000000000..26a5c1538 --- /dev/null +++ b/src/Plugins/BotSharp.Plugin.MMPEmbedding/MMPEmbeddingPlugin.cs @@ -0,0 +1,19 @@ +using BotSharp.Abstraction.Plugins; +using BotSharp.Plugin.MMPEmbedding.Providers; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; + +namespace BotSharp.Plugin.MMPEmbedding +{ + public class MMPEmbeddingPlugin : IBotSharpPlugin + { + public string Id => "54d04e10-fc84-493e-a8c9-39da1c83f45a"; + public string Name => "MMPEmbedding"; + public string Description => "MMP Embedding Service"; + + public void RegisterDI(IServiceCollection services, IConfiguration config) + { + services.AddScoped(); + } + } +} diff --git a/src/Plugins/BotSharp.Plugin.MMPEmbedding/ProviderHelper.cs b/src/Plugins/BotSharp.Plugin.MMPEmbedding/ProviderHelper.cs new file mode 100644 index 000000000..4fa7f6f3c --- /dev/null +++ b/src/Plugins/BotSharp.Plugin.MMPEmbedding/ProviderHelper.cs @@ -0,0 +1,70 @@ +using OpenAI; +using Azure.AI.OpenAI; +using System.ClientModel; +using Microsoft.Extensions.DependencyInjection; + +namespace BotSharp.Plugin.MMPEmbedding; + +/// +/// Helper class to get the appropriate client based on provider type +/// Supports multiple providers: OpenAI, Azure OpenAI, DeepSeek, etc. +/// +public static class ProviderHelper +{ + /// + /// Gets an OpenAI-compatible client based on the provider name + /// + /// Provider name (e.g., "openai", "azure-openai") + /// Model name + /// Service provider for dependency injection + /// OpenAIClient instance configured for the specified provider + public static OpenAIClient GetClient(string provider, string model, IServiceProvider services) + { + var settingsService = services.GetRequiredService(); + var settings = settingsService.GetSetting(provider, model); + + if (settings == null) + { + throw new InvalidOperationException($"Cannot find settings for provider '{provider}' and model '{model}'"); + } + + // Handle Azure OpenAI separately as it uses AzureOpenAIClient + if (provider.Equals("azure-openai", StringComparison.OrdinalIgnoreCase)) + { + return GetAzureOpenAIClient(settings); + } + + // For OpenAI, DeepSeek, and other OpenAI-compatible providers + return GetOpenAICompatibleClient(settings); + } + + /// + /// Gets an Azure OpenAI client + /// + private static OpenAIClient GetAzureOpenAIClient(LlmModelSetting settings) + { + if (string.IsNullOrEmpty(settings.Endpoint)) + { + throw new InvalidOperationException("Azure OpenAI endpoint is required"); + } + + var client = new AzureOpenAIClient( + new Uri(settings.Endpoint), + new ApiKeyCredential(settings.ApiKey) + ); + + return client; + } + + /// + /// Gets an OpenAI-compatible client (OpenAI, DeepSeek, etc.) + /// + private static OpenAIClient GetOpenAICompatibleClient(LlmModelSetting settings) + { + var options = !string.IsNullOrEmpty(settings.Endpoint) + ? new OpenAIClientOptions { Endpoint = new Uri(settings.Endpoint) } + : null; + + return new OpenAIClient(new ApiKeyCredential(settings.ApiKey), options); + } +} diff --git a/src/Plugins/BotSharp.Plugin.MMPEmbedding/Providers/MMPEmbeddingProvider.cs b/src/Plugins/BotSharp.Plugin.MMPEmbedding/Providers/MMPEmbeddingProvider.cs new file mode 100644 index 000000000..9d1054096 --- /dev/null +++ b/src/Plugins/BotSharp.Plugin.MMPEmbedding/Providers/MMPEmbeddingProvider.cs @@ -0,0 +1,167 @@ +using System.Collections.Generic; +using System.Text.RegularExpressions; +using BotSharp.Plugin.MMPEmbedding; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using OpenAI.Embeddings; + +namespace BotSharp.Plugin.MMPEmbedding.Providers; + +/// +/// Text embedding provider that uses Mean-Max Pooling strategy +/// This provider gets embeddings for individual tokens and combines them using mean and max pooling +/// +public class MMPEmbeddingProvider : ITextEmbedding +{ + protected readonly IServiceProvider _serviceProvider; + protected readonly ILogger _logger; + + private const int DEFAULT_DIMENSION = 1536; + protected string _model = "text-embedding-3-small"; + protected int _dimension = DEFAULT_DIMENSION; + + // The underlying provider to use (e.g., "openai", "azure-openai", "deepseek-ai") + protected string _underlyingProvider = "openai"; + + public string Provider => "mmp-embedding"; + public string Model => _model; + + private static readonly Regex WordRegex = new(@"\b\w+\b", RegexOptions.Compiled); + + public MMPEmbeddingProvider(IServiceProvider serviceProvider, ILogger logger) + { + _serviceProvider = serviceProvider; + _logger = logger; + } + + /// + /// Gets a single embedding vector using mean-max pooling + /// + public async Task GetVectorAsync(string text) + { + if (string.IsNullOrWhiteSpace(text)) + { + return new float[_dimension]; + } + + var tokens = Tokenize(text).ToList(); + + if (tokens.Count == 0) + { + return new float[_dimension]; + } + + // Get embeddings for all tokens + var tokenEmbeddings = await GetTokenEmbeddingsAsync(tokens); + + // Apply mean-max pooling + var pooledEmbedding = MeanMaxPooling(tokenEmbeddings); + + return pooledEmbedding; + } + + /// + /// Gets multiple embedding vectors using mean-max pooling + /// + public async Task> GetVectorsAsync(List texts) + { + var results = new List(); + + foreach (var text in texts) + { + var embedding = await GetVectorAsync(text); + results.Add(embedding); + } + + return results; + } + + /// + /// Gets embeddings for individual tokens using the underlying provider + /// + private async Task> GetTokenEmbeddingsAsync(List tokens) + { + try + { + // Get the appropriate client based on the underlying provider + var client = ProviderHelper.GetClient(_underlyingProvider, _model, _serviceProvider); + var embeddingClient = client.GetEmbeddingClient(_model); + + // Prepare options + var options = new EmbeddingGenerationOptions + { + Dimensions = _dimension > 0 ? _dimension : null + }; + + // Get embeddings for all tokens in batch + var response = await embeddingClient.GenerateEmbeddingsAsync(tokens, options); + var embeddings = response.Value; + + return embeddings.Select(e => e.ToFloats().ToArray()).ToList(); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error getting token embeddings from provider {Provider} with model {Model}", + _underlyingProvider, _model); + throw; + } + } + + /// + /// Applies mean-max pooling to combine token embeddings + /// Mean pooling: average of all token embeddings + /// Max pooling: element-wise maximum of all token embeddings + /// Result: concatenation of mean and max pooled vectors + /// + private float[] MeanMaxPooling(IReadOnlyList vectors, double meanWeight = 0.5, double maxWeight = 0.5) + { + var numTokens = vectors.Count; + + if (numTokens == 0) + return []; + + var meanPooled = Enumerable.Range(0, _dimension) + .Select(i => vectors.Average(v => v[i])) + .ToArray(); + var maxPooled = Enumerable.Range(0, _dimension) + .Select(i => vectors.Max(v => v[i])) + .ToArray(); + + return Enumerable.Range(0, _dimension) + .Select(i => (float)meanWeight * meanPooled[i] + (float)maxWeight * maxPooled[i]) + .ToArray(); + } + + public void SetDimension(int dimension) + { + _dimension = dimension > 0 ? dimension : DEFAULT_DIMENSION; + } + + public int GetDimension() + { + return _dimension; + } + + public void SetModelName(string model) + { + _model = model; + } + + /// + /// Sets the underlying provider to use for getting token embeddings + /// + /// Provider name (e.g., "openai", "azure-openai", "deepseek-ai") + public void SetUnderlyingProvider(string provider) + { + _underlyingProvider = provider; + } + + /// + /// Tokenizes text into individual words + /// + public static IEnumerable Tokenize(string text, string? pattern = null) + { + var patternRegex = string.IsNullOrEmpty(pattern) ? WordRegex : new(pattern, RegexOptions.Compiled); + return patternRegex.Matches(text).Cast().Select(m => m.Value); + } +} diff --git a/src/Plugins/BotSharp.Plugin.MMPEmbedding/Using.cs b/src/Plugins/BotSharp.Plugin.MMPEmbedding/Using.cs new file mode 100644 index 000000000..70cbe34f5 --- /dev/null +++ b/src/Plugins/BotSharp.Plugin.MMPEmbedding/Using.cs @@ -0,0 +1,10 @@ +global using System; +global using System.Collections.Generic; +global using System.Linq; +global using System.Text; +global using System.Threading.Tasks; + +global using BotSharp.Abstraction.MLTasks; +global using BotSharp.Abstraction.MLTasks.Settings; +global using Microsoft.Extensions.Logging; +