Skip to content

Commit

Permalink
fixing race condition in worker shutdown (#9738)
Browse files Browse the repository at this point in the history
  • Loading branch information
brettsam authored and v-imohammad committed Dec 13, 2023
1 parent b85b598 commit 9289209
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ internal async void ShutdownWebhostLanguageWorkerChannels()
await _webHostLanguageWorkerChannelManager?.ShutdownChannelsAsync();
}

private void SetDispatcherStateToInitialized(Dictionary<string, TaskCompletionSource<IRpcWorkerChannel>> webhostLanguageWorkerChannel = null)
private void SetDispatcherStateToInitialized(IDictionary<string, TaskCompletionSource<IRpcWorkerChannel>> webhostLanguageWorkerChannel = null)
{
// RanToCompletion indicates successful process startup
if (State != FunctionInvocationDispatcherState.Initialized
Expand All @@ -198,7 +198,7 @@ private void SetDispatcherStateToInitialized(Dictionary<string, TaskCompletionSo
}
}

private void StartWorkerProcesses(int startIndex, Func<IEnumerable<string>, Task> startAction, bool initializeDispatcher = false, Dictionary<string, TaskCompletionSource<IRpcWorkerChannel>> webhostLanguageWorkerChannel = null, IEnumerable<string> functionLanguages = null)
private void StartWorkerProcesses(int startIndex, Func<IEnumerable<string>, Task> startAction, bool initializeDispatcher = false, IDictionary<string, TaskCompletionSource<IRpcWorkerChannel>> webhostLanguageWorkerChannel = null, IEnumerable<string> functionLanguages = null)
{
Task.Run(async () =>
{
Expand Down Expand Up @@ -309,7 +309,7 @@ public async Task InitializeAsync(IEnumerable<FunctionMetadata> functions, Cance
if (Utility.IsSupportedRuntime(_workerRuntime, _workerConfigs) || _environment.IsMultiLanguageRuntimeEnvironment())
{
State = FunctionInvocationDispatcherState.Initializing;
Dictionary<string, TaskCompletionSource<IRpcWorkerChannel>> webhostLanguageWorkerChannels = _webHostLanguageWorkerChannelManager.GetChannels(_workerRuntime);
IDictionary<string, TaskCompletionSource<IRpcWorkerChannel>> webhostLanguageWorkerChannels = _webHostLanguageWorkerChannelManager.GetChannels(_workerRuntime);
if (webhostLanguageWorkerChannels != null)
{
int workerProcessCount = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ public interface IWebHostRpcWorkerChannelManager
{
Task<IRpcWorkerChannel> InitializeChannelAsync(IEnumerable<RpcWorkerConfig> workerConfigs, string language);

Dictionary<string, TaskCompletionSource<IRpcWorkerChannel>> GetChannels(string language);
IDictionary<string, TaskCompletionSource<IRpcWorkerChannel>> GetChannels(string language);

Task SpecializeAsync();

Expand Down
28 changes: 13 additions & 15 deletions src/WebJobs.Script/Workers/Rpc/WebHostRpcWorkerChannelManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
using System.Linq;
using System.Reactive.Linq;
using System.Threading.Tasks;
using Microsoft.Azure.AppService.Proxy.Common.Extensions;
using Microsoft.Azure.AppService.Proxy.Common.Infra;
using Microsoft.Azure.AppService.Proxy.Runtime;
using Microsoft.Azure.WebJobs.Script.Config;
using Microsoft.Azure.WebJobs.Script.Diagnostics;
using Microsoft.Azure.WebJobs.Script.Eventing;
Expand All @@ -36,7 +34,7 @@ public class WebHostRpcWorkerChannelManager : IWebHostRpcWorkerChannelManager
private Action _shutdownStandbyWorkerChannels;
private IConfiguration _config;

private ConcurrentDictionary<string, Dictionary<string, TaskCompletionSource<IRpcWorkerChannel>>> _workerChannels = new ConcurrentDictionary<string, Dictionary<string, TaskCompletionSource<IRpcWorkerChannel>>>(StringComparer.OrdinalIgnoreCase);
private ConcurrentDictionary<string, ConcurrentDictionary<string, TaskCompletionSource<IRpcWorkerChannel>>> _workerChannels = new(StringComparer.OrdinalIgnoreCase);

public WebHostRpcWorkerChannelManager(IScriptEventManager eventManager,
IEnvironment environment,
Expand Down Expand Up @@ -101,7 +99,7 @@ internal async Task<IRpcWorkerChannel> InitializeLanguageWorkerChannel(IEnumerab

internal Task<IRpcWorkerChannel> GetChannelAsync(string language)
{
if (!string.IsNullOrEmpty(language) && _workerChannels.TryGetValue(language, out Dictionary<string, TaskCompletionSource<IRpcWorkerChannel>> workerChannels))
if (!string.IsNullOrEmpty(language) && _workerChannels.TryGetValue(language, out ConcurrentDictionary<string, TaskCompletionSource<IRpcWorkerChannel>> workerChannels))
{
if (workerChannels.Count > 0 && workerChannels.TryGetValue(workerChannels.Keys.First(), out TaskCompletionSource<IRpcWorkerChannel> valueTask))
{
Expand All @@ -111,9 +109,9 @@ internal Task<IRpcWorkerChannel> GetChannelAsync(string language)
return Task.FromResult<IRpcWorkerChannel>(null);
}

public Dictionary<string, TaskCompletionSource<IRpcWorkerChannel>> GetChannels(string language)
public IDictionary<string, TaskCompletionSource<IRpcWorkerChannel>> GetChannels(string language)
{
if (!string.IsNullOrEmpty(language) && _workerChannels.TryGetValue(language, out Dictionary<string, TaskCompletionSource<IRpcWorkerChannel>> workerChannels))
if (!string.IsNullOrEmpty(language) && _workerChannels.TryGetValue(language, out ConcurrentDictionary<string, TaskCompletionSource<IRpcWorkerChannel>> workerChannels))
{
return workerChannels;
}
Expand Down Expand Up @@ -237,7 +235,7 @@ public Task<bool> ShutdownChannelIfExistsAsync(string language, string workerId,

if (_hostingConfigOptions.Value.RevertWorkerShutdownBehaviour)
{
if (_workerChannels.TryRemove(language, out Dictionary<string, TaskCompletionSource<IRpcWorkerChannel>> rpcWorkerChannels))
if (_workerChannels.TryRemove(language, out ConcurrentDictionary<string, TaskCompletionSource<IRpcWorkerChannel>> rpcWorkerChannels))
{
if (rpcWorkerChannels.TryGetValue(workerId, out TaskCompletionSource<IRpcWorkerChannel> value))
{
Expand All @@ -264,7 +262,7 @@ public Task<bool> ShutdownChannelIfExistsAsync(string language, string workerId,
}
else
{
if (_workerChannels.TryGetValue(language, out Dictionary<string, TaskCompletionSource<IRpcWorkerChannel>> rpcWorkerChannels)
if (_workerChannels.TryGetValue(language, out ConcurrentDictionary<string, TaskCompletionSource<IRpcWorkerChannel>> rpcWorkerChannels)
&& rpcWorkerChannels.TryRemove(workerId, out TaskCompletionSource<IRpcWorkerChannel> value))
{
value?.Task.ContinueWith(channelTask =>
Expand Down Expand Up @@ -304,7 +302,7 @@ internal void ScheduleShutdownStandbyChannels()
using (_metricsLogger.LatencyEvent(string.Format(MetricEventNames.SpecializationShutdownStandbyChannels, runtime.Key)))
{
_logger.LogInformation("Disposing standby channel for runtime:{language}", runtime.Key);
if (_workerChannels.TryRemove(runtime.Key, out Dictionary<string, TaskCompletionSource<IRpcWorkerChannel>> standbyChannels))
if (_workerChannels.TryRemove(runtime.Key, out ConcurrentDictionary<string, TaskCompletionSource<IRpcWorkerChannel>> standbyChannels))
{
foreach (string workerId in standbyChannels.Keys)
{
Expand Down Expand Up @@ -338,7 +336,7 @@ public async Task ShutdownChannelsAsync()
foreach (string runtime in _workerChannels.Keys)
{
_logger.LogInformation("Shutting down language worker channels for runtime:{runtime}", runtime);
if (_workerChannels.TryRemove(runtime, out Dictionary<string, TaskCompletionSource<IRpcWorkerChannel>> standbyChannels))
if (_workerChannels.TryRemove(runtime, out ConcurrentDictionary<string, TaskCompletionSource<IRpcWorkerChannel>> standbyChannels))
{
foreach (string workerId in standbyChannels.Keys)
{
Expand Down Expand Up @@ -378,21 +376,21 @@ internal void AddOrUpdateWorkerChannels(string initializedRuntime, IRpcWorkerCha
_workerChannels.AddOrUpdate(initializedRuntime,
(runtime) =>
{
Dictionary<string, TaskCompletionSource<IRpcWorkerChannel>> newLanguageWorkerChannels = new Dictionary<string, TaskCompletionSource<IRpcWorkerChannel>>();
newLanguageWorkerChannels.Add(initializedLanguageWorkerChannel.Id, new TaskCompletionSource<IRpcWorkerChannel>());
ConcurrentDictionary<string, TaskCompletionSource<IRpcWorkerChannel>> newLanguageWorkerChannels = new(StringComparer.OrdinalIgnoreCase);
newLanguageWorkerChannels.TryAdd(initializedLanguageWorkerChannel.Id, new TaskCompletionSource<IRpcWorkerChannel>());
return newLanguageWorkerChannels;
},
(runtime, existingLanguageWorkerChannels) =>
{
existingLanguageWorkerChannels.Add(initializedLanguageWorkerChannel.Id, new TaskCompletionSource<IRpcWorkerChannel>());
existingLanguageWorkerChannels.TryAdd(initializedLanguageWorkerChannel.Id, new TaskCompletionSource<IRpcWorkerChannel>());
return existingLanguageWorkerChannels;
});
}

internal void SetInitializedWorkerChannel(string initializedRuntime, IRpcWorkerChannel initializedLanguageWorkerChannel)
{
_logger.LogDebug("Adding webhost language worker channel for runtime: {language}. workerId:{id}", initializedRuntime, initializedLanguageWorkerChannel.Id);
if (_workerChannels.TryGetValue(initializedRuntime, out Dictionary<string, TaskCompletionSource<IRpcWorkerChannel>> channel))
if (_workerChannels.TryGetValue(initializedRuntime, out ConcurrentDictionary<string, TaskCompletionSource<IRpcWorkerChannel>> channel))
{
if (channel.TryGetValue(initializedLanguageWorkerChannel.Id, out TaskCompletionSource<IRpcWorkerChannel> value))
{
Expand All @@ -404,7 +402,7 @@ internal void SetInitializedWorkerChannel(string initializedRuntime, IRpcWorkerC
internal void SetExceptionOnInitializedWorkerChannel(string initializedRuntime, IRpcWorkerChannel initializedLanguageWorkerChannel, Exception exception)
{
_logger.LogDebug("Failed to initialize webhost language worker channel for runtime: {language}. workerId:{id}", initializedRuntime, initializedLanguageWorkerChannel.Id);
if (_workerChannels.TryGetValue(initializedRuntime, out Dictionary<string, TaskCompletionSource<IRpcWorkerChannel>> channel))
if (_workerChannels.TryGetValue(initializedRuntime, out ConcurrentDictionary<string, TaskCompletionSource<IRpcWorkerChannel>> channel))
{
if (channel.TryGetValue(initializedLanguageWorkerChannel.Id, out TaskCompletionSource<IRpcWorkerChannel> value))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public IRpcWorkerChannel GetChannel(string language)
throw new System.NotImplementedException();
}

public Dictionary<string, TaskCompletionSource<IRpcWorkerChannel>> GetChannels(string language)
public IDictionary<string, TaskCompletionSource<IRpcWorkerChannel>> GetChannels(string language)
{
if (_workerChannels.TryGetValue(language, out Dictionary<string, TaskCompletionSource<IRpcWorkerChannel>> workerChannels))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.WebJobs.Script.Config;
using Microsoft.Azure.WebJobs.Script.Description;
Expand Down Expand Up @@ -407,6 +408,38 @@ public async Task ShutdownChannelsIfExist_Succeeds()
Assert.Null(initializedChannel);
}

[Fact]
public void ShutdownChannelsIfExist_Race_Succeeds()
{
var channel = CreateTestChannel(RpcWorkerConstants.JavaLanguageWorkerName);
string id = channel.Id;

List<Task<bool>> tasks = new();
List<Thread> threads = new();
for (int i = 0; i < 2; i++)
{
Thread t = new(static (state) =>
{
var (channelManager, tasks, id) = ((WebHostRpcWorkerChannelManager, List<Task<bool>>, string))state;
tasks.Add(channelManager.ShutdownChannelIfExistsAsync(RpcWorkerConstants.JavaLanguageWorkerName, id));
});
threads.Add(t);
}

foreach (Thread t in threads)
{
t.Start((_rpcWorkerChannelManager, tasks, id));
}

foreach (Thread t in threads)
{
t.Join();
}

// only one should successfully shut down
Assert.Single(tasks, t => t.Result == true);
}

[Fact]
public async Task ShutdownChannelsIfExistsAsync_StopsWorkerInvocations()
{
Expand Down

0 comments on commit 9289209

Please sign in to comment.