Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using System.Threading.Tasks;
using System.Threading.Tasks.Dataflow;
using Microsoft.AspNetCore.Hosting;
using Microsoft.Azure.WebJobs.Host.Executors.Internal;
using Microsoft.Azure.WebJobs.Logging;
using Microsoft.Azure.WebJobs.Script.Config;
using Microsoft.Azure.WebJobs.Script.Description;
Expand Down Expand Up @@ -403,6 +404,9 @@ public async Task ShutdownAsync()

public async Task InvokeAsync(ScriptInvocationContext invocationContext)
{
// We have entered back into a system scope, ensure our logs are captured as such.
using FunctionInvoker.Scope scope = FunctionInvoker.BeginSystemScope();

// This could throw if no initialized workers are found. Shut down instance and retry.
IEnumerable<IRpcWorkerChannel> workerChannels = await GetInitializedWorkerChannelsAsync(invocationContext.FunctionMetadata.Language ?? _workerRuntime);
var rpcWorkerChannel = _functionDispatcherLoadBalancer.GetLanguageWorkerChannel(workerChannels);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Dataflow;
using Microsoft.AspNetCore.Hosting;
using Microsoft.Azure.WebJobs.Host.Executors.Internal;
using Microsoft.Azure.WebJobs.Script.Config;
using Microsoft.Azure.WebJobs.Script.Description;
using Microsoft.Azure.WebJobs.Script.Diagnostics;
Expand Down Expand Up @@ -631,8 +633,92 @@ public async Task FunctionDispatcher_ErroredWebHostChannel()
Assert.True(testLogs.Any(m => m.FormattedMessage.Contains("Removing errored webhost language worker channel for runtime")));
}

private static RpcFunctionInvocationDispatcher GetTestFunctionDispatcher(int maxProcessCountValue = 1, bool addWebhostChannel = false,
Mock<IWebHostRpcWorkerChannelManager> mockwebHostLanguageWorkerChannelManager = null, bool throwOnProcessStartUp = false, TimeSpan? startupIntervals = null, string runtime = null, bool workerIndexing = false, bool placeholder = false)
[Fact]
public async Task FunctionDispatcher_InvokeAsync_SystemScope()
{
FunctionMetadata func1 = new FunctionMetadata()
{
Name = "func1",
Language = "node"
};
var functions = new List<FunctionMetadata>()
{
func1
};

ScriptInvocationContext context = new()
{
FunctionMetadata = func1,
ExecutionContext = new()
{
InvocationId = Guid.NewGuid(),
FunctionName = "func1",
},
ResultSource = new(),
CancellationToken = default,
Logger = _testLogger,
AsyncExecutionContext = System.Threading.ExecutionContext.Capture(),
};

BufferBlock<ScriptInvocationContext> inputBuffer = new();
ActionBlock<ScriptInvocationContext> actionBlock = new(context =>
{
try
{
Assert.Equal(FunctionInvocationScope.System, FunctionInvoker.CurrentScope);
context.ResultSource.TrySetResult(null);
}
catch (Exception ex)
{
context.ResultSource.TrySetException(ex);
}
});

inputBuffer.LinkTo(actionBlock);

Mock<IDictionary<string, BufferBlock<ScriptInvocationContext>>> mockBufferBlocks = new();
mockBufferBlocks.Setup(m => m.TryGetValue(It.IsAny<string>(), out inputBuffer)).Returns(true);

Mock<IRpcWorkerChannel> mockChannel = new();
mockChannel.Setup(m => m.FunctionInputBuffers).Returns(mockBufferBlocks.Object);

IRpcWorkerChannel Create(
string workerRuntime, string language, IMetricsLogger metricsLogger, int attemptCount, IEnumerable<RpcWorkerConfig> workerConfigs)
{
var workerConfig = workerConfigs.SingleOrDefault(p => language.Equals(p.Description.Language, StringComparison.OrdinalIgnoreCase));
return new TestRpcWorkerChannel(
Guid.NewGuid().ToString(), language, null, _testLogger, false, workerConfig: workerConfig)
{
FunctionInputBuffers = mockBufferBlocks.Object,
};
}

Mock<IRpcWorkerChannelFactory> mockChannelFactory = new();
mockChannelFactory.Setup(
m => m.Create(It.IsAny<string>(), It.IsAny<string>(), It.IsAny<IMetricsLogger>(), It.IsAny<int>(), It.IsAny<IEnumerable<RpcWorkerConfig>>()))
.Returns(Create);

RpcFunctionInvocationDispatcher functionDispatcher = GetTestFunctionDispatcher(
runtime: RpcWorkerConstants.NodeLanguageWorkerName, channelFactory: mockChannelFactory.Object);

await functionDispatcher.InitializeAsync(functions);
await WaitForFunctionDispactherStateInitialized(functionDispatcher);
using FunctionInvoker.Scope scope = FunctionInvoker.BeginUserScope();
await functionDispatcher.InvokeAsync(context);
await context.ResultSource.Task;
Assert.Equal(FunctionInvocationScope.User, FunctionInvoker.CurrentScope);
}

