Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Azure OpenAI double ?, and JSON test issues #44457

Merged
merged 6 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions sdk/openai/Azure.AI.OpenAI/tests/BatchTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@ public class BatchTests : AoaiTestBase<BatchClient>
{
private static readonly JsonSerializerOptions JSON_OPTIONS = new()
{
// PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
PropertyNamingPolicy = JsonHelpers.SnakeCaseLower,
PropertyNameCaseInsensitive = true,
// DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
#if NETFRAMEWORK
IgnoreNullValues = true,
#else
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
#endif
Converters =
{
new ModelReaderWriterConverter()
Expand Down Expand Up @@ -139,7 +143,7 @@ private class BatchRequest
public BinaryContent ToBinaryContent()
{
using MemoryStream stream = new MemoryStream();
JsonSerializer.Serialize(stream, typeof(BatchRequest), JSON_OPTIONS);
JsonHelpers.Serialize(stream, this, JSON_OPTIONS);

stream.Seek(0, SeekOrigin.Begin);
var data = BinaryData.FromStream(stream);
Expand Down Expand Up @@ -181,7 +185,6 @@ private class BatchObject
{
public static BatchObject From(BinaryData data)
{
using var stream = data.ToStream();
return JsonSerializer.Deserialize<BatchObject>(data, JSON_OPTIONS)
?? throw new InvalidOperationException("Response was null JSON");
}
Expand Down Expand Up @@ -239,7 +242,7 @@ public async Task<string> UploadBatchFileAsync()
}

using MemoryStream stream = new MemoryStream();
// JsonSerializer.Serialize(stream, _operations, JSON_OPTIONS);
JsonHelpers.Serialize(stream, _operations, JSON_OPTIONS);
stream.Seek(0, SeekOrigin.Begin);
var data = BinaryData.FromStream(stream);

Expand Down
185 changes: 185 additions & 0 deletions sdk/openai/Azure.AI.OpenAI/tests/Utils/JsonHelpers.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
using System;
using System.Buffers;
using System.Globalization;
using System.IO;
using System.Runtime.CompilerServices;
using System.Text.Json;

#nullable enable

namespace Azure.AI.OpenAI.Tests.Utils;

/// <summary>
/// A helper class to make working with older versions of System.Text.Json simpler
/// </summary>
internal static class JsonHelpers
{
// TODO FIXME once we update to newer versions of System.Text.JSon we should switch to using
// JsonNamingPolicy.SnakeCaseLower
public static JsonNamingPolicy SnakeCaseLower { get; } =
new SnakeCaseNamingPolicy();

// TODO FIXME once we move to newer versions of System.Text.Json we can directly call
// JsonSerializer.Serialize(...) with a stream
public static void Serialize<T>(Stream stream, T value, JsonSerializerOptions? options = null)
{
#if NETFRAMEWORK
using Utf8JsonWriter writer = new(stream, new JsonWriterOptions()
{
Encoder = options?.Encoder,
Indented = options?.WriteIndented == true,
SkipValidation = false
});

JsonSerializer.Serialize(writer, value, options);
#else
JsonSerializer.Serialize(stream, value, options);
#endif
}

#if NET6_0_OR_GREATER
// .Net 6 and newer already have the extension method we need defined in JsonsSerializer
#else
// TODO FIXME once we move to newer versions of System.Text.Json we can directly use the
// JsonSerializer extension method for elements
public static T? Deserialize<T>(this JsonElement element, JsonSerializerOptions? options = null)
{
using MemoryStream stream = new();
using Utf8JsonWriter writer = new(stream, new()
{
Encoder = System.Text.Encodings.Web.JavaScriptEncoder.UnsafeRelaxedJsonEscaping,
Indented = false,
SkipValidation = true
});
element.WriteTo(writer);
writer.Flush();

stream.Seek(0, SeekOrigin.Begin);
if (((ulong)stream.Length & 0xffffffff00000000) != 0ul)
{
throw new ArgumentOutOfRangeException("JsonElement is too large");
}

ReadOnlySpan<byte> span = new(stream.GetBuffer(), 0, (int)stream.Length);
return JsonSerializer.Deserialize<T>(span, options);
}
#endif

// Ported over from the source code for newer versions of System.Text.Json
internal class SnakeCaseNamingPolicy : JsonNamingPolicy
{
private enum SeparatorState
{
NotStarted,
UppercaseLetter,
LowercaseLetterOrDigit,
SpaceSeparator
}

public override string ConvertName(string name)
{
if (string.IsNullOrEmpty(name))
{
return string.Empty;
}

return ConvertName('_', name.AsSpan());
}

internal static string ConvertName(char separator, ReadOnlySpan<char> chars)
{
char[]? rentedBuffer = null;

int num = (int)(1.2 * chars.Length);
Span<char> output = num > 128
? (rentedBuffer = ArrayPool<char>.Shared.Rent(num))!
: stackalloc char[128];

SeparatorState separatorState = SeparatorState.NotStarted;
int charsWritten = 0;

for (int i = 0; i < chars.Length; i++)
{
char c = chars[i];
UnicodeCategory unicodeCategory = char.GetUnicodeCategory(c);
switch (unicodeCategory)
{
case UnicodeCategory.UppercaseLetter:
switch (separatorState)
{
case SeparatorState.LowercaseLetterOrDigit:
case SeparatorState.SpaceSeparator:
WriteChar(separator, ref output);
break;
case SeparatorState.UppercaseLetter:
if (i + 1 < chars.Length && char.IsLower(chars[i + 1]))
{
WriteChar(separator, ref output);
}
break;
}

c = char.ToLowerInvariant(c);
WriteChar(c, ref output);
separatorState = SeparatorState.UppercaseLetter;
break;

case UnicodeCategory.LowercaseLetter:
case UnicodeCategory.DecimalDigitNumber:
if (separatorState == SeparatorState.SpaceSeparator)
{
WriteChar(separator, ref output);
}

WriteChar(c, ref output);
separatorState = SeparatorState.LowercaseLetterOrDigit;
break;

case UnicodeCategory.SpaceSeparator:
if (separatorState != 0)
{
separatorState = SeparatorState.SpaceSeparator;
}
break;

default:
WriteChar(c, ref output);
separatorState = SeparatorState.NotStarted;
break;
}
}

string result = output.Slice(0, charsWritten).ToString();
if (rentedBuffer != null)
{
output.Slice(0, charsWritten).Clear();
ArrayPool<char>.Shared.Return(rentedBuffer);
}
return result;

void ExpandBuffer(ref Span<char> destination)
{
int minimumLength = checked(destination.Length * 2);
char[] array = ArrayPool<char>.Shared.Rent(minimumLength);
destination.CopyTo(array);
if (rentedBuffer != null)
{
destination.Slice(0, charsWritten).Clear();
ArrayPool<char>.Shared.Return(rentedBuffer);
}
rentedBuffer = array;
destination = rentedBuffer;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
void WriteChar(char value, ref Span<char> destination)
{
if (charsWritten == destination.Length)
{
ExpandBuffer(ref destination);
}
destination[charsWritten++] = value;
}
}
}
}
50 changes: 41 additions & 9 deletions sdk/openai/Azure.AI.OpenAI/tests/Utils/TestConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using Azure.AI.OpenAI.Tests.Utils;

namespace Azure.AI.OpenAI.Tests;

Expand Down Expand Up @@ -48,11 +49,11 @@ public TestConfig()
return JsonSerializer.Deserialize<Dictionary<string, Config>>(json, new JsonSerializerOptions()
{
PropertyNameCaseInsensitive = true,
// PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
// DictionaryKeyPolicy = JsonNamingPolicy.SnakeCaseLower,
PropertyNamingPolicy = JsonHelpers.SnakeCaseLower,
DictionaryKeyPolicy = JsonHelpers.SnakeCaseLower,
Converters =
{
new UnSnakeCaseDictConverter()
new UnSnakeCaseDictKeyConverter()
}
});
}
Expand Down Expand Up @@ -203,14 +204,14 @@ public T GetValueOrDefault<T>(string name)

if (ExtensionData?.TryGetValue(name, out JsonElement element) == true)
{
// val = element.Deserialize<T>()!;
val = element.Deserialize<T>()!;
}

return val ?? default(T)!;
}
}

private class UnSnakeCaseDictConverter : JsonConverterFactory
private class UnSnakeCaseDictKeyConverter : JsonConverterFactory
{
public override bool CanConvert(Type typeToConvert)
{
Expand All @@ -232,12 +233,43 @@ public override bool CanConvert(Type typeToConvert)

private class InnerConverter<TValue> : JsonConverter<Dictionary<string, TValue>>
{
private readonly JsonConverter<TValue> _converter;
private readonly Type _valueType = typeof(TValue);
private JsonSerializerOptions _options;

public InnerConverter(JsonSerializerOptions options)
{
_converter = (JsonConverter<TValue>)options.GetConverter(typeof(TValue));
#if NETFRAMEWORK
_options = new()
{
AllowTrailingCommas = options.AllowTrailingCommas,
DefaultBufferSize = options.DefaultBufferSize,
DictionaryKeyPolicy = options.DictionaryKeyPolicy,
Encoder = options.Encoder,
IgnoreReadOnlyProperties = options.IgnoreReadOnlyProperties,
MaxDepth = options.MaxDepth,
PropertyNameCaseInsensitive = options.PropertyNameCaseInsensitive,
PropertyNamingPolicy = options.PropertyNamingPolicy,
ReadCommentHandling = options.ReadCommentHandling,
WriteIndented = options.WriteIndented,
IgnoreNullValues = options.IgnoreNullValues,
};
#else
_options = new(options);
_options.Converters.Clear();
#endif

if (options.Converters?.Count > 0)
{
var thisType = GetType();

foreach (var conv in options.Converters)
{
if (conv is not UnSnakeCaseDictKeyConverter
&& !thisType.IsAssignableFrom(conv.GetType()))
{
_options.Converters.Add(conv);
}
}
}
}

public override Dictionary<string, TValue> Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
Expand Down Expand Up @@ -285,7 +317,7 @@ public InnerConverter(JsonSerializerOptions options)
builder.Clear();

reader.Read();
TValue? val = _converter.Read(ref reader, _valueType, options);
TValue? val = JsonSerializer.Deserialize<TValue>(ref reader, _options);

dict[propertyName] = val!;
}
Expand Down
2 changes: 1 addition & 1 deletion sdk/openai/Azure.AI.OpenAI/tests/VectorStoreTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ public async Task CanAssociateFiles()
Assert.That(association.VectorStoreId, Is.EqualTo(vectorStore.Id));
Assert.That(association.LastError, Is.Null);
Assert.That(association.CreatedAt, Is.GreaterThan(s_2024));
Assert.That(association.Status, Is.EqualTo(VectorStoreFileAssociationStatus.InProgress));
Assert.That(association.Status, Is.AnyOf(VectorStoreFileAssociationStatus.InProgress, VectorStoreFileAssociationStatus.Completed));
});
}

Expand Down