Skip to content

Commit

Permalink
.Net: Refactor planners, memory config, and function extensions (micr…
Browse files Browse the repository at this point in the history
…osoft#2949)

Followup to microsoft#2912 and microsoft#2931.

Resolves microsoft#2848
Resolves microsoft#2074

This commit includes several updates and refactors to planners,
SemanticMemoryConfig, and function extension classes. Changes include
updating SequentialPlannerConfig to use SemanticMemory, refactoring
tests to use async methods, renaming and updating test cases, and
improving planner configurations. Additionally, FunctionViewExtensions
has been added and refactored, along with updates to method signatures
and code organization. The StepwisePlanner has also been refactored,
and PlannerConfigBase has been updated to improve memory usage and
function filtering.
  • Loading branch information
lemillermicrosoft authored and SOE-YoungS committed Oct 31, 2023
1 parent 28940df commit c18dbfb
Show file tree
Hide file tree
Showing 18 changed files with 265 additions and 255 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ private static async Task MemorySampleAsync()
var goal = "Create a book with 3 chapters about a group of kids in a club called 'The Thinking Caps.'";

// IMPORTANT: To use memory and embeddings to find relevant plugins in the planner, set the 'Memory' property on the planner config.
var planner = new SequentialPlanner(kernel, new SequentialPlannerConfig { RelevancyThreshold = 0.5, Memory = kernel.Memory });
var planner = new SequentialPlanner(kernel, new SequentialPlannerConfig { SemanticMemoryConfig = new() { RelevancyThreshold = 0.5, Memory = kernel.Memory } });

var plan = await planner.CreatePlanAsync(goal);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

using Microsoft.Extensions.Configuration;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Planners;
using Microsoft.SemanticKernel.Planners.Sequential;
using Microsoft.SemanticKernel.Planning;
using SemanticKernel.IntegrationTests.Fakes;
Expand Down Expand Up @@ -52,7 +53,7 @@ public void CanCallToPlanFromXml()
var goal = "Summarize an input, translate to french, and e-mail to John Doe";

// Act
var plan = planString.ToPlanFromXml(goal, SequentialPlanParser.GetFunctionCallback(kernel.Functions));
var plan = planString.ToPlanFromXml(goal, kernel.Functions.GetFunctionCallback());

// Assert
Assert.NotNull(plan);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public async Task CreatePlanGoalRelevantAsync(string prompt, string expectedFunc
TestHelpers.ImportAllSamplePlugins(kernel);

var planner = new Microsoft.SemanticKernel.Planners.SequentialPlanner(kernel,
new SequentialPlannerConfig { RelevancyThreshold = 0.65, MaxRelevantFunctions = 30, Memory = kernel.Memory });
new SequentialPlannerConfig { SemanticMemoryConfig = new() { RelevancyThreshold = 0.65, MaxRelevantFunctions = 30, Memory = kernel.Memory } });

// Act
var plan = await planner.CreatePlanAsync(prompt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public async Task MalformedJsonThrowsAsync()
}

[Fact]
public void ListOfFunctionsIncludesNativeAndSemanticFunctions()
public async Task ListOfFunctionsIncludesNativeAndSemanticFunctionsAsync()
{
// Arrange
var plugins = this.CreateMockFunctionCollection();
Expand All @@ -109,15 +109,15 @@ public void ListOfFunctionsIncludesNativeAndSemanticFunctions()
var context = kernel.Object.CreateNewContext();

// Act
var result = planner.ListOfFunctions("goal", context);
var result = await planner.ListOfFunctionsAsync("goal", context);

// Assert
var expected = $"// Send an e-mail.{Environment.NewLine}email.SendEmail{Environment.NewLine}// List pull requests.{Environment.NewLine}GitHubPlugin.PullsList{Environment.NewLine}// List repositories.{Environment.NewLine}GitHubPlugin.RepoList{Environment.NewLine}";
Assert.Equal(expected, result);
}

[Fact]
public void ListOfFunctionsExcludesExcludedPlugins()
public async Task ListOfFunctionsExcludesExcludedPluginsAsync()
{
// Arrange
var plugins = this.CreateMockFunctionCollection();
Expand All @@ -128,15 +128,15 @@ public void ListOfFunctionsExcludesExcludedPlugins()
var context = kernel.Object.CreateNewContext();

// Act
var result = planner.ListOfFunctions("goal", context);
var result = await planner.ListOfFunctionsAsync("goal", context);

// Assert
var expected = $"// Send an e-mail.{Environment.NewLine}email.SendEmail{Environment.NewLine}";
Assert.Equal(expected, result);
}

[Fact]
public void ListOfFunctionsExcludesExcludedFunctions()
public async Task ListOfFunctionsExcludesExcludedFunctionsAsync()
{
// Arrange
var plugins = this.CreateMockFunctionCollection();
Expand All @@ -147,7 +147,7 @@ public void ListOfFunctionsExcludesExcludedFunctions()
var context = kernel.Object.CreateNewContext();

// Act
var result = planner.ListOfFunctions("goal", context);
var result = await planner.ListOfFunctionsAsync("goal", context);

// Assert
var expected = $"// Send an e-mail.{Environment.NewLine}email.SendEmail{Environment.NewLine}// List repositories.{Environment.NewLine}GitHubPlugin.RepoList{Environment.NewLine}";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,23 @@
using Xunit;

#pragma warning disable IDE0130 // Namespace does not match folder structure
namespace Microsoft.SemanticKernel.Planners.Sequential.UnitTests;
namespace Microsoft.SemanticKernel.Planners.UnitTests;
#pragma warning restore IDE0130 // Namespace does not match folder structure

public class SKContextExtensionsTests
public class ReadOnlyFunctionCollectionExtensionsTests
{
[Fact]
public async Task CanCallGetAvailableFunctionsWithNoFunctionsAsync()
private static PlannerConfigBase InitializeConfig(Type t)
{
PlannerConfigBase? config = Activator.CreateInstance(t) as PlannerConfigBase;
Assert.NotNull(config);
return config;
}

[Theory]
[InlineData(typeof(ActionPlannerConfig))]
[InlineData(typeof(SequentialPlannerConfig))]
[InlineData(typeof(StepwisePlannerConfig))]
public async Task CanCallGetAvailableFunctionsWithNoFunctionsAsync(Type t)
{
// Arrange
var kernel = new Mock<IKernel>();
Expand All @@ -40,21 +50,46 @@ public async Task CanCallGetAvailableFunctionsWithNoFunctionsAsync()

// Arrange GetAvailableFunctionsAsync parameters
var context = new SKContext(kernel.Object, variables, functions);
var config = new SequentialPlannerConfig() { Memory = memory.Object };
var config = InitializeConfig(t);
var semanticQuery = "test";

// Act
var result = await context.GetAvailableFunctionsAsync(config, semanticQuery, cancellationToken);
var result = await context.Functions.GetAvailableFunctionsAsync(config, semanticQuery, null, cancellationToken);

// Assert
Assert.NotNull(result);
memory.Verify(
x => x.SearchAsync(It.IsAny<string>(), It.IsAny<string>(), It.IsAny<int>(), It.IsAny<double>(), It.IsAny<bool>(), It.IsAny<CancellationToken>()),
Times.Never);

config.SemanticMemoryConfig = new();

// Act
result = await context.Functions.GetAvailableFunctionsAsync(config, semanticQuery, null, cancellationToken);

// Assert
Assert.NotNull(result);
memory.Verify(
x => x.SearchAsync(It.IsAny<string>(), It.IsAny<string>(), It.IsAny<int>(), It.IsAny<double>(), It.IsAny<bool>(), It.IsAny<CancellationToken>()),
Times.Never);

config.SemanticMemoryConfig = new() { Memory = memory.Object };

// Act
result = await context.Functions.GetAvailableFunctionsAsync(config, semanticQuery, null, cancellationToken);

// Assert
Assert.NotNull(result);
memory.Verify(
x => x.SearchAsync(It.IsAny<string>(), It.IsAny<string>(), It.IsAny<int>(), It.IsAny<double>(), It.IsAny<bool>(), It.IsAny<CancellationToken>()),
Times.Once);
}

[Fact]
public async Task CanCallGetAvailableFunctionsWithFunctionsAsync()
[Theory]
[InlineData(typeof(ActionPlannerConfig))]
[InlineData(typeof(SequentialPlannerConfig))]
[InlineData(typeof(StepwisePlannerConfig))]
public async Task CanCallGetAvailableFunctionsWithFunctionsAsync(Type t)
{
// Arrange
var kernel = new Mock<IKernel>();
Expand Down Expand Up @@ -92,22 +127,23 @@ public async Task CanCallGetAvailableFunctionsWithFunctionsAsync()

// Arrange GetAvailableFunctionsAsync parameters
var context = new SKContext(kernel.Object, variables, functions.Object);
var config = new SequentialPlannerConfig() { Memory = memory.Object };
var config = InitializeConfig(t);
var semanticQuery = "test";

// Act
var result = (await context.GetAvailableFunctionsAsync(config, semanticQuery, cancellationToken)).ToList();
var result = (await context.Functions.GetAvailableFunctionsAsync(config, semanticQuery, null, cancellationToken)).ToList();

// Assert
Assert.NotNull(result);
Assert.Equal(2, result.Count);
Assert.Equal(functionView, result[0]);

// Arrange update IncludedFunctions
config.IncludedFunctions.UnionWith(new List<(string, string)> { ("pluginName", "nativeFunctionName") });
config.SemanticMemoryConfig = new() { Memory = memory.Object };
config.SemanticMemoryConfig.IncludedFunctions.UnionWith(new List<(string, string)> { ("pluginName", "nativeFunctionName") });

// Act
result = (await context.GetAvailableFunctionsAsync(config, semanticQuery)).ToList();
result = (await context.Functions.GetAvailableFunctionsAsync(config, semanticQuery)).ToList();

// Assert
Assert.NotNull(result);
Expand All @@ -116,8 +152,11 @@ public async Task CanCallGetAvailableFunctionsWithFunctionsAsync()
Assert.Equal(nativeFunctionView, result[1]);
}

[Fact]
public async Task CanCallGetAvailableFunctionsWithFunctionsWithRelevancyAsync()
[Theory]
[InlineData(typeof(ActionPlannerConfig))]
[InlineData(typeof(SequentialPlannerConfig))]
[InlineData(typeof(StepwisePlannerConfig))]
public async Task CanCallGetAvailableFunctionsWithFunctionsWithRelevancyAsync(Type t)
{
// Arrange
var kernel = new Mock<IKernel>();
Expand Down Expand Up @@ -157,22 +196,23 @@ public async Task CanCallGetAvailableFunctionsWithFunctionsWithRelevancyAsync()

// Arrange GetAvailableFunctionsAsync parameters
var context = new SKContext(kernel.Object, variables, functions.Object);
var config = new SequentialPlannerConfig { RelevancyThreshold = 0.78, Memory = memory.Object };
var config = InitializeConfig(t);
config.SemanticMemoryConfig = new() { RelevancyThreshold = 0.78, Memory = memory.Object };
var semanticQuery = "test";

// Act
var result = (await context.GetAvailableFunctionsAsync(config, semanticQuery, cancellationToken)).ToList();
var result = (await context.Functions.GetAvailableFunctionsAsync(config, semanticQuery, null, cancellationToken)).ToList();

// Assert
Assert.NotNull(result);
Assert.Single(result);
Assert.Equal(functionView, result[0]);

// Arrange update IncludedFunctions
config.IncludedFunctions.UnionWith(new List<(string, string)> { ("pluginName", "nativeFunctionName") });
config.SemanticMemoryConfig.IncludedFunctions.UnionWith(new List<(string, string)> { ("pluginName", "nativeFunctionName") });

// Act
result = (await context.GetAvailableFunctionsAsync(config, semanticQuery)).ToList();
result = (await context.Functions.GetAvailableFunctionsAsync(config, semanticQuery)).ToList();

// Assert
Assert.NotNull(result);
Expand All @@ -181,8 +221,11 @@ public async Task CanCallGetAvailableFunctionsWithFunctionsWithRelevancyAsync()
Assert.Equal(nativeFunctionView, result[1]);
}

[Fact]
public async Task CanCallGetAvailableFunctionsAsyncWithDefaultRelevancyAsync()
[Theory]
[InlineData(typeof(ActionPlannerConfig))]
[InlineData(typeof(SequentialPlannerConfig))]
[InlineData(typeof(StepwisePlannerConfig))]
public async Task CanCallGetAvailableFunctionsAsyncWithDefaultRelevancyAsync(Type t)
{
// Arrange
var kernel = new Mock<IKernel>();
Expand Down Expand Up @@ -210,11 +253,12 @@ public async Task CanCallGetAvailableFunctionsAsyncWithDefaultRelevancyAsync()

// Arrange GetAvailableFunctionsAsync parameters
var context = new SKContext(kernel.Object, variables, functions);
var config = new SequentialPlannerConfig { RelevancyThreshold = 0.78, Memory = memory.Object };
var config = InitializeConfig(t);
config.SemanticMemoryConfig = new() { RelevancyThreshold = 0.78, Memory = memory.Object };
var semanticQuery = "test";

// Act
var result = await context.GetAvailableFunctionsAsync(config, semanticQuery, cancellationToken);
var result = await context.Functions.GetAvailableFunctionsAsync(config, semanticQuery, null, cancellationToken);

// Assert
Assert.NotNull(result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ public void CanCallToPlanFromXml()
var goal = "Summarize an input, translate to french, and e-mail to John Doe";

// Act
var plan = planString.ToPlanFromXml(goal, SequentialPlanParser.GetFunctionCallback(kernel.Functions));
var plan = planString.ToPlanFromXml(goal, kernel.Functions.GetFunctionCallback());

// Assert
Assert.NotNull(plan);
Expand Down Expand Up @@ -167,7 +167,7 @@ public void InvalidPlanExecutePlanReturnsInvalidResult()
var planString = "<someTag>";

// Act
Assert.Throws<SKException>(() => planString.ToPlanFromXml(GoalText, SequentialPlanParser.GetFunctionCallback(kernel.Functions)));
Assert.Throws<SKException>(() => planString.ToPlanFromXml(GoalText, kernel.Functions.GetFunctionCallback()));
}

// Test that contains a #text node in the plan
Expand All @@ -187,7 +187,7 @@ public void CanCreatePlanWithTextNodes(string goalText, string planText)
this.CreateKernelAndFunctionCreateMocks(functions, out var kernel);

// Act
var plan = planText.ToPlanFromXml(goalText, SequentialPlanParser.GetFunctionCallback(kernel.Functions));
var plan = planText.ToPlanFromXml(goalText, kernel.Functions.GetFunctionCallback());

// Assert
Assert.NotNull(plan);
Expand All @@ -211,7 +211,7 @@ public void CanCreatePlanWithPartialXml(string goalText, string planText)
this.CreateKernelAndFunctionCreateMocks(functions, out var kernel);

// Act
var plan = planText.ToPlanFromXml(goalText, SequentialPlanParser.GetFunctionCallback(kernel.Functions));
var plan = planText.ToPlanFromXml(goalText, kernel.Functions.GetFunctionCallback());

// Assert
Assert.NotNull(plan);
Expand All @@ -236,7 +236,7 @@ public void CanCreatePlanWithFunctionName(string goalText, string planText)
this.CreateKernelAndFunctionCreateMocks(functions, out var kernel);

// Act
var plan = planText.ToPlanFromXml(goalText, SequentialPlanParser.GetFunctionCallback(kernel.Functions));
var plan = planText.ToPlanFromXml(goalText, kernel.Functions.GetFunctionCallback());

// Assert
Assert.NotNull(plan);
Expand Down Expand Up @@ -271,7 +271,7 @@ public void CanCreatePlanWithInvalidFunctionNodes(string planText, bool allowMis
if (allowMissingFunctions)
{
// it should not throw
var plan = planText.ToPlanFromXml(string.Empty, SequentialPlanParser.GetFunctionCallback(kernel.Functions), allowMissingFunctions);
var plan = planText.ToPlanFromXml(string.Empty, kernel.Functions.GetFunctionCallback(), allowMissingFunctions);

// Assert
Assert.NotNull(plan);
Expand All @@ -287,7 +287,7 @@ public void CanCreatePlanWithInvalidFunctionNodes(string planText, bool allowMis
}
else
{
Assert.Throws<SKException>(() => planText.ToPlanFromXml(string.Empty, SequentialPlanParser.GetFunctionCallback(kernel.Functions), allowMissingFunctions));
Assert.Throws<SKException>(() => planText.ToPlanFromXml(string.Empty, kernel.Functions.GetFunctionCallback(), allowMissingFunctions));
}
}

Expand Down Expand Up @@ -321,7 +321,7 @@ public void CanCreatePlanWithOtherText(string goalText, string planText)
this.CreateKernelAndFunctionCreateMocks(functions, out var kernel);

// Act
var plan = planText.ToPlanFromXml(goalText, SequentialPlanParser.GetFunctionCallback(kernel.Functions));
var plan = planText.ToPlanFromXml(goalText, kernel.Functions.GetFunctionCallback());

// Assert
Assert.NotNull(plan);
Expand All @@ -345,7 +345,7 @@ public void CanCreatePlanWithOpenApiPlugin(string planText)
this.CreateKernelAndFunctionCreateMocks(functions, out var kernel);

// Act
var plan = planText.ToPlanFromXml(string.Empty, SequentialPlanParser.GetFunctionCallback(kernel.Functions));
var plan = planText.ToPlanFromXml(string.Empty, kernel.Functions.GetFunctionCallback());

// Assert
Assert.NotNull(plan);
Expand All @@ -371,7 +371,7 @@ public void CanCreatePlanWithIgnoredNodes(string goalText, string planText)
this.CreateKernelAndFunctionCreateMocks(functions, out var kernel);

// Act
var plan = planText.ToPlanFromXml(goalText, SequentialPlanParser.GetFunctionCallback(kernel.Functions));
var plan = planText.ToPlanFromXml(goalText, kernel.Functions.GetFunctionCallback());

// Assert
Assert.NotNull(plan);
Expand Down

0 comments on commit c18dbfb

Please sign in to comment.