Skip to content

Commit

Permalink
Lazy load dlls #1637
Browse files Browse the repository at this point in the history
This discovers the binding types at runtime and just loads those bindings.
Even though the assembly has a static reference to a dll, that dll is not loaded until its used.
So this removed unused dlls from the runtime working set.

Add whitelist support. We let you #r to a whitelist of "builtin" dlls even if you don't have any binding to them.
See TwilioReferenceInvokeSucceeds test.

It would be good to find a way to remove the static reference.

Updated the ApplicationInsights_Succeeds because it's testing for log messages from extensions; but those extensions are no longer loaded.
  • Loading branch information
MikeStall committed Jul 26, 2017
1 parent 03e7c29 commit 4aa3ad6
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 33 deletions.
175 changes: 147 additions & 28 deletions src/WebJobs.Script/Host/ScriptHost.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,47 @@ public class ScriptHost : JobHost
private FileWatcherEventSource _fileEventSource;
private IDisposable _fileEventsSubscription;

// Specify the "builtin binding types". These are types that are directly accesible without needing an explicit load gesture.
// This is the set of bindings we shipped prior to binding extensibility.
// Map from BindingType to the Assembly Qualified Type name for its IExtensionConfigProvider object.
private static IReadOnlyDictionary<string, string> _builtinBindingTypes = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase)
{
{ "bot", "Microsoft.Azure.WebJobs.Extensions.BotFramework.Config.BotFrameworkConfiguration, Microsoft.Azure.WebJobs.Extensions.BotFramework" },
{ "sendgrid", "Microsoft.Azure.WebJobs.Extensions.SendGrid.SendGridConfiguration, Microsoft.Azure.WebJobs.Extensions.SendGrid" }
};

private static IReadOnlyDictionary<string, string> _builtinScriptBindingTypes = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase)
{
{ "twilioSms", "Microsoft.Azure.WebJobs.Script.Binding.TwilioScriptBindingProvider" },
{ "notificationHub", "Microsoft.Azure.WebJobs.Script.Binding.NotificationHubScriptBindingProvider" },
{ "documentDB", "Microsoft.Azure.WebJobs.Script.Binding.DocumentDBScriptBindingProvider" },
{ "mobileTable", "Microsoft.Azure.WebJobs.Script.Binding.MobileAppsScriptBindingProvider" },
{ "apiHubFileTrigger", "Microsoft.Azure.WebJobs.Script.Binding.ApiHubScriptBindingProvider" },
{ "apiHubFile", "Microsoft.Azure.WebJobs.Script.Binding.ApiHubScriptBindingProvider" },
{ "apiHubTable", "Microsoft.Azure.WebJobs.Script.Binding.ApiHubScriptBindingProvider" },
{ "serviceBusTrigger", "Microsoft.Azure.WebJobs.Script.Binding.ServiceBusScriptBindingProvider" },
{ "serviceBus", "Microsoft.Azure.WebJobs.Script.Binding.ServiceBusScriptBindingProvider" },
{ "eventHubTrigger", "Microsoft.Azure.WebJobs.Script.Binding.ServiceBusScriptBindingProvider" },
{ "eventHub", "Microsoft.Azure.WebJobs.Script.Binding.ServiceBusScriptBindingProvider" },
};

// For backwards compat, we support a #r directly to these assemblies.
private static HashSet<string> _assemblyWhitelist = new HashSet<string>(StringComparer.OrdinalIgnoreCase)
{
{ "Twilio.Api" },
{ "Microsoft.Azure.WebJobs.Extensions.Twilio" },
{ "Microsoft.Azure.NotificationHubs" },
{ "Microsoft.WindowsAzure.Mobile" },
{ "Microsoft.Azure.WebJobs.Extensions.MobileApps" },
{ "Microsoft.Azure.WebJobs.Extensions.NotificationHubs" },
{ "Microsoft.WindowsAzure.Mobile" },
{ "Microsoft.Azure.WebJobs.Extensions.MobileApps" },
{ "Microsoft.Azure.Documents.Client" },
{ "Microsoft.Azure.WebJobs.Extensions.DocumentDB" },
{ "Microsoft.Azure.ApiHub.Sdk" },
{ "Microsoft.Azure.WebJobs.Extensions.ApiHub" }
};

protected internal ScriptHost(IScriptHostEnvironment environment,
IScriptEventManager eventManager,
ScriptHostConfiguration scriptConfig = null,
Expand Down Expand Up @@ -398,9 +439,6 @@ protected virtual void Initialize()
hostConfig.StorageConnectionString = null;
}

