Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Catch TLS exceptions during connect #2928

Merged
merged 3 commits into from
Nov 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 10 additions & 0 deletions iothub/device/src/Exceptions/IotHubClientErrorCode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -125,5 +125,15 @@ public enum IotHubClientErrorCode
/// guide for more information.
/// </remarks>
Unauthorized,

/// <summary>
/// The request failed because of TLS authentication error.
brycewang-microsoft marked this conversation as resolved.
Show resolved Hide resolved
/// </summary>
/// <remarks>
/// This error may happen when the remote certificate presented could not be validated, the TLS version is different between client requested
/// auth and service minimum requirement, cipher suites to be used could not be agreed upon, etc. The best course of action is to check your
/// device certificates and ensure they are up-to-date.
/// </remarks>
TlsAuthenticationError,
}
}
2 changes: 1 addition & 1 deletion iothub/device/src/Pipeline/ErrorDelegatingHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ private async Task<T> ExecuteWithErrorHandlingAsync<T>(Func<Task<T>> asyncOperat
if (IsSecurityExceptionChain(ex))
{
Exception innerException = (ex is IotHubClientException) ? ex.InnerException : ex;
brycewang-microsoft marked this conversation as resolved.
Show resolved Hide resolved
throw new AuthenticationException("TLS authentication error.", innerException);
throw new IotHubClientException("TLS authentication error.", IotHubClientErrorCode.TlsAuthenticationError, innerException);
}
// For historic reasons, part of the Error handling is done within the transport handlers.
else if (ex is IotHubClientException hubEx
Expand Down
164 changes: 73 additions & 91 deletions iothub/device/tests/Pipeline/ErrorDelegatingHandlerTests.cs
Original file line number Diff line number Diff line change
@@ -1,34 +1,37 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Net.Http;
using System.Net.Sockets;
using System.Net.WebSockets;
using System.Security.Authentication;
using System.Threading;
using System.Threading.Tasks;
using FluentAssertions;
using Microsoft.Azure.Amqp;
using Microsoft.Azure.Amqp.Framing;
using Microsoft.Azure.Devices.Client.Transport;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NSubstitute;

namespace Microsoft.Azure.Devices.Client.Test
{
using System;
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Net.Http;
using System.Net.Sockets;
using System.Net.WebSockets;
using System.Security.Authentication;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.Amqp;
using Microsoft.Azure.Devices.Client.Transport;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NSubstitute;

[TestClass]
[TestCategory("Unit")]
public class ErrorDelegatingHandlerTests
{
internal static readonly HashSet<Type> NonTransientExceptions = new HashSet<Type>
internal static readonly HashSet<Type> s_nonTransientExceptions = new HashSet<Type>
{
typeof(IotHubClientException),
};

private const string ErrorMessage = "Error occurred.";

private static readonly Dictionary<Type, Func<Exception>> ExceptionFactory = new Dictionary<Type, Func<Exception>>
private static readonly Dictionary<Type, Func<Exception>> s_exceptionFactory = new Dictionary<Type, Func<Exception>>
{
{ typeof(IotHubClientException), () => new IotHubClientException(ErrorMessage) },
{ typeof(IOException), () => new IOException(ErrorMessage) },
Expand All @@ -50,7 +53,6 @@ public class ErrorDelegatingHandlerTests
typeof(SocketException),
typeof(HttpRequestException),
typeof(WebException),
typeof(IotHubClientException),
typeof(WebSocketException),
typeof(TestDerivedException),
};
Expand Down Expand Up @@ -96,95 +98,71 @@ public async Task ErrorHandler_TransientErrorOccuredChannelIsAlive_ChannelIsTheS
{
foreach (Type exceptionType in s_networkExceptions)
{
await TestExceptionThrown(exceptionType, typeof(IotHubClientException)).ConfigureAwait(false);
List<Exception> exceptionList = await TestExceptionThrown(exceptionType, typeof(IotHubClientException)).ConfigureAwait(false);

foreach (Exception ex in exceptionList)
{
if (ex is IotHubClientException hubEx)
{
hubEx.ErrorCode.Should().Be(IotHubClientErrorCode.NetworkErrors);
hubEx.IsTransient.Should().BeTrue();
}
}
}
}

[TestMethod]
public async Task ErrorHandler_SecurityErrorOccured_ChannelIsAborted()
{
await TestExceptionThrown(typeof(TestSecurityException), typeof(AuthenticationException)).ConfigureAwait(false);
List<Exception> exceptionList = await TestExceptionThrown(typeof(TestSecurityException), typeof(IotHubClientException)).ConfigureAwait(false);

foreach (Exception ex in exceptionList)
{
if (ex is IotHubClientException hubEx)
{
hubEx.ErrorCode.Should().Be(IotHubClientErrorCode.TlsAuthenticationError);
hubEx.IsTransient.Should().BeFalse();
}
}
}

[TestMethod]
public async Task ErrorHandler_NonTransientErrorOccured_ChannelIsRecreated()
{
foreach (Type exceptionType in NonTransientExceptions)
foreach (Type exceptionType in s_nonTransientExceptions)
{
await TestExceptionThrown(exceptionType, exceptionType).ConfigureAwait(false);
}
}

private static async Task TestExceptionThrown(Type thrownExceptionType, Type expectedExceptionType)
private static async Task<List<Exception>> TestExceptionThrown(Type thrownExceptionType, Type expectedExceptionType)
{
var message = new TelemetryMessage(new byte[0]);
var cancellationToken = new CancellationToken();

await OperationAsync_ExceptionThrownAndThenSucceed_OperationSuccessfullyCompleted(
di => di.SendTelemetryAsync(Arg.Is(message), Arg.Any<CancellationToken>()),
di => di.SendTelemetryAsync(message, cancellationToken),
di => di.Received(2).SendTelemetryAsync(Arg.Is(message), Arg.Any<CancellationToken>()),
thrownExceptionType, expectedExceptionType).ConfigureAwait(false);

IEnumerable<Message> messages = new[] { new Message(new byte[0]) };

await OperationAsync_ExceptionThrownAndThenSucceed_OperationSuccessfullyCompleted(
di => di.SendTelemetryAsync(Arg.Is(message), Arg.Any<CancellationToken>()),
di => di.SendTelemetryAsync(message, cancellationToken),
di => di.Received(2).SendTelemetryAsync(Arg.Is(message), Arg.Any<CancellationToken>()),
thrownExceptionType, expectedExceptionType).ConfigureAwait(false);

await OpenAsync_ExceptionThrownAndThenSucceed_SuccessfullyOpened(
di => di.OpenAsync(Arg.Any<CancellationToken>()),
di => di.OpenAsync(cancellationToken),
di => di.Received(2).OpenAsync(Arg.Any<CancellationToken>()),
thrownExceptionType, expectedExceptionType).ConfigureAwait(false);
}

private static async Task OperationAsync_ExceptionThrownAndThenSucceed_OperationSuccessfullyCompleted(
Func<IDelegatingHandler, Task<Message>> mockSetup,
Func<IDelegatingHandler, Task<Message>> act,
Func<IDelegatingHandler, Task<Message>> assert,
Type thrownExceptionType,
Type expectedExceptionType)
Comment on lines -144 to -149
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unused private method

{
var contextMock = Substitute.For<PipelineContext>();
var innerHandler = Substitute.For<IDelegatingHandler>();
var sut = new ErrorDelegatingHandler(contextMock, innerHandler);

//initial OpenAsync to emulate Gatekeeper behavior
var cancellationToken = new CancellationToken();
innerHandler.OpenAsync(Arg.Any<CancellationToken>()).Returns(Task.CompletedTask);
await sut.OpenAsync(cancellationToken).ConfigureAwait(false);

//set initial operation result that throws

bool[] setup = { false };
mockSetup(innerHandler).Returns(ci =>
{
if (setup[0])
{
return Task.FromResult(new Message());
}
throw ExceptionFactory[thrownExceptionType]();
});

//act
await ((Func<Task>)(() => act(sut))).ExpectedAsync(expectedExceptionType).ConfigureAwait(false);

//override outcome
setup[0] = true;//otherwise previously setup call will happen and throw;
mockSetup(innerHandler).Returns(new Message());

//act
await act(sut).ConfigureAwait(false);

//assert
await innerHandler.Received(1).OpenAsync(Arg.Any<CancellationToken>()).ConfigureAwait(false);
await assert(innerHandler).ConfigureAwait(false);
var exceptionList = new List<Exception>();

exceptionList.Add(
await OperationAsync_ExceptionThrownAndThenSucceed_OperationSuccessfullyCompleted(
di => di.SendTelemetryAsync(Arg.Is(message), Arg.Any<CancellationToken>()),
di => di.SendTelemetryAsync(message, cancellationToken),
di => di.Received(2).SendTelemetryAsync(Arg.Is(message), Arg.Any<CancellationToken>()),
thrownExceptionType,
expectedExceptionType)
.ConfigureAwait(false));

exceptionList.Add(
await OpenAsync_ExceptionThrownAndThenSucceed_SuccessfullyOpened(
di => di.OpenAsync(Arg.Any<CancellationToken>()),
di => di.OpenAsync(cancellationToken),
di => di.Received(2).OpenAsync(Arg.Any<CancellationToken>()),
thrownExceptionType,
expectedExceptionType)
.ConfigureAwait(false));

return exceptionList;
}

private static async Task OperationAsync_ExceptionThrownAndThenSucceed_OperationSuccessfullyCompleted(
private static async Task<Exception> OperationAsync_ExceptionThrownAndThenSucceed_OperationSuccessfullyCompleted(
Func<IDelegatingHandler, Task> mockSetup,
Func<IDelegatingHandler, Task> act,
Func<IDelegatingHandler, Task> assert,
Expand All @@ -209,11 +187,11 @@ private static async Task TestExceptionThrown(Type thrownExceptionType, Type exp
{
return Task.CompletedTask; ;
}
throw ExceptionFactory[thrownExceptionType]();
throw s_exceptionFactory[thrownExceptionType]();
});

//act
await ((Func<Task>)(() => act(sut))).ExpectedAsync(expectedExceptionType).ConfigureAwait(false);
Exception ex = await ((Func<Task>)(() => act(sut))).ExpectedAsync(expectedExceptionType).ConfigureAwait(false);

//override outcome
setup[0] = true;//otherwise previously setup call will happen and throw;
Expand All @@ -225,9 +203,11 @@ private static async Task TestExceptionThrown(Type thrownExceptionType, Type exp
//assert
await innerHandler.Received(1).OpenAsync(Arg.Any<CancellationToken>()).ConfigureAwait(false);
await assert(innerHandler).ConfigureAwait(false);

return ex;
}

private static async Task OpenAsync_ExceptionThrownAndThenSucceed_SuccessfullyOpened(
private static async Task<Exception> OpenAsync_ExceptionThrownAndThenSucceed_SuccessfullyOpened(
Func<IDelegatingHandler, Task> mockSetup,
Func<IDelegatingHandler, Task> act,
Func<IDelegatingHandler, Task> assert,
Expand All @@ -247,11 +227,11 @@ private static async Task TestExceptionThrown(Type thrownExceptionType, Type exp
{
return Task.FromResult(Guid.NewGuid());
}
throw ExceptionFactory[thrownExceptionType]();
throw s_exceptionFactory[thrownExceptionType]();
});

//act
await ((Func<Task>)(() => act(sut))).ExpectedAsync(expectedExceptionType).ConfigureAwait(false);
Exception ex = await ((Func<Task>)(() => act(sut))).ExpectedAsync(expectedExceptionType).ConfigureAwait(false);

//override outcome
setup[0] = true;//otherwise previously setup call will happen and throw;
Expand All @@ -262,6 +242,8 @@ private static async Task TestExceptionThrown(Type thrownExceptionType, Type exp

//assert
await assert(innerHandler).ConfigureAwait(false);

return ex;
}
}
}
2 changes: 0 additions & 2 deletions iothub/service/src/Amqp/AmqpConnectionHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@

using System;
using System.Globalization;
using System.Net.Security;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.Amqp;
Expand Down