Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using BotSharp.Abstraction.MLTasks.Settings;

namespace BotSharp.Abstraction.MLTasks.Filters;

public class LlmConfigFilter
{
public List<string>? Providers { get; set; }
public List<string>? ModelIds { get; set; }
public List<string>? ModelNames { get; set; }
public List<LlmModelType>? ModelTypes { get; set; }
public List<LlmModelCapability>? ModelCapabilities { get; set; }
public bool? MultiModal { get; set; }
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
using BotSharp.Abstraction.MLTasks.Filters;
using BotSharp.Abstraction.MLTasks.Settings;

namespace BotSharp.Abstraction.MLTasks;

public interface ILlmProviderService
{
LlmModelSetting GetSetting(string provider, string model);
LlmModelSetting? GetSetting(string provider, string model);
List<string> GetProviders();
LlmModelSetting GetProviderModel(string provider, string id, bool? multiModal = null, LlmModelType? modelType = null, bool imageGenerate = false);
LlmModelSetting? GetProviderModel(string provider, string id, bool? multiModal = null, LlmModelType? modelType = null, IEnumerable<LlmModelCapability>? capabilities = null);
List<LlmModelSetting> GetProviderModels(string provider);
List<LlmProviderSetting> GetLlmConfigs(LlmConfigOptions? options = null);
List<LlmProviderSetting> GetLlmConfigs(LlmConfigFilter? filter = null);
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,13 @@ public class LlmModelSetting
public string ApiKey { get; set; } = null!;
public string? Endpoint { get; set; }
public LlmModelType Type { get; set; } = LlmModelType.Chat;
public List<LlmModelCapability> Capabilities { get; set; } = [];

/// <summary>
/// If true, allow sending images/vidoes to this model
/// If true, allow sending images/videos to this model
/// </summary>
public bool MultiModal { get; set; }

/// <summary>
/// If true, allow generating images
/// </summary>
public bool ImageGeneration { get; set; }

/// <summary>
/// Settings for embedding
/// </summary>
Expand Down Expand Up @@ -173,10 +169,29 @@ public class LlmCostSetting

public enum LlmModelType
{
All = 0,
Text = 1,
Chat = 2,
Image = 3,
Embedding = 4,
Audio = 5,
Realtime = 6,
Web = 7
}

public enum LlmModelCapability
{
All = 0,
Text = 1,
Chat = 2,
ImageReading = 3,
ImageGeneration = 4,
ImageEdit = 5,
ImageVariation = 6,
Embedding = 7,
AudioTranscription = 8,
AudioGeneration = 9,
Realtime = 10,
WebSearch = 11,
PdfReading = 12
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken)
await RunCronChecker(scope.ServiceProvider);
await Task.Delay(1000, stoppingToken);
});
if (isLocked == false)