private static RpcFunctionInvocationDispatcher GetTestFunctionDispatcher(
int maxProcessCountValue = 1,
bool addWebhostChannel = false,
Mock<IWebHostRpcWorkerChannelManager> mockwebHostLanguageWorkerChannelManager = null,
bool throwOnProcessStartUp = false,
TimeSpan? startupIntervals = null,
string runtime = null,
bool workerIndexing = false,
bool placeholder = false,
IRpcWorkerChannelFactory channelFactory = null)
{
var eventManager = new ScriptEventManager();
var metricsLogger = new Mock<IMetricsLogger>();
Expand Down Expand Up @@ -667,8 +753,10 @@ private static RpcFunctionInvocationDispatcher GetTestFunctionDispatcher(int max
WorkerConfigs = TestHelpers.GetTestWorkerConfigs(processCountValue: maxProcessCountValue, processStartupInterval: intervals,
processRestartInterval: intervals, processShutdownTimeout: TimeSpan.FromSeconds(1), workerIndexing: workerIndexing)
};
IRpcWorkerChannelFactory testLanguageWorkerChannelFactory = new TestRpcWorkerChannelFactory(eventManager, _testLogger, scriptOptions.Value.RootScriptPath, throwOnProcessStartUp);
IWebHostRpcWorkerChannelManager testWebHostLanguageWorkerChannelManager = new TestRpcWorkerChannelManager(eventManager, _testLogger, scriptOptions.Value.RootScriptPath, testLanguageWorkerChannelFactory);

channelFactory ??= new TestRpcWorkerChannelFactory(eventManager, _testLogger, scriptOptions.Value.RootScriptPath, throwOnProcessStartUp);
IWebHostRpcWorkerChannelManager testWebHostLanguageWorkerChannelManager = new TestRpcWorkerChannelManager(
eventManager, _testLogger, scriptOptions.Value.RootScriptPath, channelFactory);
IJobHostRpcWorkerChannelManager jobHostLanguageWorkerChannelManager = new JobHostRpcWorkerChannelManager(_testLoggerFactory);

if (addWebhostChannel)
Expand All @@ -681,6 +769,8 @@ private static RpcFunctionInvocationDispatcher GetTestFunctionDispatcher(int max
}

var mockFunctionDispatcherLoadBalancer = new Mock<IRpcFunctionInvocationDispatcherLoadBalancer>();
mockFunctionDispatcherLoadBalancer.Setup(m => m.GetLanguageWorkerChannel(It.IsAny<IEnumerable<IRpcWorkerChannel>>()))
.Returns((IEnumerable<IRpcWorkerChannel> channels) => channels.FirstOrDefault());
var mockHostMetrics = new Mock<IHostMetrics>();

_javaTestChannel = new TestRpcWorkerChannel(Guid.NewGuid().ToString(), "java", eventManager, _testLogger, false);
Expand All @@ -693,7 +783,7 @@ private static RpcFunctionInvocationDispatcher GetTestFunctionDispatcher(int max
mockApplicationLifetime.Object,
eventManager,
_testLoggerFactory,
testLanguageWorkerChannelFactory,
channelFactory,
optionsMonitor,
testWebHostLanguageWorkerChannelManager,
jobHostLanguageWorkerChannelManager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public TestRpcWorkerChannel(string workerId, string runtime = null, IScriptEvent

public RpcWorkerConfig WorkerConfig => _workerConfig;

public IDictionary<string, BufferBlock<ScriptInvocationContext>> FunctionInputBuffers => throw new NotImplementedException();
public IDictionary<string, BufferBlock<ScriptInvocationContext>> FunctionInputBuffers { get; set; }

public List<Task> ExecutionContexts => _executionContexts;

Expand Down