var bindingProviders = LoadBindingProviders(ScriptConfig, hostConfigObject, TraceWriter, _startupLogger);
ScriptConfig.BindingProviders = bindingProviders;

if (ScriptConfig.FileWatchingEnabled)
{
_fileEventSource = new FileWatcherEventSource(EventManager, EventSources.ScriptFiles, ScriptConfig.RootScriptPath);
Expand All @@ -424,6 +462,13 @@ protected virtual void Initialize()
// take a snapshot so we can detect function additions/removals
_directorySnapshot = Directory.EnumerateDirectories(ScriptConfig.RootScriptPath).ToImmutableArray();

// Scan the function.json early to determine the requirements.
var functionMetadata = ReadFunctionMetadata(ScriptConfig, TraceWriter, _startupLogger, FunctionErrors, _settingsManager);
var usedBindingTypes = DiscoverBindingTypes(functionMetadata);

var bindingProviders = LoadBindingProviders(ScriptConfig, hostConfigObject, TraceWriter, _startupLogger, usedBindingTypes);
ScriptConfig.BindingProviders = bindingProviders;

// Allow BindingProviders to initialize
foreach (var bindingProvider in ScriptConfig.BindingProviders)
{
Expand All @@ -440,16 +485,7 @@ protected virtual void Initialize()
_startupLogger?.LogError(0, ex, errorMsg);
}
}

// Load builtin extensions
{
var botExtension = new Extensions.BotFramework.Config.BotFrameworkConfiguration();
LoadExtension(botExtension);

var sendGridExtension = new Extensions.SendGrid.SendGridConfiguration();
LoadExtension(sendGridExtension);
}

LoadBuiltinBindings(usedBindingTypes);
LoadCustomExtensions();

// Do this after we've loaded the custom extensions. That gives an extension an opportunity to plug in their own implementations.
Expand All @@ -467,7 +503,7 @@ protected virtual void Initialize()
}

// read all script functions and apply to JobHostConfiguration
Collection<FunctionDescriptor> functions = GetFunctionDescriptors();
Collection<FunctionDescriptor> functions = GetFunctionDescriptors(functionMetadata);
Collection<CustomAttributeBuilder> typeAttributes = CreateTypeAttributes(ScriptConfig);
string typeName = string.Format(CultureInfo.InvariantCulture, "{0}.{1}", GeneratedTypeNamespace, GeneratedTypeName);

Expand All @@ -490,6 +526,43 @@ protected virtual void Initialize()
}
}

private void LoadBuiltinBindings(IEnumerable<string> bindingTypes)
{
foreach (var bindingType in bindingTypes)
{
string assemblyQualifiedTypeName;
if (_builtinBindingTypes.TryGetValue(bindingType, out assemblyQualifiedTypeName))
{
Type typeExtension = Type.GetType(assemblyQualifiedTypeName);
if (typeExtension == null)
{
string errorMsg = $"Can't find builtin provider '{assemblyQualifiedTypeName}' for '{bindingType}'";
TraceWriter.Error(errorMsg);
_startupLogger?.LogError(errorMsg);
}
else
{
IExtensionConfigProvider extension = (IExtensionConfigProvider)Activator.CreateInstance(typeExtension);
LoadExtension(extension);
}
}
}
}

private static IEnumerable<string> DiscoverBindingTypes(IEnumerable<FunctionMetadata> functions)
{
HashSet<string> bindingTypes = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
foreach (var function in functions)
{
foreach (var binding in function.InputBindings.Concat(function.OutputBindings))
{
string bindingType = binding.Type;
bindingTypes.Add(bindingType);
}
}
return bindingTypes;
}

