Skip to content

Commit

Permalink
Fix Azure OpenAI double ?, and JSON test issues (#44457)
Browse files Browse the repository at this point in the history
* Fix double ? issues with generated URIs
* Test fixes
  - Fixes JSON issues due to moving to older JSON version
  - Fixes vector store test that was subject to race condition in status
  • Loading branch information
ralph-msft committed Jun 13, 2024
1 parent 410dc90 commit a4b9ab1
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 15 deletions.
13 changes: 8 additions & 5 deletions sdk/openai/Azure.AI.OpenAI/tests/BatchTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@ public class BatchTests : AoaiTestBase<BatchClient>
{
private static readonly JsonSerializerOptions JSON_OPTIONS = new()
{
// PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
PropertyNamingPolicy = JsonHelpers.SnakeCaseLower,
PropertyNameCaseInsensitive = true,
// DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
#if NETFRAMEWORK
IgnoreNullValues = true,
#else
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
#endif
Converters =
{
new ModelReaderWriterConverter()
Expand Down Expand Up @@ -139,7 +143,7 @@ private class BatchRequest
public BinaryContent ToBinaryContent()
{
using MemoryStream stream = new MemoryStream();
JsonSerializer.Serialize(stream, typeof(BatchRequest), JSON_OPTIONS);
JsonHelpers.Serialize(stream, this, JSON_OPTIONS);

stream.Seek(0, SeekOrigin.Begin);
var data = BinaryData.FromStream(stream);
Expand Down Expand Up @@ -181,7 +185,6 @@ private class BatchObject
{
public static BatchObject From(BinaryData data)
{
using var stream = data.ToStream();
return JsonSerializer.Deserialize<BatchObject>(data, JSON_OPTIONS)
?? throw new InvalidOperationException("Response was null JSON");
}
Expand Down Expand Up @@ -239,7 +242,7 @@ public async Task<string> UploadBatchFileAsync()
}

using MemoryStream stream = new MemoryStream();
// JsonSerializer.Serialize(stream, _operations, JSON_OPTIONS);
JsonHelpers.Serialize(stream, _operations, JSON_OPTIONS);
stream.Seek(0, SeekOrigin.Begin);
var data = BinaryData.FromStream(stream);

Expand Down
185 changes: 185 additions & 0 deletions sdk/openai/Azure.AI.OpenAI/tests/Utils/JsonHelpers.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
using System;
using System.Buffers;
using System.Globalization;
using System.IO;
using System.Runtime.CompilerServices;
using System.Text.Json;

#nullable enable

namespace Azure.AI.OpenAI.Tests.Utils;

/// <summary>
/// A helper class to make working with older versions of System.Text.Json simpler
/// </summary>
internal static class JsonHelpers
{
// TODO FIXME once we update to newer versions of System.Text.JSon we should switch to using
// JsonNamingPolicy.SnakeCaseLower
public static JsonNamingPolicy SnakeCaseLower { get; } =
new SnakeCaseNamingPolicy();

// TODO FIXME once we move to newer versions of System.Text.Json we can directly call
// JsonSerializer.Serialize(...) with a stream
public static void Serialize<T>(Stream stream, T value, JsonSerializerOptions? options = null)
{
#if NETFRAMEWORK
using Utf8JsonWriter writer = new(stream, new JsonWriterOptions()
{
Encoder = options?.Encoder,
Indented = options?.WriteIndented == true,
SkipValidation = false
});

JsonSerializer.Serialize(writer, value, options);
#else
JsonSerializer.Serialize(stream, value, options);
#endif
}

#if NET6_0_OR_GREATER
// .Net 6 and newer already have the extension method we need defined in JsonsSerializer
#else
// TODO FIXME once we move to newer versions of System.Text.Json we can directly use the
// JsonSerializer extension method for elements
public static T? Deserialize<T>(this JsonElement element, JsonSerializerOptions? options = null)
{
using MemoryStream stream = new();
using Utf8JsonWriter writer = new(stream, new()
{
Encoder = System.Text.Encodings.Web.JavaScriptEncoder.UnsafeRelaxedJsonEscaping,
Indented = false,
SkipValidation = true
});
element.WriteTo(writer);
writer.Flush();

stream.Seek(0, SeekOrigin.Begin);
if (((ulong)stream.Length & 0xffffffff00000000) != 0ul)
{
throw new ArgumentOutOfRangeException("JsonElement is too large");
}

ReadOnlySpan<byte> span = new(stream.GetBuffer(), 0, (int)stream.Length);
return JsonSerializer.Deserialize<T>(span, options);
}
#endif

// Ported over from the source code for newer versions of System.Text.Json
internal class SnakeCaseNamingPolicy : JsonNamingPolicy
{
private enum SeparatorState
{
NotStarted,
UppercaseLetter,
LowercaseLetterOrDigit,
SpaceSeparator
}

public override string ConvertName(string name)
{
if (string.IsNullOrEmpty(name))
{
return string.Empty;
}

return ConvertName('_', name.AsSpan());
}

internal static string ConvertName(char separator, ReadOnlySpan<char> chars)
{
char[]? rentedBuffer = null;

int num = (int)(1.2 * chars.Length);
Span<char> output = num > 128
? (rentedBuffer = ArrayPool<char>.Shared.Rent(num))!
: stackalloc char[128];

SeparatorState separatorState = SeparatorState.NotStarted;
int charsWritten = 0;

for (int i = 0; i < chars.Length; i++)
{
char c = chars[i];
UnicodeCategory unicodeCategory = char.GetUnicodeCategory(c);
switch (unicodeCategory)
{
case UnicodeCategory.UppercaseLetter:
switch (separatorState)
{
case SeparatorState.LowercaseLetterOrDigit:
case SeparatorState.SpaceSeparator:
WriteChar(separator, ref output);
break;
case SeparatorState.UppercaseLetter:
if (i + 1 < chars.Length && char.IsLower(chars[i + 1]))
{
WriteChar(separator, ref output);
}
break;
}

c = char.ToLowerInvariant(c);
WriteChar(c, ref output);
separatorState = SeparatorState.UppercaseLetter;
break;

case UnicodeCategory.LowercaseLetter:
case UnicodeCategory.DecimalDigitNumber:
if (separatorState == SeparatorState.SpaceSeparator)
{
WriteChar(separator, ref output);
}

WriteChar(c, ref output);
separatorState = SeparatorState.LowercaseLetterOrDigit;
break;

case UnicodeCategory.SpaceSeparator:
if (separatorState != 0)
{
separatorState = SeparatorState.SpaceSeparator;
}
break;

default:
WriteChar(c, ref output);
separatorState = SeparatorState.NotStarted;
break;
}
}

string result = output.Slice(0, charsWritten).ToString();
if (rentedBuffer != null)
{
output.Slice(0, charsWritten).Clear();
ArrayPool<char>.Shared.Return(rentedBuffer);
}
return result;

void ExpandBuffer(ref Span<char> destination)
{
int minimumLength = checked(destination.Length * 2);
char[] array = ArrayPool<char>.Shared.Rent(minimumLength);
destination.CopyTo(array);
if (rentedBuffer != null)
{
destination.Slice(0, charsWritten).Clear();
ArrayPool<char>.Shared.Return(rentedBuffer);
}
rentedBuffer = array;
destination = rentedBuffer;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
void WriteChar(char value, ref Span<char> destination)
{
if (charsWritten == destination.Length)
{
ExpandBuffer(ref destination);
}
destination[charsWritten++] = value;
}
}
}
}
50 changes: 41 additions & 9 deletions sdk/openai/Azure.AI.OpenAI/tests/Utils/TestConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using Azure.AI.OpenAI.Tests.Utils;

namespace Azure.AI.OpenAI.Tests;

Expand Down Expand Up @@ -48,11 +49,11 @@ public TestConfig()
return JsonSerializer.Deserialize<Dictionary<string, Config>>(json, new JsonSerializerOptions()
{
PropertyNameCaseInsensitive = true,
// PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
// DictionaryKeyPolicy = JsonNamingPolicy.SnakeCaseLower,
PropertyNamingPolicy = JsonHelpers.SnakeCaseLower,
DictionaryKeyPolicy = JsonHelpers.SnakeCaseLower,
Converters =
{
new UnSnakeCaseDictConverter()
new UnSnakeCaseDictKeyConverter()
}
});
}
Expand Down Expand Up @@ -203,14 +204,14 @@ public T GetValueOrDefault<T>(string name)

if (ExtensionData?.TryGetValue(name, out JsonElement element) == true)
{
// val = element.Deserialize<T>()!;
val = element.Deserialize<T>()!;
}

return val ?? default(T)!;
}
}

private class UnSnakeCaseDictConverter : JsonConverterFactory
private class UnSnakeCaseDictKeyConverter : JsonConverterFactory
{
public override bool CanConvert(Type typeToConvert)
{
Expand All @@ -232,12 +233,43 @@ public override bool CanConvert(Type typeToConvert)

private class InnerConverter<TValue> : JsonConverter<Dictionary<string, TValue>>
{
private readonly JsonConverter<TValue> _converter;
private readonly Type _valueType = typeof(TValue);
private JsonSerializerOptions _options;

public InnerConverter(JsonSerializerOptions options)
{
_converter = (JsonConverter<TValue>)options.GetConverter(typeof(TValue));
#if NETFRAMEWORK
_options = new()
{
AllowTrailingCommas = options.AllowTrailingCommas,
DefaultBufferSize = options.DefaultBufferSize,
DictionaryKeyPolicy = options.DictionaryKeyPolicy,
Encoder = options.Encoder,
IgnoreReadOnlyProperties = options.IgnoreReadOnlyProperties,
MaxDepth = options.MaxDepth,
PropertyNameCaseInsensitive = options.PropertyNameCaseInsensitive,
PropertyNamingPolicy = options.PropertyNamingPolicy,
ReadCommentHandling = options.ReadCommentHandling,
WriteIndented = options.WriteIndented,
IgnoreNullValues = options.IgnoreNullValues,
};
#else
_options = new(options);
_options.Converters.Clear();
#endif

if (options.Converters?.Count > 0)
{
var thisType = GetType();

foreach (var conv in options.Converters)
{
if (conv is not UnSnakeCaseDictKeyConverter
&& !thisType.IsAssignableFrom(conv.GetType()))
{
_options.Converters.Add(conv);
}
}
}
}

public override Dictionary<string, TValue> Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
Expand Down Expand Up @@ -285,7 +317,7 @@ public InnerConverter(JsonSerializerOptions options)
builder.Clear();

reader.Read();
TValue? val = _converter.Read(ref reader, _valueType, options);
TValue? val = JsonSerializer.Deserialize<TValue>(ref reader, _options);

dict[propertyName] = val!;
}
Expand Down
2 changes: 1 addition & 1 deletion sdk/openai/Azure.AI.OpenAI/tests/VectorStoreTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ public async Task CanAssociateFiles()
Assert.That(association.VectorStoreId, Is.EqualTo(vectorStore.Id));
Assert.That(association.LastError, Is.Null);
Assert.That(association.CreatedAt, Is.GreaterThan(s_2024));
Assert.That(association.Status, Is.EqualTo(VectorStoreFileAssociationStatus.InProgress));
Assert.That(association.Status, Is.AnyOf(VectorStoreFileAssociationStatus.InProgress, VectorStoreFileAssociationStatus.Completed));
});
}

Expand Down

0 comments on commit a4b9ab1

Please sign in to comment.