if (!isLocked)
{
await Task.Delay(1000, stoppingToken);
}
Expand Down
10 changes: 8 additions & 2 deletions src/Infrastructure/BotSharp.Core/Agents/Hooks/BasicAgentHook.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,18 @@ public override void OnAgentUtilityLoaded(Agent agent)
foreach (var utility in innerUtilities)
{
var isVisible = agentService.RenderVisibility(utility.VisibilityExpression, renderDict);
if (!isVisible || utility.Items.IsNullOrEmpty()) continue;
if (!isVisible || utility.Items.IsNullOrEmpty())
{
continue;
}

foreach (var item in utility.Items)
{
isVisible = agentService.RenderVisibility(item.VisibilityExpression, renderDict);
if (!isVisible) continue;
if (!isVisible)
{
continue;
}

if (item.FunctionName?.StartsWith(UTIL_PREFIX) == true)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ public bool RenderVisibility(string? visibilityExpression, IDictionary<string, o
}

var render = _services.GetRequiredService<ITemplateRender>();
var copy = new Dictionary<string, object>(dict);
var copy = dict != null ? new Dictionary<string, object>(dict) : [];
var result = render.Render(visibilityExpression, new Dictionary<string, object>
{
{ "states", copy }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ namespace BotSharp.Core.Infrastructures;

public class CompletionProvider
{
public static object GetCompletion(IServiceProvider services,
public static object GetCompletion(
IServiceProvider services,
string? provider = null,
string? model = null,
AgentLlmConfig? agentConfig = null)
Expand Down Expand Up @@ -42,7 +43,8 @@ public static object GetCompletion(IServiceProvider services,
}
}

public static IChatCompletion GetChatCompletion(IServiceProvider services,
public static IChatCompletion GetChatCompletion(
IServiceProvider services,
string? provider = null,
string? model = null,
string? modelId = null,
Expand All @@ -66,7 +68,8 @@ public static IChatCompletion GetChatCompletion(IServiceProvider services,
return completer;
}

public static ITextCompletion GetTextCompletion(IServiceProvider services,
public static ITextCompletion GetTextCompletion(
IServiceProvider services,
string? provider = null,
string? model = null,
AgentLlmConfig? agentConfig = null)
Expand All @@ -86,15 +89,16 @@ public static ITextCompletion GetTextCompletion(IServiceProvider services,
return completer;
}

public static IImageCompletion GetImageCompletion(IServiceProvider services,
public static IImageCompletion GetImageCompletion(
IServiceProvider services,
string? provider = null,
string? model = null,
string? modelId = null,
bool imageGenerate = false)
IEnumerable<LlmModelCapability>? capabilities = null)
{
var completions = services.GetServices<IImageCompletion>();
(provider, model) = GetProviderAndModel(services, provider: provider,
model: model, modelId: modelId, imageGenerate: imageGenerate);
model: model, modelId: modelId, capabilities: capabilities);

var completer = completions.FirstOrDefault(x => x.Provider == provider);
if (completer == null)
Expand All @@ -107,7 +111,8 @@ public static IImageCompletion GetImageCompletion(IServiceProvider services,
return completer;
}

public static ITextEmbedding GetTextEmbedding(IServiceProvider services,
public static ITextEmbedding GetTextEmbedding(
IServiceProvider services,
string? provider = null,
string? model = null)
{
Expand Down Expand Up @@ -166,7 +171,8 @@ public static IAudioSynthesis GetAudioSynthesizer(
return completer;
}

public static IRealTimeCompletion GetRealTimeCompletion(IServiceProvider services,
public static IRealTimeCompletion GetRealTimeCompletion(
IServiceProvider services,
string? provider = null,
string? model = null,
string? modelId = null,
Expand All @@ -176,7 +182,7 @@ public static IRealTimeCompletion GetRealTimeCompletion(IServiceProvider service
var completions = services.GetServices<IRealTimeCompletion>();
(provider, model) = GetProviderAndModel(services, provider: provider, model: model, modelId: modelId,
multiModal: multiModal,
modelType: LlmModelType.Realtime,
modelType: LlmModelType.Realtime,
agentConfig: agentConfig);

var completer = completions.FirstOrDefault(x => x.Provider == provider);
Expand All @@ -190,13 +196,14 @@ public static IRealTimeCompletion GetRealTimeCompletion(IServiceProvider service
return completer;
}

private static (string, string) GetProviderAndModel(IServiceProvider services,
private static (string, string) GetProviderAndModel(
IServiceProvider services,
string? provider = null,
string? model = null,
string? modelId = null,
bool? multiModal = null,
LlmModelType? modelType = null,
bool imageGenerate = false,
IEnumerable<LlmModelCapability>? capabilities = null,
AgentLlmConfig? agentConfig = null)
{
var agentSetting = services.GetRequiredService<AgentSettings>();
Expand All @@ -220,9 +227,9 @@ private static (string, string) GetProviderAndModel(IServiceProvider services,
var modelIdentity = state.ContainsState("model_id") ? state.GetState("model_id") : modelId;
var llmProviderService = services.GetRequiredService<ILlmProviderService>();
model = llmProviderService.GetProviderModel(provider, modelIdentity,
multiModal: multiModal,
multiModal: multiModal,
modelType: modelType,
imageGenerate: imageGenerate)?.Name;
capabilities: capabilities)?.Name;
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using BotSharp.Abstraction.MLTasks;
using BotSharp.Abstraction.MLTasks.Filters;
using BotSharp.Abstraction.MLTasks.Settings;
using BotSharp.Abstraction.Settings;

Expand Down Expand Up @@ -40,14 +41,12 @@ public List<LlmModelSetting> GetProviderModels(string provider)
{
var settingService = _services.GetRequiredService<ISettingService>();
return settingService.Bind<List<LlmProviderSetting>>($"LlmProviders")
.FirstOrDefault(x => x.Provider.Equals(provider))
?.Models ?? new List<LlmModelSetting>();
.FirstOrDefault(x => x.Provider.Equals(provider))?.Models ?? [];
}

public LlmModelSetting GetProviderModel(string provider, string id, bool? multiModal = null, LlmModelType? modelType = null, bool imageGenerate = false)
public LlmModelSetting? GetProviderModel(string provider, string id, bool? multiModal = null, LlmModelType? modelType = null, IEnumerable<LlmModelCapability>? capabilities = null)
{
var models = GetProviderModels(provider)
.Where(x => x.Id == id);
var models = GetProviderModels(provider).Where(x => x.Id == id);

if (multiModal.HasValue)
{
Expand All @@ -59,7 +58,15 @@ public LlmModelSetting GetProviderModel(string provider, string id, bool? multiM
models = models.Where(x => x.Type == modelType.Value);
}

models = models.Where(x => x.ImageGeneration == imageGenerate);
if (capabilities != null)
{
models = models.Where(x => x.Capabilities != null && capabilities.Any(y => x.Capabilities.Contains(y)));
}

if (models.IsNullOrEmpty())
{
return null;
}

var random = new Random();
var index = random.Next(0, models.Count());
Expand All @@ -72,14 +79,14 @@ public LlmModelSetting GetProviderModel(string provider, string id, bool? multiM
var settings = _services.GetRequiredService<List<LlmProviderSetting>>();
var providerSetting = settings.FirstOrDefault(p =>
p.Provider.Equals(provider, StringComparison.CurrentCultureIgnoreCase));

if (providerSetting == null)
{
_logger.LogError($"Can't find provider settings for {provider}");
return null;
}

var modelSetting = providerSetting.Models.FirstOrDefault(m =>
m.Name.Equals(model, StringComparison.CurrentCultureIgnoreCase));
var modelSetting = providerSetting.Models.FirstOrDefault(m => m.Name.Equals(model, StringComparison.CurrentCultureIgnoreCase));
if (modelSetting == null)
{
_logger.LogError($"Can't find model settings for {provider}.{model}");
Expand All @@ -95,50 +102,75 @@ public LlmModelSetting GetProviderModel(string provider, string id, bool? multiM
m.Group.Equals(modelSetting.Group, StringComparison.CurrentCultureIgnoreCase))
.ToList();

// pick one model randomly
var random = new Random();
var index = random.Next(0, models.Count());
modelSetting = models.ElementAt(index);
if (!models.IsNullOrEmpty())
{
// pick one model randomly
var random = new Random();
var index = random.Next(0, models.Count());
modelSetting = models.ElementAt(index);
}
}

return modelSetting;
}


public List<LlmProviderSetting> GetLlmConfigs(LlmConfigOptions? options = null)
public List<LlmProviderSetting> GetLlmConfigs(LlmConfigFilter? filter = null)
{
var settingService = _services.GetRequiredService<ISettingService>();
var providers = settingService.Bind<List<LlmProviderSetting>>($"LlmProviders");
var configs = new List<LlmProviderSetting>();
var comparer = StringComparer.OrdinalIgnoreCase;

if (providers.IsNullOrEmpty())
{
return configs;
}

if (providers.IsNullOrEmpty()) return configs;
if (filter == null)
{
return providers ?? [];
}

if (options == null) return providers ?? [];
if (filter.Providers != null)
{
providers = providers.Where(x => filter.Providers.Contains(x.Provider, comparer)).ToList();
}

foreach (var provider in providers)
{
var models = provider.Models ?? [];
if (options.Type.HasValue)
IEnumerable<LlmModelSetting> models = provider.Models ?? [];
if (filter.ModelTypes != null)
{
models = models.Where(x => filter.ModelTypes.Contains(x.Type));
}

if (filter.ModelIds != null)
{
models = models.Where(x => filter.ModelIds.Contains(x.Id, comparer));
}

if (filter.ModelNames != null)
{
models = models.Where(x => x.Type == options.Type.Value).ToList();
models = models.Where(x => filter.ModelNames.Contains(x.Name, comparer));
}

if (options.MultiModal.HasValue)
if (filter.ModelCapabilities != null)
{
models = models.Where(x => x.MultiModal == options.MultiModal.Value).ToList();
models = models.Where(x => x.Capabilities != null && filter.ModelCapabilities.Any(y => x.Capabilities.Contains(y)));
}

if (options.ImageGeneration.HasValue)
if (filter.MultiModal.HasValue)
{
models = models.Where(x => x.ImageGeneration == options.ImageGeneration.Value).ToList();
models = models.Where(x => x.MultiModal == filter.MultiModal.Value);
}

if (models.IsNullOrEmpty())
{
continue;
}

provider.Models = models;
provider.Models = models.ToList();
configs.Add(provider);
}

Expand Down
Loading
Loading