Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GatewayClientStore: Fixes an issue with dealing with invalid JSON HTTP responses #4229

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
99 changes: 54 additions & 45 deletions Microsoft.Azure.Cosmos/src/GatewayStoreClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ namespace Microsoft.Azure.Cosmos
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.Cosmos.Handlers;
using Microsoft.Azure.Cosmos.Tracing.TraceData;
using Microsoft.Azure.Documents;
using Microsoft.Azure.Documents.Collections;
Expand Down Expand Up @@ -132,69 +131,79 @@ internal static INameValueCollection ExtractResponseHeaders(HttpResponseMessage
return headers;
}

/// <summary>
/// Creating a new DocumentClientException using the Gateway response message.
/// </summary>
/// <param name="responseMessage"></param>
/// <param name="requestStatistics"></param>
internal static async Task<DocumentClientException> CreateDocumentClientExceptionAsync(
HttpResponseMessage responseMessage,
IClientSideRequestStatistics requestStatistics)
{
bool isNameBased = false;
bool isFeed = false;
string resourceTypeString;
string resourceIdOrFullName;

string resourceLink = responseMessage.RequestMessage.RequestUri.LocalPath;
if (!PathsHelper.TryParsePathSegments(resourceLink, out isFeed, out resourceTypeString, out resourceIdOrFullName, out isNameBased))
if (!PathsHelper.TryParsePathSegments(
resourceUrl: responseMessage.RequestMessage.RequestUri.LocalPath,
isFeed: out _,
resourcePath: out _,
resourceIdOrFullName: out string resourceIdOrFullName,
isNameBased: out _))
{
// if resourceLink is invalid - we will not set resourceAddress in exception.
}

// If service rejects the initial payload like header is to large it will return an HTML error instead of JSON.
if (string.Equals(responseMessage.Content?.Headers?.ContentType?.MediaType, "application/json", StringComparison.OrdinalIgnoreCase))
if (string.Equals(responseMessage.Content?.Headers?.ContentType?.MediaType, "application/json", StringComparison.OrdinalIgnoreCase) &&
responseMessage.Content?.Headers.ContentLength > 0)
{
Stream readStream = await responseMessage.Content.ReadAsStreamAsync();
Error error = Documents.Resource.LoadFrom<Error>(readStream);
return new DocumentClientException(
error,
responseMessage.Headers,
responseMessage.StatusCode)
try
{
Stream contentAsStream = await responseMessage.Content.ReadAsStreamAsync();
Error error = JsonSerializable.LoadFrom<Error>(stream: contentAsStream);

return new DocumentClientException(
errorResource: error,
responseHeaders: responseMessage.Headers,
statusCode: responseMessage.StatusCode)
{
StatusDescription = responseMessage.ReasonPhrase,
ResourceAddress = resourceIdOrFullName,
RequestStatistics = requestStatistics
};
}
catch
{
StatusDescription = responseMessage.ReasonPhrase,
ResourceAddress = resourceIdOrFullName,
RequestStatistics = requestStatistics
};
}
}
else

StringBuilder contextBuilder = new StringBuilder();
contextBuilder.AppendLine(await responseMessage.Content.ReadAsStringAsync());

HttpRequestMessage requestMessage = responseMessage.RequestMessage;

if (requestMessage != null)
{
StringBuilder context = new StringBuilder();
context.AppendLine(await responseMessage.Content.ReadAsStringAsync());
contextBuilder.AppendLine($"RequestUri: {requestMessage.RequestUri};");
contextBuilder.AppendLine($"RequestMethod: {requestMessage.Method.Method};");

HttpRequestMessage requestMessage = responseMessage.RequestMessage;
if (requestMessage != null)
if (requestMessage.Headers != null)
{
context.AppendLine($"RequestUri: {requestMessage.RequestUri.ToString()};");
context.AppendLine($"RequestMethod: {requestMessage.Method.Method};");

if (requestMessage.Headers != null)
foreach (KeyValuePair<string, IEnumerable<string>> header in requestMessage.Headers)
{
foreach (KeyValuePair<string, IEnumerable<string>> header in requestMessage.Headers)
{
context.AppendLine($"Header: {header.Key} Length: {string.Join(",", header.Value).Length};");
}
contextBuilder.AppendLine($"Header: {header.Key} Length: {string.Join(",", header.Value).Length};");
}
}

String message = await responseMessage.Content.ReadAsStringAsync();
philipthomas-MSFT marked this conversation as resolved.
Show resolved Hide resolved
return new DocumentClientException(
message: context.ToString(),
innerException: null,
responseHeaders: responseMessage.Headers,
statusCode: responseMessage.StatusCode,
requestUri: responseMessage.RequestMessage.RequestUri)
{
StatusDescription = responseMessage.ReasonPhrase,
ResourceAddress = resourceIdOrFullName,
RequestStatistics = requestStatistics
};
}

