From 8af9552a908293fc559e263282dc5439b153157c Mon Sep 17 00:00:00 2001
From: Wenbo Cao <104199@smsassist.com>
Date: Tue, 18 Nov 2025 11:10:06 -0600
Subject: [PATCH] Add MMP embedding
---
BotSharp.sln | 11 ++
.../BotSharp.Plugin.MMPEmbedding.csproj | 18 ++
.../MMPEmbeddingPlugin.cs | 19 ++
.../ProviderHelper.cs | 70 ++++++++
.../Providers/MMPEmbeddingProvider.cs | 167 ++++++++++++++++++
.../BotSharp.Plugin.MMPEmbedding/Using.cs | 10 ++
6 files changed, 295 insertions(+)
create mode 100644 src/Plugins/BotSharp.Plugin.MMPEmbedding/BotSharp.Plugin.MMPEmbedding.csproj
create mode 100644 src/Plugins/BotSharp.Plugin.MMPEmbedding/MMPEmbeddingPlugin.cs
create mode 100644 src/Plugins/BotSharp.Plugin.MMPEmbedding/ProviderHelper.cs
create mode 100644 src/Plugins/BotSharp.Plugin.MMPEmbedding/Providers/MMPEmbeddingProvider.cs
create mode 100644 src/Plugins/BotSharp.Plugin.MMPEmbedding/Using.cs
diff --git a/BotSharp.sln b/BotSharp.sln
index e992d26ad..22b188693 100644
--- a/BotSharp.sln
+++ b/BotSharp.sln
@@ -149,6 +149,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "BotSharp.Plugin.ExcelHandle
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "BotSharp.Plugin.ImageHandler", "src\Plugins\BotSharp.Plugin.ImageHandler\BotSharp.Plugin.ImageHandler.csproj", "{242F2D93-FCCE-4982-8075-F3052ECCA92C}"
EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "BotSharp.Plugin.MMPEmbedding", "src\Plugins\BotSharp.Plugin.MMPEmbedding\BotSharp.Plugin.MMPEmbedding.csproj", "{394B858B-9C26-B977-A2DA-8CC7BE5914CB}"
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -629,6 +631,14 @@ Global
{242F2D93-FCCE-4982-8075-F3052ECCA92C}.Release|Any CPU.Build.0 = Release|Any CPU
{242F2D93-FCCE-4982-8075-F3052ECCA92C}.Release|x64.ActiveCfg = Release|Any CPU
{242F2D93-FCCE-4982-8075-F3052ECCA92C}.Release|x64.Build.0 = Release|Any CPU
+ {394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Debug|x64.ActiveCfg = Debug|Any CPU
+ {394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Debug|x64.Build.0 = Debug|Any CPU
+ {394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Release|Any CPU.Build.0 = Release|Any CPU
+ {394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Release|x64.ActiveCfg = Release|Any CPU
+ {394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Release|x64.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@@ -701,6 +711,7 @@ Global
{0428DEAA-E4FE-4259-A6D8-6EDD1A9D0702} = {51AFE054-AE99-497D-A593-69BAEFB5106F}
{FC63C875-E880-D8BB-B8B5-978AB7B62983} = {51AFE054-AE99-497D-A593-69BAEFB5106F}
{242F2D93-FCCE-4982-8075-F3052ECCA92C} = {51AFE054-AE99-497D-A593-69BAEFB5106F}
+ {394B858B-9C26-B977-A2DA-8CC7BE5914CB} = {2635EC9B-2E5F-4313-AC21-0B847F31F36C}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {A9969D89-C98B-40A5-A12B-FC87E55B3A19}
diff --git a/src/Plugins/BotSharp.Plugin.MMPEmbedding/BotSharp.Plugin.MMPEmbedding.csproj b/src/Plugins/BotSharp.Plugin.MMPEmbedding/BotSharp.Plugin.MMPEmbedding.csproj
new file mode 100644
index 000000000..901ebb687
--- /dev/null
+++ b/src/Plugins/BotSharp.Plugin.MMPEmbedding/BotSharp.Plugin.MMPEmbedding.csproj
@@ -0,0 +1,18 @@
+
+
+
+ $(TargetFramework)
+ enable
+ enable
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/Plugins/BotSharp.Plugin.MMPEmbedding/MMPEmbeddingPlugin.cs b/src/Plugins/BotSharp.Plugin.MMPEmbedding/MMPEmbeddingPlugin.cs
new file mode 100644
index 000000000..26a5c1538
--- /dev/null
+++ b/src/Plugins/BotSharp.Plugin.MMPEmbedding/MMPEmbeddingPlugin.cs
@@ -0,0 +1,19 @@
+using BotSharp.Abstraction.Plugins;
+using BotSharp.Plugin.MMPEmbedding.Providers;
+using Microsoft.Extensions.Configuration;
+using Microsoft.Extensions.DependencyInjection;
+
+namespace BotSharp.Plugin.MMPEmbedding
+{
+ public class MMPEmbeddingPlugin : IBotSharpPlugin
+ {
+ public string Id => "54d04e10-fc84-493e-a8c9-39da1c83f45a";
+ public string Name => "MMPEmbedding";
+ public string Description => "MMP Embedding Service";
+
+ public void RegisterDI(IServiceCollection services, IConfiguration config)
+ {
+ services.AddScoped();
+ }
+ }
+}
diff --git a/src/Plugins/BotSharp.Plugin.MMPEmbedding/ProviderHelper.cs b/src/Plugins/BotSharp.Plugin.MMPEmbedding/ProviderHelper.cs
new file mode 100644
index 000000000..4fa7f6f3c
--- /dev/null
+++ b/src/Plugins/BotSharp.Plugin.MMPEmbedding/ProviderHelper.cs
@@ -0,0 +1,70 @@
+using OpenAI;
+using Azure.AI.OpenAI;
+using System.ClientModel;
+using Microsoft.Extensions.DependencyInjection;
+
+namespace BotSharp.Plugin.MMPEmbedding;
+
+///
+/// Helper class to get the appropriate client based on provider type
+/// Supports multiple providers: OpenAI, Azure OpenAI, DeepSeek, etc.
+///
+public static class ProviderHelper
+{
+ ///
+ /// Gets an OpenAI-compatible client based on the provider name
+ ///
+ /// Provider name (e.g., "openai", "azure-openai")
+ /// Model name
+ /// Service provider for dependency injection
+ /// OpenAIClient instance configured for the specified provider
+ public static OpenAIClient GetClient(string provider, string model, IServiceProvider services)
+ {
+ var settingsService = services.GetRequiredService();
+ var settings = settingsService.GetSetting(provider, model);
+
+ if (settings == null)
+ {
+ throw new InvalidOperationException($"Cannot find settings for provider '{provider}' and model '{model}'");
+ }
+
+ // Handle Azure OpenAI separately as it uses AzureOpenAIClient
+ if (provider.Equals("azure-openai", StringComparison.OrdinalIgnoreCase))
+ {
+ return GetAzureOpenAIClient(settings);
+ }
+
+ // For OpenAI, DeepSeek, and other OpenAI-compatible providers
+ return GetOpenAICompatibleClient(settings);
+ }
+
+ ///
+ /// Gets an Azure OpenAI client
+ ///
+ private static OpenAIClient GetAzureOpenAIClient(LlmModelSetting settings)
+ {
+ if (string.IsNullOrEmpty(settings.Endpoint))
+ {
+ throw new InvalidOperationException("Azure OpenAI endpoint is required");
+ }
+
+ var client = new AzureOpenAIClient(
+ new Uri(settings.Endpoint),
+ new ApiKeyCredential(settings.ApiKey)
+ );
+
+ return client;
+ }
+
+ ///
+ /// Gets an OpenAI-compatible client (OpenAI, DeepSeek, etc.)
+ ///
+ private static OpenAIClient GetOpenAICompatibleClient(LlmModelSetting settings)
+ {
+ var options = !string.IsNullOrEmpty(settings.Endpoint)
+ ? new OpenAIClientOptions { Endpoint = new Uri(settings.Endpoint) }
+ : null;
+
+ return new OpenAIClient(new ApiKeyCredential(settings.ApiKey), options);
+ }
+}
diff --git a/src/Plugins/BotSharp.Plugin.MMPEmbedding/Providers/MMPEmbeddingProvider.cs b/src/Plugins/BotSharp.Plugin.MMPEmbedding/Providers/MMPEmbeddingProvider.cs
new file mode 100644
index 000000000..9d1054096
--- /dev/null
+++ b/src/Plugins/BotSharp.Plugin.MMPEmbedding/Providers/MMPEmbeddingProvider.cs
@@ -0,0 +1,167 @@
+using System.Collections.Generic;
+using System.Text.RegularExpressions;
+using BotSharp.Plugin.MMPEmbedding;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.Logging;
+using OpenAI.Embeddings;
+
+namespace BotSharp.Plugin.MMPEmbedding.Providers;
+
+///
+/// Text embedding provider that uses Mean-Max Pooling strategy
+/// This provider gets embeddings for individual tokens and combines them using mean and max pooling
+///
+public class MMPEmbeddingProvider : ITextEmbedding
+{
+ protected readonly IServiceProvider _serviceProvider;
+ protected readonly ILogger _logger;
+
+ private const int DEFAULT_DIMENSION = 1536;
+ protected string _model = "text-embedding-3-small";
+ protected int _dimension = DEFAULT_DIMENSION;
+
+ // The underlying provider to use (e.g., "openai", "azure-openai", "deepseek-ai")
+ protected string _underlyingProvider = "openai";
+
+ public string Provider => "mmp-embedding";
+ public string Model => _model;
+
+ private static readonly Regex WordRegex = new(@"\b\w+\b", RegexOptions.Compiled);
+
+ public MMPEmbeddingProvider(IServiceProvider serviceProvider, ILogger logger)
+ {
+ _serviceProvider = serviceProvider;
+ _logger = logger;
+ }
+
+ ///
+ /// Gets a single embedding vector using mean-max pooling
+ ///
+ public async Task GetVectorAsync(string text)
+ {
+ if (string.IsNullOrWhiteSpace(text))
+ {
+ return new float[_dimension];
+ }
+
+ var tokens = Tokenize(text).ToList();
+
+ if (tokens.Count == 0)
+ {
+ return new float[_dimension];
+ }
+
+ // Get embeddings for all tokens
+ var tokenEmbeddings = await GetTokenEmbeddingsAsync(tokens);
+
+ // Apply mean-max pooling
+ var pooledEmbedding = MeanMaxPooling(tokenEmbeddings);
+
+ return pooledEmbedding;
+ }
+
+ ///
+ /// Gets multiple embedding vectors using mean-max pooling
+ ///
+ public async Task> GetVectorsAsync(List texts)
+ {
+ var results = new List();
+
+ foreach (var text in texts)
+ {
+ var embedding = await GetVectorAsync(text);
+ results.Add(embedding);
+ }
+
+ return results;
+ }
+
+ ///
+ /// Gets embeddings for individual tokens using the underlying provider
+ ///
+ private async Task> GetTokenEmbeddingsAsync(List tokens)
+ {
+ try
+ {
+ // Get the appropriate client based on the underlying provider
+ var client = ProviderHelper.GetClient(_underlyingProvider, _model, _serviceProvider);
+ var embeddingClient = client.GetEmbeddingClient(_model);
+
+ // Prepare options
+ var options = new EmbeddingGenerationOptions
+ {
+ Dimensions = _dimension > 0 ? _dimension : null
+ };
+
+ // Get embeddings for all tokens in batch
+ var response = await embeddingClient.GenerateEmbeddingsAsync(tokens, options);
+ var embeddings = response.Value;
+
+ return embeddings.Select(e => e.ToFloats().ToArray()).ToList();
+ }
+ catch (Exception ex)
+ {
+ _logger.LogError(ex, "Error getting token embeddings from provider {Provider} with model {Model}",
+ _underlyingProvider, _model);
+ throw;
+ }
+ }
+
+ ///
+ /// Applies mean-max pooling to combine token embeddings
+ /// Mean pooling: average of all token embeddings
+ /// Max pooling: element-wise maximum of all token embeddings
+ /// Result: concatenation of mean and max pooled vectors
+ ///
+ private float[] MeanMaxPooling(IReadOnlyList vectors, double meanWeight = 0.5, double maxWeight = 0.5)
+ {
+ var numTokens = vectors.Count;
+
+ if (numTokens == 0)
+ return [];
+
+ var meanPooled = Enumerable.Range(0, _dimension)
+ .Select(i => vectors.Average(v => v[i]))
+ .ToArray();
+ var maxPooled = Enumerable.Range(0, _dimension)
+ .Select(i => vectors.Max(v => v[i]))
+ .ToArray();
+
+ return Enumerable.Range(0, _dimension)
+ .Select(i => (float)meanWeight * meanPooled[i] + (float)maxWeight * maxPooled[i])
+ .ToArray();
+ }
+
+ public void SetDimension(int dimension)
+ {
+ _dimension = dimension > 0 ? dimension : DEFAULT_DIMENSION;
+ }
+
+ public int GetDimension()
+ {
+ return _dimension;
+ }
+
+ public void SetModelName(string model)
+ {
+ _model = model;
+ }
+
+ ///
+ /// Sets the underlying provider to use for getting token embeddings
+ ///
+ /// Provider name (e.g., "openai", "azure-openai", "deepseek-ai")
+ public void SetUnderlyingProvider(string provider)
+ {
+ _underlyingProvider = provider;
+ }
+
+ ///
+ /// Tokenizes text into individual words
+ ///
+ public static IEnumerable Tokenize(string text, string? pattern = null)
+ {
+ var patternRegex = string.IsNullOrEmpty(pattern) ? WordRegex : new(pattern, RegexOptions.Compiled);
+ return patternRegex.Matches(text).Cast().Select(m => m.Value);
+ }
+}
diff --git a/src/Plugins/BotSharp.Plugin.MMPEmbedding/Using.cs b/src/Plugins/BotSharp.Plugin.MMPEmbedding/Using.cs
new file mode 100644
index 000000000..70cbe34f5
--- /dev/null
+++ b/src/Plugins/BotSharp.Plugin.MMPEmbedding/Using.cs
@@ -0,0 +1,10 @@
+global using System;
+global using System.Collections.Generic;
+global using System.Linq;
+global using System.Text;
+global using System.Threading.Tasks;
+
+global using BotSharp.Abstraction.MLTasks;
+global using BotSharp.Abstraction.MLTasks.Settings;
+global using Microsoft.Extensions.Logging;
+