Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: add unit test for RequestFailedException #153

Merged
merged 5 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
45 changes: 44 additions & 1 deletion Notation.Plugin.AzureKeyVault.Tests/ProgramTests.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
using Xunit;
using Moq;
using Notation.Plugin.AzureKeyVault.Command;
using Azure;
using Notation.Plugin.Protocol;
using System.IO;
using System.Threading.Tasks;
using System;
using Moq.Protected;

namespace Notation.Plugin.AzureKeyVault.Tests
{
Expand Down Expand Up @@ -84,5 +85,47 @@ public async Task ExecuteAsync_HandlesInvalidCommands(string command)
await Assert.ThrowsAsync<ValidationException>(() => Program.ExecuteAsync(args));
}
}
// we need this because of method being protected
internal interface IResponseMock
{
bool TryGetHeader(string name, out string value);
}

// we need this to be able to define the callback with out parameter
delegate bool TryGetHeaderCallback(string name, ref string value);


[Theory]
[InlineData(200, "{\"error\":{\"message\":\"TestErrorMessage\"}}", "TestErrorMessage")]
[InlineData(500, "{\"error\":{\"message\":\"TestErrorMessage\"}", "Service request failed.\nStatus: 500\n\nHeaders:\n")]
[InlineData(500, "{\"error2\":{\"message\":\"TestErrorMessage\"}}", "Service request failed.\nStatus: 500\n\nHeaders:\n")]
[InlineData(500, "{\"error\":{\"message2\":\"TestErrorMessage\"}}", "Service request failed.\nStatus: 500\n\nHeaders:\n")]
[InlineData(500, "{\"error\":{\"message\":\"\"}}", "\nStatus: 500\n\nHeaders:\n")]
public void HandleAzureException(int code, string content, string expectedErrorMessage)
{
// Arrange
Mock<Response> responseMock = new Mock<Response>();
responseMock.SetupGet(r => r.Status).Returns(code);
responseMock.SetupGet(r => r.Content).Returns(BinaryData.FromString(content));

// mock headers
responseMock.CallBase = true;
responseMock.Protected().As<IResponseMock>().Setup(m => m.TryGetHeader(It.IsAny<string>(), out It.Ref<string>.IsAny))
.Returns(new TryGetHeaderCallback((string name, ref string value) =>
{
value = "ETAG";
Console.WriteLine(name);
return true;
}));

var exception = new RequestFailedException(responseMock.Object);

// Act
var errorResponse = Program.HandleAzureException(exception);

// Assert exit code 1
Assert.Equal(expectedErrorMessage, errorResponse.ErrorMessage);
Assert.Equal("ERROR", errorResponse.ErrorCode);
}
}
}
45 changes: 31 additions & 14 deletions Notation.Plugin.AzureKeyVault/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,29 @@
Error.PrintError(e.Code, e.Message);
Environment.Exit(1);
}
catch (Azure.RequestFailedException e)
{
// wrap azure exception to notation plugin error response
var rawResponse = e.GetRawResponse();
if (rawResponse != null)
Console.Error.WriteLine(HandleAzureException(e).ToJson());
Environment.Exit(1);
}

Check warning on line 24 in Notation.Plugin.AzureKeyVault/Program.cs

View check run for this annotation

Codecov / codecov/patch

Notation.Plugin.AzureKeyVault/Program.cs#L20-L24

Added lines #L20 - L24 were not covered by tests
catch (Exception e)
{
Error.PrintError(Error.ERROR, e.Message);
Environment.Exit(1);
}
}

/// <summary>
/// Handles Azure.RequestFailedException and returns ErrorResponse.
/// </summary>
/// <param name="e"></param>
/// <returns></returns>
public static ErrorResponse HandleAzureException(Azure.RequestFailedException e)
{
var rawResponse = e.GetRawResponse();
if (rawResponse != null)
{
try
{
var content = JsonDocument.Parse(rawResponse.Content);
if (content.RootElement.TryGetProperty("error", out var errorInfo) &&
Expand All @@ -30,23 +48,22 @@
var errorMessage = errMsg.GetString();
if (!string.IsNullOrEmpty(errorMessage))
{
Error.PrintError(
return new ErrorResponse(
errorCode: e.ErrorCode ?? Error.ERROR,
errorMessage: errorMessage);
Environment.Exit(1);
}
}
}

// fallback to default error message
Error.PrintError(Error.ERROR, e.Message);
Environment.Exit(1);
}
catch (Exception e)
{
Error.PrintError(Error.ERROR, e.Message);
Environment.Exit(1);
catch (Exception)
{
// ignore
}
}

// fallback to default error message
return new ErrorResponse(
errorCode: e.ErrorCode ?? Error.ERROR,
errorMessage: e.Message);
}

public static async Task ExecuteAsync(string[] args)
Expand Down