Skip to content

Commit

Permalink
[OpenAI] Migrate azure-json serialization to replace jackson-databind (
Browse files Browse the repository at this point in the history
  • Loading branch information
mssfang committed Apr 30, 2024
1 parent 537c06d commit 4180481
Show file tree
Hide file tree
Showing 158 changed files with 6,710 additions and 2,012 deletions.
1 change: 1 addition & 0 deletions sdk/openai/azure-ai-openai/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

### Breaking Changes

- Replaced Jackson Databind annotations with `azure-json` functionality for OpenAI service models.
- [AOAI] Added a new class `ContentFilterDetailedResults` to represent detailed content filter results, which replaces the
`customBlocklists` response property type, `List<ContentFilterBlocklistIdResult>` in
`ContentFilterResultDetailsForPrompt` and `ContentFilterResultsForChoice` class.
Expand Down
2 changes: 1 addition & 1 deletion sdk/openai/azure-ai-openai/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "java",
"TagPrefix": "java/openai/azure-ai-openai",
"Tag": "java/openai/azure-ai-openai_8fd1810cfb"
"Tag": "java/openai/azure-ai-openai_589fab4377"
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,51 +14,17 @@
/**
* This class contains the customization code to customize the AutoRest generated code for OpenAI.
*/
public class ChatCompletionsToolCallCustomizations extends Customization {
public class OpenAICustomizations extends Customization {

@Override
public void customize(LibraryCustomization customization, Logger logger) {
customizeChatCompletionsToolCall(customization, logger);
customizeEmbeddingEncodingFormatClass(customization, logger);
customizeEmbeddingsOptions(customization, logger);
}

public void customizeChatCompletionsToolCall(LibraryCustomization customization, Logger logger) {
logger.info("Customizing the ChatCompletionsToolCall class");
PackageCustomization packageCustomization = customization.getPackage("com.azure.ai.openai.models");
ClassCustomization classCustomization = packageCustomization.getClass("ChatCompletionsToolCall");

// Replace JsonTypeInfo annotation
classCustomization.removeAnnotation("JsonTypeInfo");
classCustomization.addAnnotation("JsonTypeInfo(use = JsonTypeInfo.Id.DEDUCTION, defaultImpl = ChatCompletionsToolCall.class)");

// Edit constructor
ConstructorCustomization constructorCustomization = classCustomization.getConstructor("ChatCompletionsToolCall")
.replaceParameters("@JsonProperty(value = \"id\") String id, @JsonProperty(value = \"type\")String type")
.replaceBody(joinWithNewline(
"this.id = id;",
"this.type = type;"));
JavadocCustomization constructorJavadocCustomization = constructorCustomization.getJavadoc()
.setParam("type", "the type value to set.");

// Remove type and getter in ChatCompletionsFunctionToolCall
classCustomization = packageCustomization.getClass("ChatCompletionsFunctionToolCall");
classCustomization.removeMethod("getType");
classCustomization.customizeAst(compilationUnit -> {
ClassOrInterfaceDeclaration clazz = compilationUnit.getClassByName("ChatCompletionsFunctionToolCall").get();
clazz.getMembers().removeIf(node -> {
if (node.isFieldDeclaration()
&& node.asFieldDeclaration().getVariables() != null && !node.asFieldDeclaration().getVariables().isEmpty()) {
return "type".equals(node.asFieldDeclaration().getVariables().get(0).getName().asString());
}
return false;
});
});

// remove unused class (no reference to them, after partial-update)
customization.getRawEditor().removeFile("src/main/java/com/azure/ai/openai/models/FileDetails.java");
customization.getRawEditor().removeFile("src/main/java/com/azure/ai/openai/implementation/MultipartFormDataHelper.java");
customizeEmbeddingEncodingFormatClass(customization, logger);
customizeEmbeddingsOptions(customization, logger);
}

private void customizeEmbeddingEncodingFormatClass(LibraryCustomization customization, Logger logger) {
logger.info("Customizing the EmbeddingEncodingFormat class");
ClassCustomization embeddingEncodingFormatClass = customization.getPackage("com.azure.ai.openai.models").getClass("EmbeddingEncodingFormat");
Expand Down
2 changes: 0 additions & 2 deletions sdk/openai/azure-ai-openai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@
--add-exports com.azure.core/com.azure.core.implementation.http=ALL-UNNAMED
--add-exports com.azure.core/com.azure.core.implementation.util=ALL-UNNAMED
--add-opens com.azure.ai.openai/com.azure.ai.openai=ALL-UNNAMED
--add-opens com.azure.ai.openai/com.azure.ai.openai.implementation=com.fasterxml.jackson.databind
--add-opens com.azure.ai.openai/com.azure.ai.openai.functions=com.fasterxml.jackson.databind
</javaModulesSurefireArgLine>
</properties>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,22 @@
import com.azure.core.util.FluxUtil;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import static com.azure.ai.openai.implementation.AudioTranscriptionValidator.validateAudioResponseFormatForTranscription;
import static com.azure.ai.openai.implementation.AudioTranscriptionValidator.validateAudioResponseFormatForTranscriptionText;
import static com.azure.ai.openai.implementation.AudioTranslationValidator.validateAudioResponseFormatForTranslation;
import static com.azure.ai.openai.implementation.AudioTranslationValidator.validateAudioResponseFormatForTranslationText;
import static com.azure.ai.openai.implementation.EmbeddingsUtils.addEncodingFormat;
import static com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl.addModelIdJson;
import static com.azure.core.util.FluxUtil.monoError;
import com.azure.ai.openai.implementation.CompletionsUtils;
import com.azure.ai.openai.implementation.MultipartDataHelper;
import com.azure.ai.openai.implementation.MultipartDataSerializationResult;
import com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl;
import com.azure.ai.openai.implementation.OpenAIServerSentEvents;
import com.azure.core.util.CoreUtils;
import com.azure.core.util.logging.ClientLogger;
import com.fasterxml.jackson.core.JsonProcessingException;
import java.nio.ByteBuffer;
import java.util.concurrent.atomic.AtomicReference;
import static com.azure.ai.openai.implementation.AudioTranscriptionValidator.validateAudioResponseFormatForTranscription;
import static com.azure.ai.openai.implementation.AudioTranscriptionValidator.validateAudioResponseFormatForTranscriptionText;
import static com.azure.ai.openai.implementation.AudioTranslationValidator.validateAudioResponseFormatForTranslation;
import static com.azure.ai.openai.implementation.AudioTranslationValidator.validateAudioResponseFormatForTranslationText;
import static com.azure.ai.openai.implementation.EmbeddingsUtils.addEncodingFormat;
import static com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl.addModelIdJson;
import static com.azure.core.util.FluxUtil.monoError;

/**
* Initializes a new instance of the asynchronous OpenAIClient type.
Expand Down Expand Up @@ -140,15 +139,12 @@ public final class OpenAIAsyncClient {
@ServiceMethod(returns = ReturnType.SINGLE)
public Mono<Response<BinaryData>> getEmbeddingsWithResponse(String deploymentOrModelName,
BinaryData embeddingsOptions, RequestOptions requestOptions) {
try {
embeddingsOptions = addEncodingFormat(embeddingsOptions);
} catch (JsonProcessingException e) {
return Mono.error(new RuntimeException(e));
}
final BinaryData embeddingsOptionsUpdated = addEncodingFormat(embeddingsOptions);
return openAIServiceClient != null
? openAIServiceClient.getEmbeddingsWithResponseAsync(deploymentOrModelName, embeddingsOptions,
? openAIServiceClient.getEmbeddingsWithResponseAsync(deploymentOrModelName, embeddingsOptionsUpdated,
requestOptions)
: serviceClient.getEmbeddingsWithResponseAsync(deploymentOrModelName, embeddingsOptions, requestOptions);
: serviceClient.getEmbeddingsWithResponseAsync(deploymentOrModelName, embeddingsOptionsUpdated,
requestOptions);
}

/**
Expand Down Expand Up @@ -1482,17 +1478,13 @@ Mono<AudioTranslation> getAudioTranslationAsResponseObject(String deploymentOrMo
public Mono<Response<BinaryData>> generateSpeechFromTextWithResponse(String deploymentOrModelName,
BinaryData speechGenerationOptions, RequestOptions requestOptions) {
// modelId is part of the request body in nonAzure OpenAI
try {
BinaryData speechGenerationOptionsWithModelId
= addModelIdJson(speechGenerationOptions, deploymentOrModelName);
return this.openAIServiceClient != null
? this.openAIServiceClient.generateSpeechFromTextWithResponseAsync(deploymentOrModelName,
speechGenerationOptionsWithModelId, requestOptions)
: this.serviceClient.generateSpeechFromTextWithResponseAsync(deploymentOrModelName,
speechGenerationOptionsWithModelId, requestOptions);
} catch (JsonProcessingException e) {
return Mono.error(e);
}
final BinaryData speechGenerationOptionsWithModelId
= addModelIdJson(speechGenerationOptions, deploymentOrModelName);
return this.openAIServiceClient != null
? this.openAIServiceClient.generateSpeechFromTextWithResponseAsync(deploymentOrModelName,
speechGenerationOptionsWithModelId, requestOptions)
: this.serviceClient.generateSpeechFromTextWithResponseAsync(deploymentOrModelName,
speechGenerationOptionsWithModelId, requestOptions);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,21 @@
import com.azure.core.http.rest.SimpleResponse;
import com.azure.core.util.BinaryData;
import com.azure.core.util.logging.ClientLogger;
import static com.azure.ai.openai.implementation.AudioTranscriptionValidator.validateAudioResponseFormatForTranscription;
import static com.azure.ai.openai.implementation.AudioTranscriptionValidator.validateAudioResponseFormatForTranscriptionText;
import static com.azure.ai.openai.implementation.AudioTranslationValidator.validateAudioResponseFormatForTranslation;
import static com.azure.ai.openai.implementation.AudioTranslationValidator.validateAudioResponseFormatForTranslationText;
import static com.azure.ai.openai.implementation.EmbeddingsUtils.addEncodingFormat;
import static com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl.addModelIdJson;
import com.azure.ai.openai.implementation.CompletionsUtils;
import com.azure.ai.openai.implementation.MultipartDataHelper;
import com.azure.ai.openai.implementation.MultipartDataSerializationResult;
import com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl;
import com.azure.ai.openai.implementation.OpenAIServerSentEvents;
import com.azure.core.util.CoreUtils;
import com.azure.core.util.IterableStream;
import com.fasterxml.jackson.core.JsonProcessingException;
import java.nio.ByteBuffer;
import reactor.core.publisher.Flux;
import java.nio.ByteBuffer;
import static com.azure.ai.openai.implementation.AudioTranscriptionValidator.validateAudioResponseFormatForTranscription;
import static com.azure.ai.openai.implementation.AudioTranscriptionValidator.validateAudioResponseFormatForTranscriptionText;
import static com.azure.ai.openai.implementation.AudioTranslationValidator.validateAudioResponseFormatForTranslation;
import static com.azure.ai.openai.implementation.AudioTranslationValidator.validateAudioResponseFormatForTranslationText;
import static com.azure.ai.openai.implementation.EmbeddingsUtils.addEncodingFormat;
import static com.azure.ai.openai.implementation.NonAzureOpenAIClientImpl.addModelIdJson;

/**
* Initializes a new instance of the synchronous OpenAIClient type.
Expand Down Expand Up @@ -133,14 +132,11 @@ public final class OpenAIClient {
@ServiceMethod(returns = ReturnType.SINGLE)
public Response<BinaryData> getEmbeddingsWithResponse(String deploymentOrModelName, BinaryData embeddingsOptions,
RequestOptions requestOptions) {
try {
embeddingsOptions = addEncodingFormat(embeddingsOptions);
} catch (JsonProcessingException e) {
throw LOGGER.logExceptionAsWarning(new RuntimeException(e));
}
final BinaryData embeddingsOptionsUpdated = addEncodingFormat(embeddingsOptions);
return openAIServiceClient != null
? openAIServiceClient.getEmbeddingsWithResponse(deploymentOrModelName, embeddingsOptions, requestOptions)
: serviceClient.getEmbeddingsWithResponse(deploymentOrModelName, embeddingsOptions, requestOptions);
? openAIServiceClient.getEmbeddingsWithResponse(deploymentOrModelName, embeddingsOptionsUpdated,
requestOptions)
: serviceClient.getEmbeddingsWithResponse(deploymentOrModelName, embeddingsOptionsUpdated, requestOptions);
}

/**
Expand Down Expand Up @@ -1441,12 +1437,8 @@ AudioTranslation getAudioTranslationAsResponseObject(String deploymentOrModelNam
@ServiceMethod(returns = ReturnType.SINGLE)
public Response<BinaryData> generateSpeechFromTextWithResponse(String deploymentOrModelName,
BinaryData speechGenerationOptions, RequestOptions requestOptions) {
BinaryData speechGenerationOptionsWithModelId = null;
try {
speechGenerationOptionsWithModelId = addModelIdJson(speechGenerationOptions, deploymentOrModelName);
} catch (JsonProcessingException e) {
throw LOGGER.logExceptionAsError(new RuntimeException(e));
}
final BinaryData speechGenerationOptionsWithModelId
= addModelIdJson(speechGenerationOptions, deploymentOrModelName);
return openAIServiceClient != null
? this.openAIServiceClient.generateSpeechFromTextWithResponse(speechGenerationOptionsWithModelId,
requestOptions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,28 @@

import com.azure.core.util.Base64Util;
import com.azure.core.util.BinaryData;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/** This class contains convenience methods and constants for operations related to Embeddings */
public final class EmbeddingsUtils {
private static final ObjectMapper JSON_MAPPER = new ObjectMapper();

// This method applies to both AOAI and OAI clients
public static BinaryData addEncodingFormat(BinaryData inputJson) throws JsonProcessingException {
@SuppressWarnings("unchecked")
public static BinaryData addEncodingFormat(BinaryData inputJson) {
Map<String, Object> mapJson = inputJson.toObject(Map.class);

JsonNode jsonNode = JSON_MAPPER.readTree(inputJson.toString());
if (jsonNode instanceof ObjectNode) {
ObjectNode objectNode = (ObjectNode) jsonNode;
if (objectNode.has("encoding_format")) {
return inputJson;
}

objectNode.put("encoding_format", "base64");
inputJson = BinaryData.fromBytes(objectNode.toString().getBytes(StandardCharsets.UTF_8));
if (mapJson.containsKey("encoding_format")) {
return inputJson;
}

return inputJson;
mapJson.put("encoding_format", "base64");
return BinaryData.fromObject(mapJson);
}

// This method converts a base64 string to a list of floats
Expand Down
Loading

0 comments on commit 4180481

Please sign in to comment.