Skip to content

[release/9.0-staging] [WinHTTP] Certificate caching on WinHttpHandler to eliminate extra call to Custom Certificate Validation #114678

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,16 @@ public struct WINHTTP_ASYNC_RESULT
public uint dwError;
}

[StructLayout(LayoutKind.Sequential)]
public unsafe struct WINHTTP_CONNECTION_INFO
{
// This field is actually 4 bytes, but we use nuint to avoid alignment issues for x64.
// If we want to read this field in the future, we need to change type and make sure
// alignment is correct for necessary archs.
public nuint cbSize;
public fixed byte LocalAddress[128];
public fixed byte RemoteAddress[128];
}

[StructLayout(LayoutKind.Sequential)]
public struct tcp_keepalive
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ System.Net.Http.WinHttpHandler</PackageDescription>
Link="Common\System\Runtime\ExceptionServices\ExceptionStackTrace.cs" />
<Compile Include="$(CommonPath)\System\Threading\Tasks\RendezvousAwaitable.cs"
Link="Common\System\Threading\Tasks\RendezvousAwaitable.cs" />
<Compile Include="System\Net\Http\CachedCertificateValue.cs" />
<Compile Include="System\Net\Http\NetEventSource.WinHttpHandler.cs" />
<Compile Include="System\Net\Http\NoWriteNoSeekStreamContent.cs" />
<Compile Include="System\Net\Http\WinHttpAuthHelper.cs" />
Expand Down Expand Up @@ -117,6 +118,7 @@ System.Net.Http.WinHttpHandler</PackageDescription>
<ItemGroup Condition="'$(TargetFrameworkIdentifier)' == '.NETFramework'">
<PackageReference Include="System.Buffers" Version="$(SystemBuffersVersion)" />
<PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" />
<PackageReference Include="Microsoft.Bcl.HashCode" Version="$(MicrosoftBclHashCodeVersion)" />
<Reference Include="System.Net.Http" />
</ItemGroup>

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using System.Threading;