private IMetricsLogger CreateMetricsLogger()
{
IMetricsLogger metricsLogger = ScriptConfig.HostConfig.GetService<IMetricsLogger>();
Expand Down Expand Up @@ -748,7 +821,20 @@ public static ScriptHost Create(IScriptHostEnvironment environment, IScriptEvent
return scriptHost;
}

private static Collection<ScriptBindingProvider> LoadBindingProviders(ScriptHostConfiguration config, JObject hostMetadata, TraceWriter traceWriter, ILogger logger)
// Get the ScriptBindingProviderType for a given binding type.
// Null if no match.
private static Type GetScriptBindingProvider(string bindingType)
{
string assemblyQualifiedTypeName;
if (_builtinScriptBindingTypes.TryGetValue(bindingType, out assemblyQualifiedTypeName))
{
var type = Type.GetType(assemblyQualifiedTypeName);
return type;
}
return null;
}

private static Collection<ScriptBindingProvider> LoadBindingProviders(ScriptHostConfiguration config, JObject hostMetadata, TraceWriter traceWriter, ILogger logger, IEnumerable<string> usedBindingTypes)
{
JobHostConfiguration hostConfig = config.HostConfig;

Expand All @@ -757,21 +843,29 @@ private static Collection<ScriptBindingProvider> LoadBindingProviders(ScriptHost
{
// binding providers defined in this assembly
typeof(WebJobsCoreScriptBindingProvider),
typeof(ServiceBusScriptBindingProvider),

// binding providers defined in known extension assemblies
typeof(CoreExtensionsScriptBindingProvider),
typeof(ApiHubScriptBindingProvider),
typeof(DocumentDBScriptBindingProvider),
typeof(MobileAppsScriptBindingProvider),
typeof(NotificationHubScriptBindingProvider),
typeof(TwilioScriptBindingProvider),

// General purpose binder that works directly against SDK.
// This should eventually replace all other ScriptBindingProvider
typeof(GeneralScriptBindingProvider)
};

HashSet<Type> existingTypes = new HashSet<Type>();

// Add custom providers for any other types being used from function.json
foreach (var usedType in usedBindingTypes)
{
var type = GetScriptBindingProvider(usedType);
if (type != null && existingTypes.Add(type))
{
bindingProviderTypes.Add(type);
}
}

// General purpose binder that works directly against SDK.
// This should eventually replace all other ScriptBindingProvider
bindingProviderTypes.Add(typeof(GeneralScriptBindingProvider));

bindingProviderTypes.Add(typeof(DllWhitelistBindingProvider));

// Create the binding providers
var bindingProviders = new Collection<ScriptBindingProvider>();
foreach (var bindingProviderType in bindingProviderTypes)
Expand Down Expand Up @@ -1064,10 +1158,8 @@ private static ScriptType ParseScriptType(string scriptFilePath)
}
}

private Collection<FunctionDescriptor> GetFunctionDescriptors()
private Collection<FunctionDescriptor> GetFunctionDescriptors(Collection<FunctionMetadata> functions)
{
var functions = ReadFunctionMetadata(ScriptConfig, TraceWriter, _startupLogger, FunctionErrors, _settingsManager);

var descriptorProviders = new List<FunctionDescriptorProvider>()
{
new ScriptFunctionDescriptorProvider(this, ScriptConfig),
Expand Down Expand Up @@ -1587,5 +1679,32 @@ protected override void Dispose(bool disposing)
// cause us to not dispose ourselves
base.Dispose(disposing);
}

// We have a backwards compat requirement to whitelist #r references to certain "builtin" dlls.
// Hook into #r resolution pipeline and apply the whitelist.
private class DllWhitelistBindingProvider : ScriptBindingProvider
{
public DllWhitelistBindingProvider(JobHostConfiguration config, JObject hostMetadata, TraceWriter traceWriter)
: base(config, hostMetadata, traceWriter)
{
}

public override bool TryCreate(ScriptBindingContext context, out ScriptBinding binding)
{
binding = null;
return false;
}

public override bool TryResolveAssembly(string assemblyName, out Assembly assembly)
{
if (_assemblyWhitelist.Contains(assemblyName))
{
assembly = Assembly.Load(assemblyName);
return true;
}

return base.TryResolveAssembly(assemblyName, out assembly);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ await TestHelpers.Await(() =>
// No need for assert; this will throw if there's not one and only one
logs.Single(p => p.EndsWith(functionTrace));

Assert.Equal(12, _fixture.TelemetryItems.Count);

// Validate the traces. Order by message string as the requests may come in
// slightly out-of-order or on different threads
TelemetryPayload[] telemetries = _fixture.TelemetryItems
Expand All @@ -70,9 +68,7 @@ await TestHelpers.Await(() =>
ValidateTrace(telemetries[5], "Host configuration file read:", LogCategories.Startup);
ValidateTrace(telemetries[6], "Host lock lease acquired by instance ID", ScriptConstants.LogCategoryHostGeneral);
ValidateTrace(telemetries[7], "Job host started", LogCategories.Startup);
ValidateTrace(telemetries[8], "Loaded custom extension: BotFrameworkConfiguration from ''", LogCategories.Startup);
ValidateTrace(telemetries[9], "Loaded custom extension: SendGridConfiguration from ''", LogCategories.Startup);
ValidateTrace(telemetries[10], "Reading host configuration file", LogCategories.Startup);
ValidateTrace(telemetries[8], "Reading host configuration file", LogCategories.Startup);

// Finally, validate the request
TelemetryPayload request = _fixture.TelemetryItems
Expand Down

0 comments on commit 4aa3ad6

Please sign in to comment.