return new DocumentClientException(
message: contextBuilder.ToString(),
innerException: null,
responseHeaders: responseMessage.Headers,
statusCode: responseMessage.StatusCode,
requestUri: responseMessage.RequestMessage.RequestUri)
{
StatusDescription = responseMessage.ReasonPhrase,
ResourceAddress = resourceIdOrFullName,
RequestStatistics = requestStatistics
};
}

internal static bool IsAllowedRequestHeader(string headerName)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
//------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
//------------------------------------------------------------
namespace Microsoft.Azure.Cosmos
{
using System;
using System.Net;
using System.Net.Http;
using System.Text;
using System.Threading.Tasks;
using Microsoft.Azure.Cosmos.Tracing;
using Microsoft.Azure.Cosmos.Tracing.TraceData;
using Microsoft.Azure.Documents;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Newtonsoft.Json;

/// <summary>
/// Tests for <see cref="GatewayStoreClient"/>.
/// </summary>
[TestClass]
public class GatewayStoreClientTests
{
/// <summary>
/// Testing CreateDocumentClientExceptionAsync when media type is NOT application/json and the error message has a length that is not zero.
/// This is not meant to be an exhaustive test for all legitimate content media types.
/// <see cref="GatewayStoreClient.CreateDocumentClientExceptionAsync(HttpResponseMessage, IClientSideRequestStatistics)"/>
/// </summary>
[TestMethod]
[DataRow("text/html", "<!DOCTYPE html><html><body></body></html>")]
[DataRow("text/plain", "This is a test error message.")]
[Owner("philipthomas-MSFT")]
public async Task TestCreateDocumentClientExceptionWhenMediaTypeIsNotApplicationJsonAndErrorMessageLengthIsNotZeroAsync(
string mediaType,
string errorMessage)
{
HttpResponseMessage responseMessage = new(statusCode: HttpStatusCode.NotFound)
{
RequestMessage = new HttpRequestMessage(
method: HttpMethod.Get,
requestUri: @"https://pt_ac_test_uri.com/"),
Content = new StringContent(
mediaType: mediaType,
encoding: Encoding.UTF8,
content: JsonConvert.SerializeObject(
value: new Error() { Code = HttpStatusCode.NotFound.ToString(), Message = errorMessage })),
};

DocumentClientException documentClientException = await GatewayStoreClient.CreateDocumentClientExceptionAsync(
responseMessage: responseMessage,
requestStatistics: GatewayStoreClientTests.CreateClientSideRequestStatistics());

Assert.IsNotNull(value: documentClientException);
Assert.AreEqual(expected: HttpStatusCode.NotFound, actual: documentClientException.StatusCode);
Assert.IsTrue(condition: documentClientException.Message.Contains(errorMessage));

Assert.IsNotNull(value: documentClientException.Error);
Assert.AreEqual(expected: HttpStatusCode.NotFound.ToString(), actual: documentClientException.Error.Code);
Assert.IsTrue(documentClientException.Error.Message.Contains(errorMessage));
}

/// <summary>
/// Testing CreateDocumentClientExceptionAsync when media type is NOT application/json and the error message has a length that is zero.
/// This is not meant to be an exhaustive test for all legitimate content media types.
/// <see cref="GatewayStoreClient.CreateDocumentClientExceptionAsync(HttpResponseMessage, IClientSideRequestStatistics)"/>
/// </summary>
[TestMethod]
[DataRow("text/html", "")]
[DataRow("text/html", " ")]
[DataRow("text/plain", "")]
[DataRow("text/plain", " ")]
[Owner("philipthomas-MSFT")]
public async Task TestCreateDocumentClientExceptionWhenMediaTypeIsNotApplicationJsonAndErrorMessageLengthIsZeroAsync(
string mediaType,
string errorMessage)
{
HttpResponseMessage responseMessage = new(statusCode: HttpStatusCode.NotFound)
{
RequestMessage = new HttpRequestMessage(
method: HttpMethod.Get,
requestUri: @"https://pt_ac_test_uri.com/"),
Content = new StringContent(
mediaType: mediaType,
encoding: Encoding.UTF8,
content: JsonConvert.SerializeObject(
value: new Error() { Code = HttpStatusCode.NotFound.ToString(), Message = errorMessage })),
};

DocumentClientException documentClientException = await GatewayStoreClient.CreateDocumentClientExceptionAsync(
responseMessage: responseMessage,
requestStatistics: GatewayStoreClientTests.CreateClientSideRequestStatistics());

Assert.IsNotNull(value: documentClientException);
Assert.AreEqual(expected: HttpStatusCode.NotFound, actual: documentClientException.StatusCode);
Assert.IsNotNull(value: documentClientException.Message);

Assert.IsNotNull(value: documentClientException.Error);
Assert.AreEqual(expected: HttpStatusCode.NotFound.ToString(), actual: documentClientException.Error.Code);
Assert.IsNotNull(value: documentClientException.Error.Message);
}

/// <summary>
/// Testing CreateDocumentClientExceptionAsync when media type is NOT application/json and the header content length is zero.
/// This is not meant to be an exhaustive test for all legitimate content media types.
/// <see cref="GatewayStoreClient.CreateDocumentClientExceptionAsync(HttpResponseMessage, IClientSideRequestStatistics)"/>
/// </summary>
[TestMethod]
[DataRow("text/plain", @"")]
[DataRow("text/plain", @" ")]
[Owner("philipthomas-MSFT")]
public async Task TestCreateDocumentClientExceptionWhenMediaTypeIsNotApplicationJsonAndHeaderContentLengthIsZeroAsync(
string mediaType,
string contentMessage)
{
HttpResponseMessage responseMessage = new(statusCode: HttpStatusCode.NotFound)
{
RequestMessage = new HttpRequestMessage(
method: HttpMethod.Get,
requestUri: @"https://pt_ac_test_uri.com/"),
Content = new StringContent(
mediaType: mediaType,
encoding: Encoding.UTF8,
content: contentMessage),
};

DocumentClientException documentClientException = await GatewayStoreClient.CreateDocumentClientExceptionAsync(
responseMessage: responseMessage,
requestStatistics: GatewayStoreClientTests.CreateClientSideRequestStatistics());

Assert.IsNotNull(value: documentClientException);
Assert.AreEqual(expected: HttpStatusCode.NotFound, actual: documentClientException.StatusCode);
Assert.IsNotNull(value: documentClientException.Message);

Assert.IsNotNull(value: documentClientException.Error);
Assert.AreEqual(expected: HttpStatusCode.NotFound.ToString(), actual: documentClientException.Error.Code);
Assert.IsNotNull(value: documentClientException.Error.Message);
}

/// <summary>
/// Testing CreateDocumentClientExceptionAsync when media type is application/json and the error message length is zero.
/// <see cref="GatewayStoreClient.CreateDocumentClientExceptionAsync(HttpResponseMessage, IClientSideRequestStatistics)"/>
/// </summary>
[TestMethod]
[DataRow("application/json", "")]
[DataRow("application/json", " ")]
[Owner("philipthomas-MSFT")]
public async Task TestCreateDocumentClientExceptionWhenMediaTypeIsApplicationJsonAndErrorMessageLengthIsZeroAsync(
string mediaType,
string errorMessage)
{
HttpResponseMessage responseMessage = new(statusCode: HttpStatusCode.NotFound)
{
RequestMessage = new HttpRequestMessage(
method: HttpMethod.Get,
requestUri: @"https://pt_ac_test_uri.com/"),
Content = new StringContent(
mediaType: mediaType,
encoding: Encoding.UTF8,
content: JsonConvert.SerializeObject(
value: new Error() { Code = HttpStatusCode.NotFound.ToString(), Message = errorMessage })),
};

DocumentClientException documentClientException = await GatewayStoreClient.CreateDocumentClientExceptionAsync(
responseMessage: responseMessage,
requestStatistics: GatewayStoreClientTests.CreateClientSideRequestStatistics());

Assert.IsNotNull(value: documentClientException);
Assert.AreEqual(expected: HttpStatusCode.NotFound, actual: documentClientException.StatusCode);
Assert.IsNotNull(value: documentClientException.Message);

Assert.IsNotNull(value: documentClientException.Error);
Assert.AreEqual(expected: HttpStatusCode.NotFound.ToString(), actual: documentClientException.Error.Code);
Assert.IsNotNull(value: documentClientException.Error.Message);
}

/// <summary>
/// Testing CreateDocumentClientExceptionAsync when media type is application/json and the content message is not valid json.
/// and has a content length that is not zero after trim.
/// <see cref="GatewayStoreClient.CreateDocumentClientExceptionAsync(HttpResponseMessage, IClientSideRequestStatistics)"/>
/// </summary>
[TestMethod]
[DataRow("application/json", @"<!DOCTYPE html><html><body></body></html>")]
[DataRow("application/json", @" <!DOCTYPE html><html><body></body></html>")]
[DataRow("application/json", @"<!DOCTYPE html><html><body></body></html> ")]
[DataRow("application/json", @" <!DOCTYPE html><html><body></body></html> ")]
[DataRow("application/json", @"ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890")]
[DataRow("application/json", @" ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890")]
[DataRow("application/json", @"ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890 ")]
[DataRow("application/json", @" ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890 ")]
[Owner("philipthomas-MSFT")]
public async Task TestCreateDocumentClientExceptionWhenMediaTypeIsApplicationJsonAndContentMessageIsNotValidJsonAsync(
string mediaType,
string contentMessage)
{
HttpResponseMessage responseMessage = new(statusCode: HttpStatusCode.NotFound)
{
RequestMessage = new HttpRequestMessage(
method: HttpMethod.Get,
requestUri: @"https://pt_ac_test_uri.com/"),
Content = new StringContent(
mediaType: mediaType,
encoding: Encoding.UTF8,
content: contentMessage),
};

DocumentClientException documentClientException = await GatewayStoreClient.CreateDocumentClientExceptionAsync(
responseMessage: responseMessage,
requestStatistics: GatewayStoreClientTests.CreateClientSideRequestStatistics());

Assert.IsNotNull(value: documentClientException);
Assert.AreEqual(expected: HttpStatusCode.NotFound, actual: documentClientException.StatusCode);
Assert.IsTrue(condition: documentClientException.Message.Contains(contentMessage));
}

/// <summary>
/// Testing CreateDocumentClientExceptionAsync when media type is application/json and the header content length is zero.
/// </summary>
[TestMethod]
[DataRow("application/json", @"")]
[DataRow("application/json", @" ")]
[Owner("philipthomas-MSFT")]
public async Task TestCreateDocumentClientExceptionWhenMediaTypeIsApplicationJsonAndHeaderContentLengthIsZeroAsync(
string mediaType,
string contentMessage)
{
HttpResponseMessage responseMessage = new(statusCode: HttpStatusCode.NotFound)
{
RequestMessage = new HttpRequestMessage(
method: HttpMethod.Get,
requestUri: @"https://pt_ac_test_uri.com/"),
Content = new StringContent(
mediaType: mediaType,
encoding: Encoding.UTF8,
content: contentMessage),
};

DocumentClientException documentClientException = await GatewayStoreClient.CreateDocumentClientExceptionAsync(
responseMessage: responseMessage,
requestStatistics: GatewayStoreClientTests.CreateClientSideRequestStatistics());

Assert.IsNotNull(value: documentClientException);
Assert.AreEqual(expected: HttpStatusCode.NotFound, actual: documentClientException.StatusCode);
Assert.IsNotNull(value: documentClientException.Message);

Assert.IsNotNull(value: documentClientException.Error);
Assert.AreEqual(expected: HttpStatusCode.NotFound.ToString(), actual: documentClientException.Error.Code);
Assert.IsNotNull(value: documentClientException.Error.Message);
}

private static IClientSideRequestStatistics CreateClientSideRequestStatistics()
{
return new ClientSideRequestStatisticsTraceDatum(
startTime: DateTime.UtcNow,
trace: NoOpTrace.Singleton);
}
}
}