diff --git a/src/WebJobs.Script/Workers/Rpc/FunctionRegistration/RpcFunctionInvocationDispatcher.cs b/src/WebJobs.Script/Workers/Rpc/FunctionRegistration/RpcFunctionInvocationDispatcher.cs index ab48b8f8cc..a4de3c1012 100644 --- a/src/WebJobs.Script/Workers/Rpc/FunctionRegistration/RpcFunctionInvocationDispatcher.cs +++ b/src/WebJobs.Script/Workers/Rpc/FunctionRegistration/RpcFunctionInvocationDispatcher.cs @@ -10,6 +10,7 @@ using System.Threading.Tasks; using System.Threading.Tasks.Dataflow; using Microsoft.AspNetCore.Hosting; +using Microsoft.Azure.WebJobs.Host.Executors.Internal; using Microsoft.Azure.WebJobs.Logging; using Microsoft.Azure.WebJobs.Script.Config; using Microsoft.Azure.WebJobs.Script.Description; @@ -403,6 +404,9 @@ public async Task ShutdownAsync() public async Task InvokeAsync(ScriptInvocationContext invocationContext) { + // We have entered back into a system scope, ensure our logs are captured as such. + using FunctionInvoker.Scope scope = FunctionInvoker.BeginSystemScope(); + // This could throw if no initialized workers are found. Shut down instance and retry. IEnumerable workerChannels = await GetInitializedWorkerChannelsAsync(invocationContext.FunctionMetadata.Language ?? _workerRuntime); var rpcWorkerChannel = _functionDispatcherLoadBalancer.GetLanguageWorkerChannel(workerChannels); diff --git a/test/WebJobs.Script.Tests/Workers/Rpc/RpcFunctionInvocationDispatcherTests.cs b/test/WebJobs.Script.Tests/Workers/Rpc/RpcFunctionInvocationDispatcherTests.cs index 6d1d72d4dc..95aec57d2b 100644 --- a/test/WebJobs.Script.Tests/Workers/Rpc/RpcFunctionInvocationDispatcherTests.cs +++ b/test/WebJobs.Script.Tests/Workers/Rpc/RpcFunctionInvocationDispatcherTests.cs @@ -7,7 +7,9 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; +using System.Threading.Tasks.Dataflow; using Microsoft.AspNetCore.Hosting; +using Microsoft.Azure.WebJobs.Host.Executors.Internal; using Microsoft.Azure.WebJobs.Script.Config; using Microsoft.Azure.WebJobs.Script.Description; using Microsoft.Azure.WebJobs.Script.Diagnostics; @@ -631,8 +633,92 @@ public async Task FunctionDispatcher_ErroredWebHostChannel() Assert.True(testLogs.Any(m => m.FormattedMessage.Contains("Removing errored webhost language worker channel for runtime"))); } - private static RpcFunctionInvocationDispatcher GetTestFunctionDispatcher(int maxProcessCountValue = 1, bool addWebhostChannel = false, - Mock mockwebHostLanguageWorkerChannelManager = null, bool throwOnProcessStartUp = false, TimeSpan? startupIntervals = null, string runtime = null, bool workerIndexing = false, bool placeholder = false) + [Fact] + public async Task FunctionDispatcher_InvokeAsync_SystemScope() + { + FunctionMetadata func1 = new FunctionMetadata() + { + Name = "func1", + Language = "node" + }; + var functions = new List() + { + func1 + }; + + ScriptInvocationContext context = new() + { + FunctionMetadata = func1, + ExecutionContext = new() + { + InvocationId = Guid.NewGuid(), + FunctionName = "func1", + }, + ResultSource = new(), + CancellationToken = default, + Logger = _testLogger, + AsyncExecutionContext = System.Threading.ExecutionContext.Capture(), + }; + + BufferBlock inputBuffer = new(); + ActionBlock actionBlock = new(context => + { + try + { + Assert.Equal(FunctionInvocationScope.System, FunctionInvoker.CurrentScope); + context.ResultSource.TrySetResult(null); + } + catch (Exception ex) + { + context.ResultSource.TrySetException(ex); + } + }); + + inputBuffer.LinkTo(actionBlock); + + Mock>> mockBufferBlocks = new(); + mockBufferBlocks.Setup(m => m.TryGetValue(It.IsAny(), out inputBuffer)).Returns(true); + + Mock mockChannel = new(); + mockChannel.Setup(m => m.FunctionInputBuffers).Returns(mockBufferBlocks.Object); + + IRpcWorkerChannel Create( + string workerRuntime, string language, IMetricsLogger metricsLogger, int attemptCount, IEnumerable workerConfigs) + { + var workerConfig = workerConfigs.SingleOrDefault(p => language.Equals(p.Description.Language, StringComparison.OrdinalIgnoreCase)); + return new TestRpcWorkerChannel( + Guid.NewGuid().ToString(), language, null, _testLogger, false, workerConfig: workerConfig) + { + FunctionInputBuffers = mockBufferBlocks.Object, + }; + } + + Mock mockChannelFactory = new(); + mockChannelFactory.Setup( + m => m.Create(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny>())) + .Returns(Create); + + RpcFunctionInvocationDispatcher functionDispatcher = GetTestFunctionDispatcher( + runtime: RpcWorkerConstants.NodeLanguageWorkerName, channelFactory: mockChannelFactory.Object); + + await functionDispatcher.InitializeAsync(functions); + await WaitForFunctionDispactherStateInitialized(functionDispatcher); + using FunctionInvoker.Scope scope = FunctionInvoker.BeginUserScope(); + await functionDispatcher.InvokeAsync(context); + await context.ResultSource.Task; + Assert.Equal(FunctionInvocationScope.User, FunctionInvoker.CurrentScope); + } + + private static RpcFunctionInvocationDispatcher GetTestFunctionDispatcher( + int maxProcessCountValue = 1, + bool addWebhostChannel = false, + Mock mockwebHostLanguageWorkerChannelManager = null, + bool throwOnProcessStartUp = false, + TimeSpan? startupIntervals = null, + string runtime = null, + bool workerIndexing = false, + bool placeholder = false, + IRpcWorkerChannelFactory channelFactory = null) { var eventManager = new ScriptEventManager(); var metricsLogger = new Mock(); @@ -667,8 +753,10 @@ private static RpcFunctionInvocationDispatcher GetTestFunctionDispatcher(int max WorkerConfigs = TestHelpers.GetTestWorkerConfigs(processCountValue: maxProcessCountValue, processStartupInterval: intervals, processRestartInterval: intervals, processShutdownTimeout: TimeSpan.FromSeconds(1), workerIndexing: workerIndexing) }; - IRpcWorkerChannelFactory testLanguageWorkerChannelFactory = new TestRpcWorkerChannelFactory(eventManager, _testLogger, scriptOptions.Value.RootScriptPath, throwOnProcessStartUp); - IWebHostRpcWorkerChannelManager testWebHostLanguageWorkerChannelManager = new TestRpcWorkerChannelManager(eventManager, _testLogger, scriptOptions.Value.RootScriptPath, testLanguageWorkerChannelFactory); + + channelFactory ??= new TestRpcWorkerChannelFactory(eventManager, _testLogger, scriptOptions.Value.RootScriptPath, throwOnProcessStartUp); + IWebHostRpcWorkerChannelManager testWebHostLanguageWorkerChannelManager = new TestRpcWorkerChannelManager( + eventManager, _testLogger, scriptOptions.Value.RootScriptPath, channelFactory); IJobHostRpcWorkerChannelManager jobHostLanguageWorkerChannelManager = new JobHostRpcWorkerChannelManager(_testLoggerFactory); if (addWebhostChannel) @@ -681,6 +769,8 @@ private static RpcFunctionInvocationDispatcher GetTestFunctionDispatcher(int max } var mockFunctionDispatcherLoadBalancer = new Mock(); + mockFunctionDispatcherLoadBalancer.Setup(m => m.GetLanguageWorkerChannel(It.IsAny>())) + .Returns((IEnumerable channels) => channels.FirstOrDefault()); var mockHostMetrics = new Mock(); _javaTestChannel = new TestRpcWorkerChannel(Guid.NewGuid().ToString(), "java", eventManager, _testLogger, false); @@ -693,7 +783,7 @@ private static RpcFunctionInvocationDispatcher GetTestFunctionDispatcher(int max mockApplicationLifetime.Object, eventManager, _testLoggerFactory, - testLanguageWorkerChannelFactory, + channelFactory, optionsMonitor, testWebHostLanguageWorkerChannelManager, jobHostLanguageWorkerChannelManager, diff --git a/test/WebJobs.Script.Tests/Workers/Rpc/TestRpcWorkerChannel.cs b/test/WebJobs.Script.Tests/Workers/Rpc/TestRpcWorkerChannel.cs index c7b6145565..c1244380d4 100644 --- a/test/WebJobs.Script.Tests/Workers/Rpc/TestRpcWorkerChannel.cs +++ b/test/WebJobs.Script.Tests/Workers/Rpc/TestRpcWorkerChannel.cs @@ -54,7 +54,7 @@ public TestRpcWorkerChannel(string workerId, string runtime = null, IScriptEvent public RpcWorkerConfig WorkerConfig => _workerConfig; - public IDictionary> FunctionInputBuffers => throw new NotImplementedException(); + public IDictionary> FunctionInputBuffers { get; set; } public List ExecutionContexts => _executionContexts;