namespace System.Net.Http
{
internal sealed class CachedCertificateValue(byte[] rawCertificateData, long lastUsedTime)
{
private long _lastUsedTime = lastUsedTime;
public byte[] RawCertificateData { get; } = rawCertificateData;
public long LastUsedTime
{
get => Volatile.Read(ref _lastUsedTime);
set => Volatile.Write(ref _lastUsedTime, value);
}
}

internal readonly struct CachedCertificateKey : IEquatable<CachedCertificateKey>
{
public CachedCertificateKey(IPAddress address, HttpRequestMessage message)
{
Debug.Assert(message.RequestUri != null);
Address = address;
Host = message.Headers.Host ?? message.RequestUri.Host;
}
public IPAddress Address { get; }
public string Host { get; }

public bool Equals(CachedCertificateKey other) =>
Address.Equals(other.Address) &&
Host == other.Host;

public override bool Equals(object? obj)
{
throw new Exception("Unreachable");
}

public override int GetHashCode() => HashCode.Combine(Address, Host);
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Net.Http.Headers;
using System.Net.Security;
Expand Down Expand Up @@ -41,11 +43,14 @@ public class WinHttpHandler : HttpMessageHandler
internal static readonly Version HttpVersion20 = new Version(2, 0);
internal static readonly Version HttpVersion30 = new Version(3, 0);
internal static readonly Version HttpVersionUnknown = new Version(0, 0);
internal static bool CertificateCachingAppContextSwitchEnabled { get; } = AppContext.TryGetSwitch("System.Net.Http.UseWinHttpCertificateCaching", out bool enabled) && enabled;
private static readonly TimeSpan s_maxTimeout = TimeSpan.FromMilliseconds(int.MaxValue);

private static readonly StringWithQualityHeaderValue s_gzipHeaderValue = new StringWithQualityHeaderValue("gzip");
private static readonly StringWithQualityHeaderValue s_deflateHeaderValue = new StringWithQualityHeaderValue("deflate");
private static readonly Lazy<bool> s_supportsTls13 = new Lazy<bool>(CheckTls13Support);
private static readonly TimeSpan s_cleanCachedCertificateTimeout = TimeSpan.FromMilliseconds((int?)AppDomain.CurrentDomain.GetData("System.Net.Http.WinHttpCertificateCachingCleanupTimerInterval") ?? 60_000);
private static readonly long s_staleTimeout = (long)(s_cleanCachedCertificateTimeout.TotalSeconds * Stopwatch.Frequency / 1000);

[ThreadStatic]
private static StringBuilder? t_requestHeadersBuilder;
Expand Down Expand Up @@ -93,9 +98,44 @@ private Func<
private volatile bool _disposed;
private SafeWinHttpHandle? _sessionHandle;
private readonly WinHttpAuthHelper _authHelper = new WinHttpAuthHelper();
private readonly Timer? _certificateCleanupTimer;
private bool _isTimerRunning;
private readonly ConcurrentDictionary<CachedCertificateKey, CachedCertificateValue> _cachedCertificates = new();

public WinHttpHandler()
{
if (CertificateCachingAppContextSwitchEnabled)
{
WeakReference<WinHttpHandler> thisRef = new(this);
bool restoreFlow = false;
try
{
if (!ExecutionContext.IsFlowSuppressed())
{
ExecutionContext.SuppressFlow();
restoreFlow = true;
}

_certificateCleanupTimer = new Timer(
static s =>
{
if (((WeakReference<WinHttpHandler>)s!).TryGetTarget(out WinHttpHandler? thisRef))
{
thisRef.ClearStaleCertificates();
}
},
thisRef,
Timeout.Infinite,
Timeout.Infinite);
}
finally
{
if (restoreFlow)
{
ExecutionContext.RestoreFlow();
}
}
}
}

#region Properties
Expand Down Expand Up @@ -543,9 +583,12 @@ protected override void Dispose(bool disposing)
{
_disposed = true;

if (disposing && _sessionHandle != null)
if (disposing)
{
SafeWinHttpHandle.DisposeAndClearHandle(ref _sessionHandle);
if (_sessionHandle is not null) {
SafeWinHttpHandle.DisposeAndClearHandle(ref _sessionHandle);
}
_certificateCleanupTimer?.Dispose();
}
}

Expand Down Expand Up @@ -1644,7 +1687,8 @@ private void SetStatusCallback(
Interop.WinHttp.WINHTTP_CALLBACK_FLAG_ALL_COMPLETIONS |
Interop.WinHttp.WINHTTP_CALLBACK_FLAG_HANDLES |
Interop.WinHttp.WINHTTP_CALLBACK_FLAG_REDIRECT |
Interop.WinHttp.WINHTTP_CALLBACK_FLAG_SEND_REQUEST;
Interop.WinHttp.WINHTTP_CALLBACK_FLAG_SEND_REQUEST |
Interop.WinHttp.WINHTTP_CALLBACK_STATUS_CONNECTED_TO_SERVER;

IntPtr oldCallback = Interop.WinHttp.WinHttpSetStatusCallback(
requestHandle,
Expand Down Expand Up @@ -1730,5 +1774,90 @@ private RendezvousAwaitable<int> InternalReceiveResponseHeadersAsync(WinHttpRequ

return state.LifecycleAwaitable;
}

internal bool GetCertificateFromCache(CachedCertificateKey key, [NotNullWhen(true)] out byte[]? rawCertificateBytes)
{
if (_cachedCertificates.TryGetValue(key, out CachedCertificateValue? cachedValue))
{
cachedValue.LastUsedTime = Stopwatch.GetTimestamp();
rawCertificateBytes = cachedValue.RawCertificateData;
return true;
}

rawCertificateBytes = null;
return false;
}

internal void AddCertificateToCache(CachedCertificateKey key, byte[] rawCertificateData)
{
if (_cachedCertificates.TryAdd(key, new CachedCertificateValue(rawCertificateData, Stopwatch.GetTimestamp())))
{
EnsureCleanupTimerRunning();
}
}

internal bool TryRemoveCertificateFromCache(CachedCertificateKey key)
{
bool result = _cachedCertificates.TryRemove(key, out _);
if (result)
{
StopCleanupTimerIfEmpty();
}
return result;
}

private void ChangeCleanerTimer(TimeSpan timeout)
{
Debug.Assert(Monitor.IsEntered(_lockObject));
Debug.Assert(_certificateCleanupTimer != null);
if (_certificateCleanupTimer!.Change(timeout, Timeout.InfiniteTimeSpan))
{
_isTimerRunning = timeout != Timeout.InfiniteTimeSpan;
}
}

private void ClearStaleCertificates()
{
foreach (KeyValuePair<CachedCertificateKey, CachedCertificateValue> kvPair in _cachedCertificates)
{
if (IsStale(kvPair.Value.LastUsedTime))
{
_cachedCertificates.TryRemove(kvPair.Key, out _);
}
}

lock (_lockObject)
{
ChangeCleanerTimer(_cachedCertificates.IsEmpty ? Timeout.InfiniteTimeSpan : s_cleanCachedCertificateTimeout);
}

static bool IsStale(long lastUsedTime)
{
long now = Stopwatch.GetTimestamp();
return (now - lastUsedTime) > s_staleTimeout;
}
}

private void EnsureCleanupTimerRunning()
{
lock (_lockObject)
{
if (!_cachedCertificates.IsEmpty && !_isTimerRunning)
{
ChangeCleanerTimer(s_cleanCachedCertificateTimeout);
}
}
}

private void StopCleanupTimerIfEmpty()
{
lock (_lockObject)
{
if (_cachedCertificates.IsEmpty && _isTimerRunning)
{
ChangeCleanerTimer(Timeout.InfiniteTimeSpan);
}
}
}
}
}
Loading
Loading