) responseMap.get("error");
+ if (error != null) {
+ var message = (String) error.get("message");
+ if (message != null) {
+ return new AzureAndOpenAiErrorResponseEntity(message);
+ }
+ }
+ } catch (Exception e) {
+ // swallow the error
+ }
+
+ return null;
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiExternalResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiExternalResponseHandler.java
new file mode 100644
index 0000000000000..5f803ad6fe74e
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiExternalResponseHandler.java
@@ -0,0 +1,149 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.response;
+
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
+import org.elasticsearch.xpack.inference.external.http.retry.ContentTooLargeException;
+import org.elasticsearch.xpack.inference.external.http.retry.ErrorMessage;
+import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
+import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
+import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+
+import java.util.function.Function;
+
+import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;
+import static org.elasticsearch.xpack.inference.external.http.retry.ResponseHandlerUtils.getFirstHeaderOrUnknown;
+
+/**
+ * A base class to use for external response handling.
+ *
+ * This currently covers response handling for Azure AI Studio, however this pattern
+ * can be used to simplify and refactor handling for Azure OpenAI and OpenAI responses.
+ */
+public class AzureAndOpenAiExternalResponseHandler extends BaseResponseHandler {
+
+ // The maximum number of requests that are permitted before exhausting the rate limit.
+ static final String REQUESTS_LIMIT = "x-ratelimit-limit-requests";
+ // The maximum number of tokens that are permitted before exhausting the rate limit.
+ static final String TOKENS_LIMIT = "x-ratelimit-limit-tokens";
+ // The remaining number of requests that are permitted before exhausting the rate limit.
+ static final String REMAINING_REQUESTS = "x-ratelimit-remaining-requests";
+ // The remaining number of tokens that are permitted before exhausting the rate limit.
+ static final String REMAINING_TOKENS = "x-ratelimit-remaining-tokens";
+
+ static final String CONTENT_TOO_LARGE_MESSAGE = "Please reduce your prompt; or completion length.";
+ static final String SERVER_BUSY_ERROR = "Received a server busy error status code";
+
+ public AzureAndOpenAiExternalResponseHandler(
+ String requestType,
+ ResponseParser parseFunction,
+ Function errorParseFunction
+ ) {
+ super(requestType, parseFunction, errorParseFunction);
+ }
+
+ @Override
+ public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
+ throws RetryException {
+ checkForFailureStatusCode(request, result);
+ checkForEmptyBody(throttlerManager, logger, request, result);
+ }
+
+ public void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
+ int statusCode = result.response().getStatusLine().getStatusCode();
+ if (statusCode >= 200 && statusCode < 300) {
+ return;
+ }
+
+ // handle error codes
+ if (statusCode == 500) {
+ throw handle500Error(request, result);
+ } else if (statusCode == 503) {
+ throw handle503Error(request, result);
+ } else if (statusCode > 500) {
+ throw handleOther500Error(request, result);
+ } else if (statusCode == 429) {
+ throw handleRateLimitingError(request, result);
+ } else if (isContentTooLarge(result)) {
+ throw new ContentTooLargeException(buildError(CONTENT_TOO_LARGE, request, result));
+ } else if (statusCode == 401) {
+ throw handleAuthenticationError(request, result);
+ } else if (statusCode >= 300 && statusCode < 400) {
+ throw handleRedirectionStatusCode(request, result);
+ } else {
+ throw new RetryException(false, buildError(UNSUCCESSFUL, request, result));
+ }
+ }
+
+ protected RetryException handle500Error(Request request, HttpResult result) {
+ return new RetryException(true, buildError(SERVER_ERROR, request, result));
+ }
+
+ protected RetryException handle503Error(Request request, HttpResult result) {
+ return new RetryException(true, buildError(SERVER_BUSY_ERROR, request, result));
+ }
+
+ protected RetryException handleOther500Error(Request request, HttpResult result) {
+ return new RetryException(false, buildError(SERVER_ERROR, request, result));
+ }
+
+ protected RetryException handleAuthenticationError(Request request, HttpResult result) {
+ return new RetryException(false, buildError(AUTHENTICATION, request, result));
+ }
+
+ protected RetryException handleRateLimitingError(Request request, HttpResult result) {
+ return new RetryException(true, buildError(buildRateLimitErrorMessage(result), request, result));
+ }
+
+ protected RetryException handleRedirectionStatusCode(Request request, HttpResult result) {
+ throw new RetryException(false, buildError(REDIRECTION, request, result));
+ }
+
+ public static boolean isContentTooLarge(HttpResult result) {
+ int statusCode = result.response().getStatusLine().getStatusCode();
+
+ if (statusCode == 413) {
+ return true;
+ }
+
+ if (statusCode == 400) {
+ var errorEntity = AzureAndOpenAiErrorResponseEntity.fromResponse(result);
+ return errorEntity != null && errorEntity.getErrorMessage().contains(CONTENT_TOO_LARGE_MESSAGE);
+ }
+
+ return false;
+ }
+
+ public static String buildRateLimitErrorMessage(HttpResult result) {
+ var response = result.response();
+ var tokenLimit = getFirstHeaderOrUnknown(response, TOKENS_LIMIT);
+ var remainingTokens = getFirstHeaderOrUnknown(response, REMAINING_TOKENS);
+ var requestLimit = getFirstHeaderOrUnknown(response, REQUESTS_LIMIT);
+ var remainingRequests = getFirstHeaderOrUnknown(response, REMAINING_REQUESTS);
+
+ if (tokenLimit.equals("unknown") && requestLimit.equals("unknown")) {
+ var usageMessage = Strings.format("Remaining tokens [%s]. Remaining requests [%s].", remainingTokens, remainingRequests);
+ return RATE_LIMIT + ". " + usageMessage;
+ }
+
+ var usageMessage = Strings.format(
+ "Token limit [%s], remaining tokens [%s]. Request limit [%s], remaining requests [%s]",
+ tokenLimit,
+ remainingTokens,
+ requestLimit,
+ remainingRequests
+ );
+
+ return RATE_LIMIT + ". " + usageMessage;
+ }
+
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/BaseResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/BaseResponseEntity.java
new file mode 100644
index 0000000000000..7c3c7a9645cf3
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/BaseResponseEntity.java
@@ -0,0 +1,27 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.response;
+
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
+import org.elasticsearch.xpack.inference.external.request.Request;
+
+import java.io.IOException;
+
+/**
+ * A base class for providing InferenceServiceResults from a response. This is a lightweight wrapper
+ * to be able to override the `fromReponse` method to avoid using a static reference to the method.
+ */
+public abstract class BaseResponseEntity implements ResponseParser {
+ protected abstract InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException;
+
+ public InferenceServiceResults apply(Request request, HttpResult response) throws IOException {
+ return fromResponse(request, response);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/azureaistudio/AzureAiStudioChatCompletionResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/azureaistudio/AzureAiStudioChatCompletionResponseEntity.java
new file mode 100644
index 0000000000000..18f5923353960
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/azureaistudio/AzureAiStudioChatCompletionResponseEntity.java
@@ -0,0 +1,76 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.response.azureaistudio;
+
+import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
+import org.elasticsearch.common.xcontent.XContentParserUtils;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xcontent.XContentParserConfiguration;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioChatCompletionRequest;
+import org.elasticsearch.xpack.inference.external.response.BaseResponseEntity;
+import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
+import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
+
+public class AzureAiStudioChatCompletionResponseEntity extends BaseResponseEntity {
+
+ @Override
+ protected InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException {
+ if (request instanceof AzureAiStudioChatCompletionRequest asChatCompletionRequest) {
+ if (asChatCompletionRequest.isRealtimeEndpoint()) {
+ return parseRealtimeEndpointResponse(response);
+ }
+
+ // we can use the OpenAI chat completion type if it's not a realtime endpoint
+ return OpenAiChatCompletionResponseEntity.fromResponse(request, response);
+ }
+
+ return null;
+ }
+
+ private ChatCompletionResults parseRealtimeEndpointResponse(HttpResult response) throws IOException {
+ var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
+ try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
+ moveToFirstToken(jsonParser);
+
+ XContentParser.Token token = jsonParser.currentToken();
+ XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser);
+
+ while (token != null && token != XContentParser.Token.END_OBJECT) {
+ if (token != XContentParser.Token.FIELD_NAME) {
+ token = jsonParser.nextToken();
+ continue;
+ }
+
+ var currentName = jsonParser.currentName();
+ if (currentName == null || currentName.equalsIgnoreCase("output") == false) {
+ token = jsonParser.nextToken();
+ continue;
+ }
+
+ token = jsonParser.nextToken();
+ ensureExpectedToken(XContentParser.Token.VALUE_STRING, token, jsonParser);
+ String content = jsonParser.text();
+
+ return new ChatCompletionResults(List.of(new ChatCompletionResults.Result(content)));
+ }
+
+ throw new IllegalStateException("Reached an invalid state while parsing the Azure AI Studio completion response");
+ }
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/azureaistudio/AzureAiStudioEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/azureaistudio/AzureAiStudioEmbeddingsResponseEntity.java
new file mode 100644
index 0000000000000..3fce1ec7920f5
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/azureaistudio/AzureAiStudioEmbeddingsResponseEntity.java
@@ -0,0 +1,24 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.response.azureaistudio;
+
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.external.response.BaseResponseEntity;
+import org.elasticsearch.xpack.inference.external.response.openai.OpenAiEmbeddingsResponseEntity;
+
+import java.io.IOException;
+
+public class AzureAiStudioEmbeddingsResponseEntity extends BaseResponseEntity {
+ @Override
+ protected InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException {
+ // expected response type is the same as the Open AI Embeddings
+ return OpenAiEmbeddingsResponseEntity.fromResponse(request, response);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereCompletionResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereCompletionResponseEntity.java
new file mode 100644
index 0000000000000..4740c93ea6c03
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereCompletionResponseEntity.java
@@ -0,0 +1,98 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.response.cohere;
+
+import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
+import org.elasticsearch.common.xcontent.XContentParserUtils;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xcontent.XContentParserConfiguration;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.request.Request;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
+import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
+import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField;
+
+public class CohereCompletionResponseEntity {
+
+ private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Cohere chat response";
+
+ /**
+ * Parses the Cohere chat json response.
+ * For a request like:
+ *
+ *
+ *
+ * {
+ * "message": "What is Elastic?"
+ * }
+ *
+ *
+ *
+ * The response would look like:
+ *
+ *
+ *
+ * {
+ * "response_id": "some id",
+ * "text": "response",
+ * "generation_id": "some id",
+ * "chat_history": [
+ * {
+ * "role": "USER",
+ * "message": "What is Elastic?"
+ * },
+ * {
+ * "role": "CHATBOT",
+ * "message": "response"
+ * }
+ * ],
+ * "finish_reason": "COMPLETE",
+ * "meta": {
+ * "api_version": {
+ * "version": "1"
+ * },
+ * "billed_units": {
+ * "input_tokens": 4,
+ * "output_tokens": 229
+ * },
+ * "tokens": {
+ * "input_tokens": 70,
+ * "output_tokens": 229
+ * }
+ * }
+ * }
+ *
+ *
+ */
+
+ public static ChatCompletionResults fromResponse(Request request, HttpResult response) throws IOException {
+ var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
+
+ try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
+ moveToFirstToken(jsonParser);
+
+ XContentParser.Token token = jsonParser.currentToken();
+ XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser);
+
+ positionParserAtTokenAfterField(jsonParser, "text", FAILED_TO_FIND_FIELD_TEMPLATE);
+
+ XContentParser.Token contentToken = jsonParser.currentToken();
+ ensureExpectedToken(XContentParser.Token.VALUE_STRING, contentToken, jsonParser);
+ String content = jsonParser.text();
+
+ return new ChatCompletionResults(List.of(new ChatCompletionResults.Result(content)));
+ }
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
index 47c7cc0fce015..25e8afbe1d16c 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
@@ -34,6 +34,7 @@
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
+import java.util.stream.Collectors;
import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
@@ -91,6 +92,40 @@ public static T removeAsType(Map sourceMap, String key, Clas
}
}
+ /**
+ * Remove the object from the map and cast to first assignable type in the expected types list.
+ * If the object cannot be cast to one of the types an error is added to the
+ * {@code validationException} parameter
+ *
+ * @param sourceMap Map containing fields
+ * @param key The key of the object to remove
+ * @param types The expected types of the removed object
+ * @param validationException If the value is not of type {@code type}
+ * @return {@code null} if not present else the object cast to the first assignable type in the types list
+ */
+ public static Object removeAsOneOfTypes(
+ Map sourceMap,
+ String key,
+ List> types,
+ ValidationException validationException
+ ) {
+ Object o = sourceMap.remove(key);
+ if (o == null) {
+ return null;
+ }
+
+ for (Class> type : types) {
+ if (type.isAssignableFrom(o.getClass())) {
+ return type.cast(o);
+ }
+ }
+
+ validationException.addValidationError(
+ invalidTypesErrorMsg(key, o, types.stream().map(Class::getSimpleName).collect(Collectors.toList()))
+ );
+ return null;
+ }
+
@SuppressWarnings("unchecked")
public static Map removeFromMap(Map sourceMap, String fieldName) {
return (Map) sourceMap.remove(fieldName);
@@ -151,6 +186,16 @@ public static String invalidTypeErrorMsg(String settingName, Object foundObject,
);
}
+ public static String invalidTypesErrorMsg(String settingName, Object foundObject, List expectedTypes) {
+ return Strings.format(
+ // omitting [ ] for the last string as this will be added, if you convert the list to a string anyway
+ "field [%s] is not of one of the expected types. The value [%s] cannot be converted to one of %s",
+ settingName,
+ foundObject,
+ expectedTypes
+ );
+ }
+
public static String invalidUrlErrorMsg(String url, String settingName, String settingScope) {
return Strings.format("[%s] Invalid url [%s] received for field [%s]", settingScope, url, settingName);
}
@@ -325,7 +370,7 @@ public static Integer extractOptionalPositiveInteger(
}
if (optionalField != null && optionalField <= 0) {
- validationException.addValidationError(ServiceUtils.mustBeAPositiveNumberErrorMessage(settingName, scope, optionalField));
+ validationException.addValidationError(ServiceUtils.mustBeAPositiveIntegerErrorMessage(settingName, scope, optionalField));
}
if (validationException.validationErrors().size() > initialValidationErrorCount) {
@@ -335,6 +380,99 @@ public static Integer extractOptionalPositiveInteger(
return optionalField;
}
+ public static Float extractOptionalFloat(Map map, String settingName) {
+ return ServiceUtils.removeAsType(map, settingName, Float.class);
+ }
+
+ public static Double extractOptionalDoubleInRange(
+ Map map,
+ String settingName,
+ @Nullable Double minValue,
+ @Nullable Double maxValue,
+ String scope,
+ ValidationException validationException
+ ) {
+ int initialValidationErrorCount = validationException.validationErrors().size();
+ var doubleReturn = ServiceUtils.removeAsType(map, settingName, Double.class, validationException);
+
+ if (validationException.validationErrors().size() > initialValidationErrorCount) {
+ return null;
+ }
+
+ if (doubleReturn != null && minValue != null && doubleReturn < minValue) {
+ validationException.addValidationError(
+ ServiceUtils.mustBeGreaterThanOrEqualNumberErrorMessage(settingName, scope, doubleReturn, minValue)
+ );
+ }
+
+ if (doubleReturn != null && maxValue != null && doubleReturn > maxValue) {
+ validationException.addValidationError(
+ ServiceUtils.mustBeLessThanOrEqualNumberErrorMessage(settingName, scope, doubleReturn, maxValue)
+ );
+ }
+
+ if (validationException.validationErrors().size() > initialValidationErrorCount) {
+ return null;
+ }
+
+ return doubleReturn;
+ }
+
+ public static > E extractRequiredEnum(
+ Map map,
+ String settingName,
+ String scope,
+ EnumConstructor constructor,
+ EnumSet validValues,
+ ValidationException validationException
+ ) {
+ int initialValidationErrorCount = validationException.validationErrors().size();
+ var enumReturn = extractOptionalEnum(map, settingName, scope, constructor, validValues, validationException);
+
+ if (validationException.validationErrors().size() > initialValidationErrorCount) {
+ return null;
+ }
+
+ if (enumReturn == null) {
+ validationException.addValidationError(ServiceUtils.missingSettingErrorMsg(settingName, scope));
+ }
+
+ return enumReturn;
+ }
+
+ public static Long extractOptionalPositiveLong(
+ Map map,
+ String settingName,
+ String scope,
+ ValidationException validationException
+ ) {
+ // We don't want callers to handle the implementation detail that a long is expected (also treat integers like a long)
+ List> types = List.of(Integer.class, Long.class);
+ int initialValidationErrorCount = validationException.validationErrors().size();
+ var optionalField = ServiceUtils.removeAsOneOfTypes(map, settingName, types, validationException);
+
+ if (optionalField != null) {
+ try {
+ // Use String.valueOf first as there's no Long.valueOf(Object o)
+ Long longValue = Long.valueOf(String.valueOf(optionalField));
+
+ if (longValue <= 0L) {
+ validationException.addValidationError(ServiceUtils.mustBeAPositiveLongErrorMessage(settingName, scope, longValue));
+ }
+
+ if (validationException.validationErrors().size() > initialValidationErrorCount) {
+ return null;
+ }
+
+ return longValue;
+ } catch (NumberFormatException e) {
+ validationException.addValidationError(format("unable to parse long [%s]", e));
+ }
+ }
+
+ return null;
+ }
+
public static > E extractOptionalEnum(
Map map,
String settingName,
@@ -391,10 +529,26 @@ private static > void validateEnumValue(E enumValue, EnumSet
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioConstants.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioConstants.java
new file mode 100644
index 0000000000000..296b8cf09f8c0
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioConstants.java
@@ -0,0 +1,39 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio;
+
+public class AzureAiStudioConstants {
+ public static final String EMBEDDINGS_URI_PATH = "/v1/embeddings";
+ public static final String COMPLETIONS_URI_PATH = "/v1/chat/completions";
+
+ // common service settings fields
+ public static final String TARGET_FIELD = "target";
+ public static final String ENDPOINT_TYPE_FIELD = "endpoint_type";
+ public static final String PROVIDER_FIELD = "provider";
+ public static final String API_KEY_FIELD = "api_key";
+
+ // embeddings service and request settings
+ public static final String INPUT_FIELD = "input";
+ public static final String DIMENSIONS_FIELD = "dimensions";
+ public static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user";
+
+ // embeddings task settings fields
+ public static final String USER_FIELD = "user";
+
+ // completion task settings fields
+ public static final String TEMPERATURE_FIELD = "temperature";
+ public static final String TOP_P_FIELD = "top_p";
+ public static final String DO_SAMPLE_FIELD = "do_sample";
+ public static final String MAX_TOKENS_FIELD = "max_tokens";
+ public static final String MAX_NEW_TOKENS_FIELD = "max_new_tokens";
+
+ public static final Double MIN_TEMPERATURE_TOP_P = 0.0;
+ public static final Double MAX_TEMPERATURE_TOP_P = 2.0;
+
+ private AzureAiStudioConstants() {}
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioEndpointType.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioEndpointType.java
new file mode 100644
index 0000000000000..ece63f4bbf0cd
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioEndpointType.java
@@ -0,0 +1,26 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio;
+
+import java.util.Locale;
+
+public enum AzureAiStudioEndpointType {
+ TOKEN,
+ REALTIME;
+
+ public static String NAME = "azure_ai_studio_endpoint_type";
+
+ public static AzureAiStudioEndpointType fromString(String name) {
+ return valueOf(name.trim().toUpperCase(Locale.ROOT));
+ }
+
+ @Override
+ public String toString() {
+ return name().toLowerCase(Locale.ROOT);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioModel.java
new file mode 100644
index 0000000000000..a5dd491d198ae
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioModel.java
@@ -0,0 +1,104 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio;
+
+import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.TaskSettings;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.action.azureaistudio.AzureAiStudioActionVisitor;
+import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.util.Map;
+import java.util.Objects;
+
+/**
+ * Base class for Azure AI Studio models. There are some common properties across the task types
+ * including:
+ * - target:
+ * - uri:
+ * - provider:
+ * - endpointType:
+ */
+public abstract class AzureAiStudioModel extends Model {
+ protected String target;
+ protected URI uri;
+ protected AzureAiStudioProvider provider;
+ protected AzureAiStudioEndpointType endpointType;
+ protected RateLimitSettings rateLimitSettings;
+
+ public AzureAiStudioModel(AzureAiStudioModel model, TaskSettings taskSettings, RateLimitSettings rateLimitSettings) {
+ super(model, taskSettings);
+ this.rateLimitSettings = Objects.requireNonNull(rateLimitSettings);
+ setPropertiesFromServiceSettings((AzureAiStudioServiceSettings) model.getServiceSettings());
+ }
+
+ public AzureAiStudioModel(AzureAiStudioModel model, AzureAiStudioServiceSettings serviceSettings) {
+ super(model, serviceSettings);
+ setPropertiesFromServiceSettings(serviceSettings);
+ }
+
+ protected AzureAiStudioModel(ModelConfigurations modelConfigurations, ModelSecrets modelSecrets) {
+ super(modelConfigurations, modelSecrets);
+ setPropertiesFromServiceSettings((AzureAiStudioServiceSettings) modelConfigurations.getServiceSettings());
+ }
+
+ private void setPropertiesFromServiceSettings(AzureAiStudioServiceSettings serviceSettings) {
+ this.target = serviceSettings.target;
+ this.provider = serviceSettings.provider();
+ this.endpointType = serviceSettings.endpointType();
+ this.rateLimitSettings = serviceSettings.rateLimitSettings();
+ try {
+ this.uri = getEndpointUri();
+ } catch (URISyntaxException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ protected abstract URI getEndpointUri() throws URISyntaxException;
+
+ public String target() {
+ return this.target;
+ }
+
+ public RateLimitSettings rateLimitSettings() {
+ return this.rateLimitSettings;
+ }
+
+ public AzureAiStudioProvider provider() {
+ return this.provider;
+ }
+
+ public AzureAiStudioEndpointType endpointType() {
+ return this.endpointType;
+ }
+
+ public URI uri() {
+ return this.uri;
+ }
+
+ // Needed for testing only
+ public void setURI(String newUri) {
+ try {
+ this.uri = new URI(newUri);
+ } catch (URISyntaxException e) {
+ // swallow any error
+ }
+ }
+
+ @Override
+ public DefaultSecretSettings getSecretSettings() {
+ return (DefaultSecretSettings) super.getSecretSettings();
+ }
+
+ public abstract ExecutableAction accept(AzureAiStudioActionVisitor creator, Map taskSettings);
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioProvider.java
new file mode 100644
index 0000000000000..6b3efca0888f3
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioProvider.java
@@ -0,0 +1,31 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio;
+
+import java.util.Locale;
+
+public enum AzureAiStudioProvider {
+ OPENAI,
+ MISTRAL,
+ META,
+ MICROSOFT_PHI,
+ COHERE,
+ DATABRICKS;
+
+ public static String NAME = "azure_ai_studio_provider";
+
+ public static AzureAiStudioProvider fromString(String name) {
+ return valueOf(name.trim().toUpperCase(Locale.ROOT));
+ }
+
+ @Override
+ public String toString() {
+ return name().toLowerCase(Locale.ROOT);
+ }
+
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioProviderCapabilities.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioProviderCapabilities.java
new file mode 100644
index 0000000000000..af064707536eb
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioProviderCapabilities.java
@@ -0,0 +1,85 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio;
+
+import org.elasticsearch.inference.TaskType;
+
+import java.util.List;
+
+public final class AzureAiStudioProviderCapabilities {
+
+ // these providers have embeddings inference
+ public static final List embeddingProviders = List.of(
+ AzureAiStudioProvider.OPENAI,
+ AzureAiStudioProvider.COHERE
+ );
+
+ // these providers have chat completion inference (all providers at the moment)
+ public static final List chatCompletionProviders = List.of(AzureAiStudioProvider.values());
+
+ // these providers allow token ("pay as you go") embeddings endpoints
+ public static final List tokenEmbeddingsProviders = List.of(
+ AzureAiStudioProvider.OPENAI,
+ AzureAiStudioProvider.COHERE
+ );
+
+ // these providers allow realtime embeddings endpoints (none at the moment)
+ public static final List realtimeEmbeddingsProviders = List.of();
+
+ // these providers allow token ("pay as you go") chat completion endpoints
+ public static final List tokenChatCompletionProviders = List.of(
+ AzureAiStudioProvider.OPENAI,
+ AzureAiStudioProvider.META,
+ AzureAiStudioProvider.COHERE
+ );
+
+ // these providers allow realtime chat completion endpoints
+ public static final List realtimeChatCompletionProviders = List.of(
+ AzureAiStudioProvider.MISTRAL,
+ AzureAiStudioProvider.META,
+ AzureAiStudioProvider.MICROSOFT_PHI,
+ AzureAiStudioProvider.DATABRICKS
+ );
+
+ public static boolean providerAllowsTaskType(AzureAiStudioProvider provider, TaskType taskType) {
+ switch (taskType) {
+ case COMPLETION -> {
+ return chatCompletionProviders.contains(provider);
+ }
+ case TEXT_EMBEDDING -> {
+ return embeddingProviders.contains(provider);
+ }
+ default -> {
+ return false;
+ }
+ }
+ }
+
+ public static boolean providerAllowsEndpointTypeForTask(
+ AzureAiStudioProvider provider,
+ TaskType taskType,
+ AzureAiStudioEndpointType endpointType
+ ) {
+ switch (taskType) {
+ case COMPLETION -> {
+ return (endpointType == AzureAiStudioEndpointType.TOKEN)
+ ? tokenChatCompletionProviders.contains(provider)
+ : realtimeChatCompletionProviders.contains(provider);
+ }
+ case TEXT_EMBEDDING -> {
+ return (endpointType == AzureAiStudioEndpointType.TOKEN)
+ ? tokenEmbeddingsProviders.contains(provider)
+ : realtimeEmbeddingsProviders.contains(provider);
+ }
+ default -> {
+ return false;
+ }
+ }
+ }
+
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java
new file mode 100644
index 0000000000000..c488eac422401
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java
@@ -0,0 +1,358 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio;
+
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkedInferenceServiceResults;
+import org.elasticsearch.inference.ChunkingOptions;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.SimilarityMeasure;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults;
+import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
+import org.elasticsearch.xpack.inference.external.action.azureaistudio.AzureAiStudioActionCreator;
+import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.SenderService;
+import org.elasticsearch.xpack.inference.services.ServiceComponents;
+import org.elasticsearch.xpack.inference.services.ServiceUtils;
+import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModel;
+import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettings;
+import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel;
+import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProviderCapabilities.providerAllowsEndpointTypeForTask;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProviderCapabilities.providerAllowsTaskType;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettings.DEFAULT_MAX_NEW_TOKENS;
+
+public class AzureAiStudioService extends SenderService {
+
+ private static final String NAME = "azureaistudio";
+
+ public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
+ super(factory, serviceComponents);
+ }
+
+ @Override
+ protected void doInfer(
+ Model model,
+ List input,
+ Map taskSettings,
+ InputType inputType,
+ TimeValue timeout,
+ ActionListener listener
+ ) {
+ var actionCreator = new AzureAiStudioActionCreator(getSender(), getServiceComponents());
+
+ if (model instanceof AzureAiStudioModel baseAzureAiStudioModel) {
+ var action = baseAzureAiStudioModel.accept(actionCreator, taskSettings);
+ action.execute(new DocumentsOnlyInput(input), timeout, listener);
+ } else {
+ listener.onFailure(createInvalidModelException(model));
+ }
+ }
+
+ @Override
+ protected void doInfer(
+ Model model,
+ String query,
+ List input,
+ Map taskSettings,
+ InputType inputType,
+ TimeValue timeout,
+ ActionListener listener
+ ) {
+ throw new UnsupportedOperationException("Azure AI Studio service does not support inference with query input");
+ }
+
+ @Override
+ protected void doChunkedInfer(
+ Model model,
+ String query,
+ List input,
+ Map taskSettings,
+ InputType inputType,
+ ChunkingOptions chunkingOptions,
+ TimeValue timeout,
+ ActionListener> listener
+ ) {
+ ActionListener inferListener = listener.delegateFailureAndWrap(
+ (delegate, response) -> delegate.onResponse(translateToChunkedResults(input, response))
+ );
+
+ doInfer(model, input, taskSettings, inputType, timeout, inferListener);
+ }
+
+ private static List translateToChunkedResults(
+ List inputs,
+ InferenceServiceResults inferenceResults
+ ) {
+ if (inferenceResults instanceof TextEmbeddingResults textEmbeddingResults) {
+ return ChunkedTextEmbeddingResults.of(inputs, textEmbeddingResults);
+ } else if (inferenceResults instanceof ErrorInferenceResults error) {
+ return List.of(new ErrorChunkedInferenceResults(error.getException()));
+ } else {
+ throw createInvalidChunkedResultException(inferenceResults.getWriteableName());
+ }
+ }
+
+ @Override
+ public void parseRequestConfig(
+ String inferenceEntityId,
+ TaskType taskType,
+ Map config,
+ Set platformArchitectures,
+ ActionListener parsedModelListener
+ ) {
+ try {
+ Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
+ Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
+
+ AzureAiStudioModel model = createModel(
+ inferenceEntityId,
+ taskType,
+ serviceSettingsMap,
+ taskSettingsMap,
+ serviceSettingsMap,
+ TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
+ ConfigurationParseContext.REQUEST
+ );
+
+ throwIfNotEmptyMap(config, NAME);
+ throwIfNotEmptyMap(serviceSettingsMap, NAME);
+ throwIfNotEmptyMap(taskSettingsMap, NAME);
+
+ parsedModelListener.onResponse(model);
+ } catch (Exception e) {
+ parsedModelListener.onFailure(e);
+ }
+ }
+
+ @Override
+ public AzureAiStudioModel parsePersistedConfigWithSecrets(
+ String inferenceEntityId,
+ TaskType taskType,
+ Map config,
+ Map secrets
+ ) {
+ Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
+ Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
+ Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
+
+ return createModelFromPersistent(
+ inferenceEntityId,
+ taskType,
+ serviceSettingsMap,
+ taskSettingsMap,
+ secretSettingsMap,
+ parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+ );
+ }
+
+ @Override
+ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) {
+ Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
+ Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
+
+ return createModelFromPersistent(
+ inferenceEntityId,
+ taskType,
+ serviceSettingsMap,
+ taskSettingsMap,
+ null,
+ parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+ );
+ }
+
+ @Override
+ public String name() {
+ return NAME;
+ }
+
+ @Override
+ public TransportVersion getMinimalSupportedVersion() {
+ return TransportVersions.ML_INFERENCE_AZURE_AI_STUDIO;
+ }
+
+ private static AzureAiStudioModel createModel(
+ String inferenceEntityId,
+ TaskType taskType,
+ Map serviceSettings,
+ Map taskSettings,
+ @Nullable Map secretSettings,
+ String failureMessage,
+ ConfigurationParseContext context
+ ) {
+
+ if (taskType == TaskType.TEXT_EMBEDDING) {
+ var embeddingsModel = new AzureAiStudioEmbeddingsModel(
+ inferenceEntityId,
+ taskType,
+ NAME,
+ serviceSettings,
+ taskSettings,
+ secretSettings,
+ context
+ );
+ checkProviderAndEndpointTypeForTask(
+ TaskType.TEXT_EMBEDDING,
+ embeddingsModel.getServiceSettings().provider(),
+ embeddingsModel.getServiceSettings().endpointType()
+ );
+ return embeddingsModel;
+ }
+
+ if (taskType == TaskType.COMPLETION) {
+ var completionModel = new AzureAiStudioChatCompletionModel(
+ inferenceEntityId,
+ taskType,
+ NAME,
+ serviceSettings,
+ taskSettings,
+ secretSettings,
+ context
+ );
+ checkProviderAndEndpointTypeForTask(
+ TaskType.COMPLETION,
+ completionModel.getServiceSettings().provider(),
+ completionModel.getServiceSettings().endpointType()
+ );
+ return completionModel;
+ }
+
+ throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+ }
+
+ private AzureAiStudioModel createModelFromPersistent(
+ String inferenceEntityId,
+ TaskType taskType,
+ Map serviceSettings,
+ Map taskSettings,
+ Map secretSettings,
+ String failureMessage
+ ) {
+ return createModel(
+ inferenceEntityId,
+ taskType,
+ serviceSettings,
+ taskSettings,
+ secretSettings,
+ failureMessage,
+ ConfigurationParseContext.PERSISTENT
+ );
+ }
+
+ @Override
+ public void checkModelConfig(Model model, ActionListener listener) {
+ if (model instanceof AzureAiStudioEmbeddingsModel embeddingsModel) {
+ ServiceUtils.getEmbeddingSize(
+ model,
+ this,
+ listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateEmbeddingModelConfig(embeddingsModel, size)))
+ );
+ } else if (model instanceof AzureAiStudioChatCompletionModel chatCompletionModel) {
+ listener.onResponse(updateChatCompletionModelConfig(chatCompletionModel));
+ } else {
+ listener.onResponse(model);
+ }
+ }
+
+ private AzureAiStudioEmbeddingsModel updateEmbeddingModelConfig(AzureAiStudioEmbeddingsModel embeddingsModel, int embeddingsSize) {
+ if (embeddingsModel.getServiceSettings().dimensionsSetByUser()
+ && embeddingsModel.getServiceSettings().dimensions() != null
+ && embeddingsModel.getServiceSettings().dimensions() != embeddingsSize) {
+ throw new ElasticsearchStatusException(
+ Strings.format(
+ "The retrieved embeddings size [%s] does not match the size specified in the settings [%s]. "
+ + "Please recreate the [%s] configuration with the correct dimensions",
+ embeddingsSize,
+ embeddingsModel.getServiceSettings().dimensions(),
+ embeddingsModel.getConfigurations().getInferenceEntityId()
+ ),
+ RestStatus.BAD_REQUEST
+ );
+ }
+
+ var similarityFromModel = embeddingsModel.getServiceSettings().similarity();
+ var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
+
+ AzureAiStudioEmbeddingsServiceSettings serviceSettings = new AzureAiStudioEmbeddingsServiceSettings(
+ embeddingsModel.getServiceSettings().target(),
+ embeddingsModel.getServiceSettings().provider(),
+ embeddingsModel.getServiceSettings().endpointType(),
+ embeddingsSize,
+ embeddingsModel.getServiceSettings().dimensionsSetByUser(),
+ embeddingsModel.getServiceSettings().maxInputTokens(),
+ similarityToUse,
+ embeddingsModel.getServiceSettings().rateLimitSettings()
+ );
+
+ return new AzureAiStudioEmbeddingsModel(embeddingsModel, serviceSettings);
+ }
+
+ private AzureAiStudioChatCompletionModel updateChatCompletionModelConfig(AzureAiStudioChatCompletionModel chatCompletionModel) {
+ var modelMaxNewTokens = chatCompletionModel.getTaskSettings().maxNewTokens();
+ var maxNewTokensToUse = modelMaxNewTokens == null ? DEFAULT_MAX_NEW_TOKENS : modelMaxNewTokens;
+ var updatedTaskSettings = new AzureAiStudioChatCompletionTaskSettings(
+ chatCompletionModel.getTaskSettings().temperature(),
+ chatCompletionModel.getTaskSettings().topP(),
+ chatCompletionModel.getTaskSettings().doSample(),
+ maxNewTokensToUse
+ );
+ return new AzureAiStudioChatCompletionModel(chatCompletionModel, updatedTaskSettings);
+ }
+
+ private static void checkProviderAndEndpointTypeForTask(
+ TaskType taskType,
+ AzureAiStudioProvider provider,
+ AzureAiStudioEndpointType endpointType
+ ) {
+ if (providerAllowsTaskType(provider, taskType) == false) {
+ throw new ElasticsearchStatusException(
+ Strings.format("The [%s] task type for provider [%s] is not available", taskType, provider),
+ RestStatus.BAD_REQUEST
+ );
+ }
+
+ if (providerAllowsEndpointTypeForTask(provider, taskType, endpointType) == false) {
+ throw new ElasticsearchStatusException(
+ Strings.format(
+ "The [%s] endpoint type with [%s] task type for provider [%s] is not available",
+ endpointType,
+ taskType,
+ provider
+ ),
+ RestStatus.BAD_REQUEST
+ );
+ }
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceSettings.java
new file mode 100644
index 0000000000000..10c57e19b6403
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceSettings.java
@@ -0,0 +1,130 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio;
+
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+
+import java.io.IOException;
+import java.util.EnumSet;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredEnum;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.ENDPOINT_TYPE_FIELD;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.PROVIDER_FIELD;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TARGET_FIELD;
+
+public abstract class AzureAiStudioServiceSettings extends FilteredXContentObject implements ServiceSettings {
+
+ protected final String target;
+ protected final AzureAiStudioProvider provider;
+ protected final AzureAiStudioEndpointType endpointType;
+ protected final RateLimitSettings rateLimitSettings;
+
+ protected static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(240);
+
+ protected static BaseAzureAiStudioCommonFields fromMap(
+ Map map,
+ ValidationException validationException,
+ ConfigurationParseContext context
+ ) {
+ String target = extractRequiredString(map, TARGET_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException);
+ RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
+ AzureAiStudioEndpointType endpointType = extractRequiredEnum(
+ map,
+ ENDPOINT_TYPE_FIELD,
+ ModelConfigurations.SERVICE_SETTINGS,
+ AzureAiStudioEndpointType::fromString,
+ EnumSet.allOf(AzureAiStudioEndpointType.class),
+ validationException
+ );
+
+ AzureAiStudioProvider provider = extractRequiredEnum(
+ map,
+ PROVIDER_FIELD,
+ ModelConfigurations.SERVICE_SETTINGS,
+ AzureAiStudioProvider::fromString,
+ EnumSet.allOf(AzureAiStudioProvider.class),
+ validationException
+ );
+
+ return new BaseAzureAiStudioCommonFields(target, provider, endpointType, rateLimitSettings);
+ }
+
+ protected AzureAiStudioServiceSettings(StreamInput in) throws IOException {
+ this.target = in.readString();
+ this.provider = in.readEnum(AzureAiStudioProvider.class);
+ this.endpointType = in.readEnum(AzureAiStudioEndpointType.class);
+ this.rateLimitSettings = new RateLimitSettings(in);
+ }
+
+ protected AzureAiStudioServiceSettings(
+ String target,
+ AzureAiStudioProvider provider,
+ AzureAiStudioEndpointType endpointType,
+ @Nullable RateLimitSettings rateLimitSettings
+ ) {
+ this.target = target;
+ this.provider = provider;
+ this.endpointType = endpointType;
+ this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
+ }
+
+ protected record BaseAzureAiStudioCommonFields(
+ String target,
+ AzureAiStudioProvider provider,
+ AzureAiStudioEndpointType endpointType,
+ RateLimitSettings rateLimitSettings
+ ) {}
+
+ public String target() {
+ return this.target;
+ }
+
+ public AzureAiStudioProvider provider() {
+ return this.provider;
+ }
+
+ public AzureAiStudioEndpointType endpointType() {
+ return this.endpointType;
+ }
+
+ public RateLimitSettings rateLimitSettings() {
+ return this.rateLimitSettings;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeString(target);
+ out.writeEnum(provider);
+ out.writeEnum(endpointType);
+ rateLimitSettings.writeTo(out);
+ }
+
+ protected void addXContentFields(XContentBuilder builder, Params params) throws IOException {
+ this.addExposedXContentFields(builder, params);
+ rateLimitSettings.toXContent(builder, params);
+ }
+
+ protected void addExposedXContentFields(XContentBuilder builder, Params params) throws IOException {
+ builder.field(TARGET_FIELD, this.target);
+ builder.field(PROVIDER_FIELD, this.provider);
+ builder.field(ENDPOINT_TYPE_FIELD, this.endpointType);
+ }
+
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionModel.java
new file mode 100644
index 0000000000000..5afb3aaed61ff
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionModel.java
@@ -0,0 +1,105 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio.completion;
+
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.action.azureaistudio.AzureAiStudioActionVisitor;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioModel;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider;
+import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
+
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.COMPLETIONS_URI_PATH;
+
+public class AzureAiStudioChatCompletionModel extends AzureAiStudioModel {
+
+ public static AzureAiStudioChatCompletionModel of(AzureAiStudioModel model, Map taskSettings) {
+ var modelAsCompletionModel = (AzureAiStudioChatCompletionModel) model;
+
+ if (taskSettings == null || taskSettings.isEmpty()) {
+ return modelAsCompletionModel;
+ }
+
+ var requestTaskSettings = AzureAiStudioChatCompletionRequestTaskSettings.fromMap(taskSettings);
+ var taskSettingToUse = AzureAiStudioChatCompletionTaskSettings.of(modelAsCompletionModel.getTaskSettings(), requestTaskSettings);
+
+ return new AzureAiStudioChatCompletionModel(modelAsCompletionModel, taskSettingToUse);
+ }
+
+ public AzureAiStudioChatCompletionModel(
+ String inferenceEntityId,
+ TaskType taskType,
+ String service,
+ AzureAiStudioChatCompletionServiceSettings serviceSettings,
+ AzureAiStudioChatCompletionTaskSettings taskSettings,
+ DefaultSecretSettings secrets
+ ) {
+ super(new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secrets));
+ }
+
+ public AzureAiStudioChatCompletionModel(
+ String inferenceEntityId,
+ TaskType taskType,
+ String service,
+ Map serviceSettings,
+ Map taskSettings,
+ @Nullable Map secrets,
+ ConfigurationParseContext context
+ ) {
+ this(
+ inferenceEntityId,
+ taskType,
+ service,
+ AzureAiStudioChatCompletionServiceSettings.fromMap(serviceSettings, context),
+ AzureAiStudioChatCompletionTaskSettings.fromMap(taskSettings),
+ DefaultSecretSettings.fromMap(secrets)
+ );
+ }
+
+ public AzureAiStudioChatCompletionModel(AzureAiStudioChatCompletionModel model, AzureAiStudioChatCompletionTaskSettings taskSettings) {
+ super(model, taskSettings, model.getServiceSettings().rateLimitSettings());
+ }
+
+ @Override
+ public AzureAiStudioChatCompletionServiceSettings getServiceSettings() {
+ return (AzureAiStudioChatCompletionServiceSettings) super.getServiceSettings();
+ }
+
+ @Override
+ public AzureAiStudioChatCompletionTaskSettings getTaskSettings() {
+ return (AzureAiStudioChatCompletionTaskSettings) super.getTaskSettings();
+ }
+
+ @Override
+ public DefaultSecretSettings getSecretSettings() {
+ return super.getSecretSettings();
+ }
+
+ @Override
+ protected URI getEndpointUri() throws URISyntaxException {
+ if (this.provider == AzureAiStudioProvider.OPENAI || this.endpointType == AzureAiStudioEndpointType.REALTIME) {
+ return new URI(this.target);
+ }
+
+ return new URI(this.target + COMPLETIONS_URI_PATH);
+ }
+
+ @Override
+ public ExecutableAction accept(AzureAiStudioActionVisitor creator, Map taskSettings) {
+ return creator.create(this, taskSettings);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionRequestTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionRequestTaskSettings.java
new file mode 100644
index 0000000000000..2eef059e3fae1
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionRequestTaskSettings.java
@@ -0,0 +1,83 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio.completion;
+
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants;
+
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalDoubleInRange;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.DO_SAMPLE_FIELD;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.MAX_NEW_TOKENS_FIELD;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TEMPERATURE_FIELD;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_P_FIELD;
+
+public record AzureAiStudioChatCompletionRequestTaskSettings(
+ @Nullable Double temperature,
+ @Nullable Double topP,
+ @Nullable Boolean doSample,
+ @Nullable Integer maxNewTokens
+) {
+
+ public static final AzureAiStudioChatCompletionRequestTaskSettings EMPTY_SETTINGS = new AzureAiStudioChatCompletionRequestTaskSettings(
+ null,
+ null,
+ null,
+ null
+ );
+
+ /**
+ * Extracts the task settings from a map. All settings are considered optional and the absence of a setting
+ * does not throw an error.
+ *
+ * @param map the settings received from a request
+ * @return a {@link AzureAiStudioChatCompletionRequestTaskSettings}
+ */
+ public static AzureAiStudioChatCompletionRequestTaskSettings fromMap(Map map) {
+ if (map.isEmpty()) {
+ return AzureAiStudioChatCompletionRequestTaskSettings.EMPTY_SETTINGS;
+ }
+
+ ValidationException validationException = new ValidationException();
+
+ var temperature = extractOptionalDoubleInRange(
+ map,
+ TEMPERATURE_FIELD,
+ AzureAiStudioConstants.MIN_TEMPERATURE_TOP_P,
+ AzureAiStudioConstants.MAX_TEMPERATURE_TOP_P,
+ ModelConfigurations.TASK_SETTINGS,
+ validationException
+ );
+ var topP = extractOptionalDoubleInRange(
+ map,
+ TOP_P_FIELD,
+ AzureAiStudioConstants.MIN_TEMPERATURE_TOP_P,
+ AzureAiStudioConstants.MAX_TEMPERATURE_TOP_P,
+ ModelConfigurations.TASK_SETTINGS,
+ validationException
+ );
+ Boolean doSample = extractOptionalBoolean(map, DO_SAMPLE_FIELD, validationException);
+ Integer maxNewTokens = extractOptionalPositiveInteger(
+ map,
+ MAX_NEW_TOKENS_FIELD,
+ ModelConfigurations.TASK_SETTINGS,
+ validationException
+ );
+
+ if (validationException.validationErrors().isEmpty() == false) {
+ throw validationException;
+ }
+
+ return new AzureAiStudioChatCompletionRequestTaskSettings(temperature, topP, doSample, maxNewTokens);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionServiceSettings.java
new file mode 100644
index 0000000000000..2f8422be5ed90
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionServiceSettings.java
@@ -0,0 +1,123 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio.completion;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xcontent.ToXContent;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioServiceSettings;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Objects;
+
+public class AzureAiStudioChatCompletionServiceSettings extends AzureAiStudioServiceSettings {
+ public static final String NAME = "azure_ai_studio_chat_completion_service_settings";
+
+ public static AzureAiStudioChatCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) {
+ ValidationException validationException = new ValidationException();
+
+ var settings = completionSettingsFromMap(map, validationException, context);
+
+ if (validationException.validationErrors().isEmpty() == false) {
+ throw validationException;
+ }
+
+ return new AzureAiStudioChatCompletionServiceSettings(settings);
+ }
+
+ private static AzureAiStudioCompletionCommonFields completionSettingsFromMap(
+ Map map,
+ ValidationException validationException,
+ ConfigurationParseContext context
+ ) {
+ var baseSettings = AzureAiStudioServiceSettings.fromMap(map, validationException, context);
+ return new AzureAiStudioCompletionCommonFields(baseSettings);
+ }
+
+ private record AzureAiStudioCompletionCommonFields(BaseAzureAiStudioCommonFields baseCommonFields) {}
+
+ public AzureAiStudioChatCompletionServiceSettings(
+ String target,
+ AzureAiStudioProvider provider,
+ AzureAiStudioEndpointType endpointType,
+ @Nullable RateLimitSettings rateLimitSettings
+ ) {
+ super(target, provider, endpointType, rateLimitSettings);
+ }
+
+ public AzureAiStudioChatCompletionServiceSettings(StreamInput in) throws IOException {
+ super(in);
+ }
+
+ private AzureAiStudioChatCompletionServiceSettings(AzureAiStudioCompletionCommonFields fields) {
+ this(
+ fields.baseCommonFields.target(),
+ fields.baseCommonFields.provider(),
+ fields.baseCommonFields.endpointType(),
+ fields.baseCommonFields.rateLimitSettings()
+ );
+ }
+
+ @Override
+ public String getWriteableName() {
+ return NAME;
+ }
+
+ @Override
+ public TransportVersion getMinimalSupportedVersion() {
+ return TransportVersions.ML_INFERENCE_AZURE_AI_STUDIO;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ super.writeTo(out);
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+
+ super.addXContentFields(builder, params);
+
+ builder.endObject();
+ return builder;
+ }
+
+ @Override
+ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, ToXContent.Params params) throws IOException {
+ super.addExposedXContentFields(builder, params);
+ return builder;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ AzureAiStudioChatCompletionServiceSettings that = (AzureAiStudioChatCompletionServiceSettings) o;
+
+ return Objects.equals(target, that.target)
+ && Objects.equals(provider, that.provider)
+ && Objects.equals(endpointType, that.endpointType)
+ && Objects.equals(rateLimitSettings, that.rateLimitSettings);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(target, provider, endpointType, rateLimitSettings);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionTaskSettings.java
new file mode 100644
index 0000000000000..fc11d96269b68
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/completion/AzureAiStudioChatCompletionTaskSettings.java
@@ -0,0 +1,192 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio.completion;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.TaskSettings;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants;
+import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsTaskSettings;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalDoubleInRange;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.DO_SAMPLE_FIELD;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.MAX_NEW_TOKENS_FIELD;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TEMPERATURE_FIELD;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TOP_P_FIELD;
+
+public class AzureAiStudioChatCompletionTaskSettings implements TaskSettings {
+ public static final String NAME = "azure_ai_studio_chat_completion_task_settings";
+ public static final Integer DEFAULT_MAX_NEW_TOKENS = 64;
+
+ public static AzureAiStudioChatCompletionTaskSettings fromMap(Map map) {
+ ValidationException validationException = new ValidationException();
+
+ var temperature = extractOptionalDoubleInRange(
+ map,
+ TEMPERATURE_FIELD,
+ AzureAiStudioConstants.MIN_TEMPERATURE_TOP_P,
+ AzureAiStudioConstants.MAX_TEMPERATURE_TOP_P,
+ ModelConfigurations.TASK_SETTINGS,
+ validationException
+ );
+ var topP = extractOptionalDoubleInRange(
+ map,
+ TOP_P_FIELD,
+ AzureAiStudioConstants.MIN_TEMPERATURE_TOP_P,
+ AzureAiStudioConstants.MAX_TEMPERATURE_TOP_P,
+ ModelConfigurations.TASK_SETTINGS,
+ validationException
+ );
+ var doSample = extractOptionalBoolean(map, DO_SAMPLE_FIELD, validationException);
+ var maxNewTokens = extractOptionalPositiveInteger(
+ map,
+ MAX_NEW_TOKENS_FIELD,
+ ModelConfigurations.TASK_SETTINGS,
+ validationException
+ );
+
+ if (validationException.validationErrors().isEmpty() == false) {
+ throw validationException;
+ }
+
+ return new AzureAiStudioChatCompletionTaskSettings(temperature, topP, doSample, maxNewTokens);
+ }
+
+ /**
+ * Creates a new {@link AzureOpenAiEmbeddingsTaskSettings} object by overriding the values in originalSettings with the ones
+ * passed in via requestSettings if the fields are not null.
+ * @param originalSettings the original {@link AzureOpenAiEmbeddingsTaskSettings} from the inference entity configuration from storage
+ * @param requestSettings the {@link AzureOpenAiEmbeddingsTaskSettings} from the request
+ * @return a new {@link AzureOpenAiEmbeddingsTaskSettings}
+ */
+ public static AzureAiStudioChatCompletionTaskSettings of(
+ AzureAiStudioChatCompletionTaskSettings originalSettings,
+ AzureAiStudioChatCompletionRequestTaskSettings requestSettings
+ ) {
+
+ var temperature = requestSettings.temperature() == null ? originalSettings.temperature() : requestSettings.temperature();
+ var topP = requestSettings.topP() == null ? originalSettings.topP() : requestSettings.topP();
+ var doSample = requestSettings.doSample() == null ? originalSettings.doSample() : requestSettings.doSample();
+ var maxNewTokens = requestSettings.maxNewTokens() == null ? originalSettings.maxNewTokens() : requestSettings.maxNewTokens();
+
+ return new AzureAiStudioChatCompletionTaskSettings(temperature, topP, doSample, maxNewTokens);
+ }
+
+ public AzureAiStudioChatCompletionTaskSettings(
+ @Nullable Double temperature,
+ @Nullable Double topP,
+ @Nullable Boolean doSample,
+ @Nullable Integer maxNewTokens
+ ) {
+
+ this.temperature = temperature;
+ this.topP = topP;
+ this.doSample = doSample;
+ this.maxNewTokens = maxNewTokens;
+ }
+
+ public AzureAiStudioChatCompletionTaskSettings(StreamInput in) throws IOException {
+ this.temperature = in.readOptionalDouble();
+ this.topP = in.readOptionalDouble();
+ this.doSample = in.readOptionalBoolean();
+ this.maxNewTokens = in.readOptionalInt();
+ }
+
+ private final Double temperature;
+ private final Double topP;
+ private final Boolean doSample;
+ private final Integer maxNewTokens;
+
+ public Double temperature() {
+ return temperature;
+ }
+
+ public Double topP() {
+ return topP;
+ }
+
+ public Boolean doSample() {
+ return doSample;
+ }
+
+ public Integer maxNewTokens() {
+ return maxNewTokens;
+ }
+
+ public boolean areAnyParametersAvailable() {
+ return temperature != null && topP != null && doSample != null && maxNewTokens != null;
+ }
+
+ @Override
+ public String getWriteableName() {
+ return NAME;
+ }
+
+ @Override
+ public TransportVersion getMinimalSupportedVersion() {
+ return TransportVersions.ML_INFERENCE_AZURE_OPENAI_EMBEDDINGS;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeOptionalDouble(temperature);
+ out.writeOptionalDouble(topP);
+ out.writeOptionalBoolean(doSample);
+ out.writeOptionalInt(maxNewTokens);
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+
+ if (temperature != null) {
+ builder.field(TEMPERATURE_FIELD, temperature);
+ }
+ if (topP != null) {
+ builder.field(TOP_P_FIELD, topP);
+ }
+ if (doSample != null) {
+ builder.field(DO_SAMPLE_FIELD, doSample);
+ }
+ if (maxNewTokens != null) {
+ builder.field(MAX_NEW_TOKENS_FIELD, maxNewTokens);
+ }
+
+ builder.endObject();
+ return builder;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ AzureAiStudioChatCompletionTaskSettings that = (AzureAiStudioChatCompletionTaskSettings) o;
+ return Objects.equals(temperature, that.temperature)
+ && Objects.equals(topP, that.topP)
+ && Objects.equals(doSample, that.doSample)
+ && Objects.equals(maxNewTokens, that.maxNewTokens);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(temperature, topP, doSample, maxNewTokens);
+ }
+
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsModel.java
new file mode 100644
index 0000000000000..a999b9f0312e6
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsModel.java
@@ -0,0 +1,102 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio.embeddings;
+
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.action.azureaistudio.AzureAiStudioActionVisitor;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioModel;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider;
+import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
+
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.EMBEDDINGS_URI_PATH;
+
+public class AzureAiStudioEmbeddingsModel extends AzureAiStudioModel {
+
+ public static AzureAiStudioEmbeddingsModel of(AzureAiStudioEmbeddingsModel model, Map taskSettings) {
+ if (taskSettings == null || taskSettings.isEmpty()) {
+ return model;
+ }
+
+ var requestTaskSettings = AzureAiStudioEmbeddingsRequestTaskSettings.fromMap(taskSettings);
+ var taskSettingToUse = AzureAiStudioEmbeddingsTaskSettings.of(model.getTaskSettings(), requestTaskSettings);
+
+ return new AzureAiStudioEmbeddingsModel(model, taskSettingToUse);
+ }
+
+ public AzureAiStudioEmbeddingsModel(
+ String inferenceEntityId,
+ TaskType taskType,
+ String service,
+ AzureAiStudioEmbeddingsServiceSettings serviceSettings,
+ AzureAiStudioEmbeddingsTaskSettings taskSettings,
+ DefaultSecretSettings secrets
+ ) {
+ super(new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secrets));
+ }
+
+ public AzureAiStudioEmbeddingsModel(
+ String inferenceEntityId,
+ TaskType taskType,
+ String service,
+ Map serviceSettings,
+ Map taskSettings,
+ @Nullable Map secrets,
+ ConfigurationParseContext context
+ ) {
+ this(
+ inferenceEntityId,
+ taskType,
+ service,
+ AzureAiStudioEmbeddingsServiceSettings.fromMap(serviceSettings, context),
+ AzureAiStudioEmbeddingsTaskSettings.fromMap(taskSettings),
+ DefaultSecretSettings.fromMap(secrets)
+ );
+ }
+
+ private AzureAiStudioEmbeddingsModel(AzureAiStudioEmbeddingsModel model, AzureAiStudioEmbeddingsTaskSettings taskSettings) {
+ super(model, taskSettings, model.getServiceSettings().rateLimitSettings());
+ }
+
+ public AzureAiStudioEmbeddingsModel(AzureAiStudioEmbeddingsModel model, AzureAiStudioEmbeddingsServiceSettings serviceSettings) {
+ super(model, serviceSettings);
+ }
+
+ @Override
+ public AzureAiStudioEmbeddingsServiceSettings getServiceSettings() {
+ return (AzureAiStudioEmbeddingsServiceSettings) super.getServiceSettings();
+ }
+
+ @Override
+ public AzureAiStudioEmbeddingsTaskSettings getTaskSettings() {
+ return (AzureAiStudioEmbeddingsTaskSettings) super.getTaskSettings();
+ }
+
+ @Override
+ protected URI getEndpointUri() throws URISyntaxException {
+ if (this.provider == AzureAiStudioProvider.OPENAI || this.endpointType == AzureAiStudioEndpointType.REALTIME) {
+ return new URI(this.target);
+ }
+
+ return new URI(this.target + EMBEDDINGS_URI_PATH);
+ }
+
+ @Override
+ public ExecutableAction accept(AzureAiStudioActionVisitor creator, Map taskSettings) {
+ return creator.create(this, taskSettings);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsRequestTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsRequestTaskSettings.java
new file mode 100644
index 0000000000000..8c9fd22a7cdf7
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsRequestTaskSettings.java
@@ -0,0 +1,52 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio.embeddings;
+
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsRequestTaskSettings;
+
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.USER_FIELD;
+
+/**
+ * This class handles extracting Azure OpenAI task settings from a request. The difference between this class and
+ * {@link AzureAiStudioEmbeddingsTaskSettings} is that this class considers all fields as optional. It will not throw an error if a field
+ * is missing. This allows overriding persistent task settings.
+ * @param user a unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse, if using an OpenAI model
+ */
+public record AzureAiStudioEmbeddingsRequestTaskSettings(@Nullable String user) {
+ public static final AzureAiStudioEmbeddingsRequestTaskSettings EMPTY_SETTINGS = new AzureAiStudioEmbeddingsRequestTaskSettings(null);
+
+ /**
+ * Extracts the task settings from a map. All settings are considered optional and the absence of a setting
+ * does not throw an error.
+ *
+ * @param map the settings received from a request
+ * @return a {@link AzureOpenAiEmbeddingsRequestTaskSettings}
+ */
+ public static AzureAiStudioEmbeddingsRequestTaskSettings fromMap(Map map) {
+ if (map.isEmpty()) {
+ return AzureAiStudioEmbeddingsRequestTaskSettings.EMPTY_SETTINGS;
+ }
+
+ ValidationException validationException = new ValidationException();
+
+ String user = extractOptionalString(map, USER_FIELD, ModelConfigurations.TASK_SETTINGS, validationException);
+
+ if (validationException.validationErrors().isEmpty() == false) {
+ throw validationException;
+ }
+
+ return new AzureAiStudioEmbeddingsRequestTaskSettings(user);
+ }
+
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsServiceSettings.java
new file mode 100644
index 0000000000000..1a39cd67a70f3
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsServiceSettings.java
@@ -0,0 +1,231 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio.embeddings;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.SimilarityMeasure;
+import org.elasticsearch.xcontent.ToXContent;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.ServiceUtils;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioServiceSettings;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
+import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
+import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType;
+
+public class AzureAiStudioEmbeddingsServiceSettings extends AzureAiStudioServiceSettings {
+
+ public static final String NAME = "azure_ai_studio_embeddings_service_settings";
+ static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user";
+
+ public static AzureAiStudioEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) {
+ ValidationException validationException = new ValidationException();
+
+ var settings = embeddingSettingsFromMap(map, validationException, context);
+
+ if (validationException.validationErrors().isEmpty() == false) {
+ throw validationException;
+ }
+
+ return new AzureAiStudioEmbeddingsServiceSettings(settings);
+ }
+
+ private static AzureAiStudioEmbeddingCommonFields embeddingSettingsFromMap(
+ Map map,
+ ValidationException validationException,
+ ConfigurationParseContext context
+ ) {
+ var baseSettings = AzureAiStudioServiceSettings.fromMap(map, validationException, context);
+ SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
+ Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
+ Integer maxTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
+
+ Boolean dimensionsSetByUser = extractOptionalBoolean(map, DIMENSIONS_SET_BY_USER, validationException);
+
+ switch (context) {
+ case REQUEST -> {
+ if (dimensionsSetByUser != null) {
+ validationException.addValidationError(
+ ServiceUtils.invalidSettingError(DIMENSIONS_SET_BY_USER, ModelConfigurations.SERVICE_SETTINGS)
+ );
+ }
+ dimensionsSetByUser = dims != null;
+ }
+ case PERSISTENT -> {
+ if (dimensionsSetByUser == null) {
+ validationException.addValidationError(
+ ServiceUtils.missingSettingErrorMsg(DIMENSIONS_SET_BY_USER, ModelConfigurations.SERVICE_SETTINGS)
+ );
+ }
+ }
+ }
+ return new AzureAiStudioEmbeddingCommonFields(baseSettings, dims, dimensionsSetByUser, maxTokens, similarity);
+ }
+
+ private record AzureAiStudioEmbeddingCommonFields(
+ BaseAzureAiStudioCommonFields baseCommonFields,
+ @Nullable Integer dimensions,
+ Boolean dimensionsSetByUser,
+ @Nullable Integer maxInputTokens,
+ SimilarityMeasure similarity
+ ) {}
+
+ public AzureAiStudioEmbeddingsServiceSettings(
+ String target,
+ AzureAiStudioProvider provider,
+ AzureAiStudioEndpointType endpointType,
+ @Nullable Integer dimensions,
+ Boolean dimensionsSetByUser,
+ @Nullable Integer maxInputTokens,
+ @Nullable SimilarityMeasure similarity,
+ RateLimitSettings rateLimitSettings
+ ) {
+ super(target, provider, endpointType, rateLimitSettings);
+ this.dimensions = dimensions;
+ this.dimensionsSetByUser = dimensionsSetByUser;
+ this.maxInputTokens = maxInputTokens;
+ this.similarity = similarity;
+ }
+
+ public AzureAiStudioEmbeddingsServiceSettings(StreamInput in) throws IOException {
+ super(in);
+ this.dimensions = in.readOptionalVInt();
+ this.dimensionsSetByUser = in.readBoolean();
+ this.maxInputTokens = in.readOptionalVInt();
+ this.similarity = in.readOptionalEnum(SimilarityMeasure.class);
+ }
+
+ private AzureAiStudioEmbeddingsServiceSettings(AzureAiStudioEmbeddingCommonFields fields) {
+ this(
+ fields.baseCommonFields.target(),
+ fields.baseCommonFields.provider(),
+ fields.baseCommonFields.endpointType(),
+ fields.dimensions(),
+ fields.dimensionsSetByUser(),
+ fields.maxInputTokens(),
+ fields.similarity(),
+ fields.baseCommonFields.rateLimitSettings()
+ );
+ }
+
+ private final Integer dimensions;
+ private final Boolean dimensionsSetByUser;
+ private final Integer maxInputTokens;
+ private final SimilarityMeasure similarity;
+
+ @Override
+ public SimilarityMeasure similarity() {
+ return similarity;
+ }
+
+ public boolean dimensionsSetByUser() {
+ return this.dimensionsSetByUser;
+ }
+
+ public Integer dimensions() {
+ return dimensions;
+ }
+
+ public Integer maxInputTokens() {
+ return maxInputTokens;
+ }
+
+ @Override
+ public String getWriteableName() {
+ return NAME;
+ }
+
+ @Override
+ public TransportVersion getMinimalSupportedVersion() {
+ return TransportVersions.ML_INFERENCE_AZURE_AI_STUDIO;
+ }
+
+ @Override
+ public DenseVectorFieldMapper.ElementType elementType() {
+ return DenseVectorFieldMapper.ElementType.FLOAT;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ super.writeTo(out);
+ out.writeOptionalVInt(dimensions);
+ out.writeBoolean(dimensionsSetByUser);
+ out.writeOptionalVInt(maxInputTokens);
+ out.writeOptionalEnum(similarity);
+ }
+
+ private void addXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
+ if (dimensions != null) {
+ builder.field(DIMENSIONS, dimensions);
+ }
+ if (maxInputTokens != null) {
+ builder.field(MAX_INPUT_TOKENS, maxInputTokens);
+ }
+ if (similarity != null) {
+ builder.field(SIMILARITY, similarity);
+ }
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+
+ super.addXContentFields(builder, params);
+ addXContentFragmentOfExposedFields(builder, params);
+ builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser);
+
+ builder.endObject();
+ return builder;
+ }
+
+ @Override
+ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, ToXContent.Params params) throws IOException {
+ super.addExposedXContentFields(builder, params);
+ addXContentFragmentOfExposedFields(builder, params);
+ return builder;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ AzureAiStudioEmbeddingsServiceSettings that = (AzureAiStudioEmbeddingsServiceSettings) o;
+
+ return Objects.equals(target, that.target)
+ && Objects.equals(provider, that.provider)
+ && Objects.equals(endpointType, that.endpointType)
+ && Objects.equals(dimensions, that.dimensions)
+ && Objects.equals(dimensionsSetByUser, that.dimensionsSetByUser)
+ && Objects.equals(maxInputTokens, that.maxInputTokens)
+ && Objects.equals(similarity, that.similarity)
+ && Objects.equals(rateLimitSettings, that.rateLimitSettings);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(target, provider, endpointType, dimensions, dimensionsSetByUser, maxInputTokens, similarity, rateLimitSettings);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsTaskSettings.java
new file mode 100644
index 0000000000000..dc001993b366f
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsTaskSettings.java
@@ -0,0 +1,109 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio.embeddings;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.TaskSettings;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsTaskSettings;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.USER_FIELD;
+
+public class AzureAiStudioEmbeddingsTaskSettings implements TaskSettings {
+ public static final String NAME = "azure_ai_studio_embeddings_task_settings";
+
+ public static AzureAiStudioEmbeddingsTaskSettings fromMap(Map map) {
+ ValidationException validationException = new ValidationException();
+
+ String user = extractOptionalString(map, USER_FIELD, ModelConfigurations.TASK_SETTINGS, validationException);
+ if (validationException.validationErrors().isEmpty() == false) {
+ throw validationException;
+ }
+
+ return new AzureAiStudioEmbeddingsTaskSettings(user);
+ }
+
+ /**
+ * Creates a new {@link AzureOpenAiEmbeddingsTaskSettings} object by overriding the values in originalSettings with the ones
+ * passed in via requestSettings if the fields are not null.
+ *
+ * @param originalSettings the original {@link AzureOpenAiEmbeddingsTaskSettings} from the inference entity configuration from storage
+ * @param requestSettings the {@link AzureOpenAiEmbeddingsTaskSettings} from the request
+ * @return a new {@link AzureOpenAiEmbeddingsTaskSettings}
+ */
+ public static AzureAiStudioEmbeddingsTaskSettings of(
+ AzureAiStudioEmbeddingsTaskSettings originalSettings,
+ AzureAiStudioEmbeddingsRequestTaskSettings requestSettings
+ ) {
+ var userToUse = requestSettings.user() == null ? originalSettings.user : requestSettings.user();
+ return new AzureAiStudioEmbeddingsTaskSettings(userToUse);
+ }
+
+ public AzureAiStudioEmbeddingsTaskSettings(@Nullable String user) {
+ this.user = user;
+ }
+
+ public AzureAiStudioEmbeddingsTaskSettings(StreamInput in) throws IOException {
+ this.user = in.readOptionalString();
+ }
+
+ private final String user;
+
+ public String user() {
+ return this.user;
+ }
+
+ @Override
+ public String getWriteableName() {
+ return NAME;
+ }
+
+ @Override
+ public TransportVersion getMinimalSupportedVersion() {
+ return TransportVersions.ML_INFERENCE_AZURE_OPENAI_EMBEDDINGS;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeOptionalString(this.user);
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+ if (user != null) {
+ builder.field(USER_FIELD, user);
+ }
+ builder.endObject();
+ return builder;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ AzureAiStudioEmbeddingsTaskSettings that = (AzureAiStudioEmbeddingsTaskSettings) o;
+ return Objects.equals(user, that.user);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(user);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java
index deb1cfb901602..11dbf673ab7bd 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java
@@ -32,6 +32,7 @@
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
+import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel;
@@ -130,6 +131,7 @@ private static CohereModel createModel(
context
);
case RERANK -> new CohereRerankModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context);
+ case COMPLETION -> new CohereCompletionModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings);
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
};
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModel.java
new file mode 100644
index 0000000000000..761081d4d723c
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModel.java
@@ -0,0 +1,86 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.cohere.completion;
+
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.EmptyTaskSettings;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.TaskSettings;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.action.cohere.CohereActionVisitor;
+import org.elasticsearch.xpack.inference.services.cohere.CohereModel;
+import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
+
+import java.net.URI;
+import java.util.Map;
+
+public class CohereCompletionModel extends CohereModel {
+
+ public CohereCompletionModel(
+ String modelId,
+ TaskType taskType,
+ String service,
+ Map serviceSettings,
+ Map taskSettings,
+ @Nullable Map secrets
+ ) {
+ this(
+ modelId,
+ taskType,
+ service,
+ CohereCompletionServiceSettings.fromMap(serviceSettings),
+ EmptyTaskSettings.INSTANCE,
+ DefaultSecretSettings.fromMap(secrets)
+ );
+ }
+
+ // should only be used for testing
+ CohereCompletionModel(
+ String modelId,
+ TaskType taskType,
+ String service,
+ CohereCompletionServiceSettings serviceSettings,
+ TaskSettings taskSettings,
+ @Nullable DefaultSecretSettings secretSettings
+ ) {
+ super(
+ new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings),
+ new ModelSecrets(secretSettings),
+ secretSettings,
+ serviceSettings
+ );
+ }
+
+ @Override
+ public CohereCompletionServiceSettings getServiceSettings() {
+ return (CohereCompletionServiceSettings) super.getServiceSettings();
+ }
+
+ @Override
+ public TaskSettings getTaskSettings() {
+ return super.getTaskSettings();
+ }
+
+ @Override
+ public DefaultSecretSettings getSecretSettings() {
+ return (DefaultSecretSettings) super.getSecretSettings();
+ }
+
+ @Override
+ public ExecutableAction accept(CohereActionVisitor visitor, Map taskSettings, InputType inputType) {
+ return visitor.create(this, taskSettings);
+ }
+
+ @Override
+ public URI uri() {
+ return getServiceSettings().uri();
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettings.java
new file mode 100644
index 0000000000000..2a22f6333f1a2
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettings.java
@@ -0,0 +1,148 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.cohere.completion;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.cohere.CohereRateLimitServiceSettings;
+import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+
+import java.io.IOException;
+import java.net.URI;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
+import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createOptionalUri;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
+
+public class CohereCompletionServiceSettings extends FilteredXContentObject implements ServiceSettings, CohereRateLimitServiceSettings {
+
+ public static final String NAME = "cohere_completion_service_settings";
+
+ // Production key rate limits for all endpoints: https://docs.cohere.com/docs/going-live#production-key-specifications
+ // 10K requests per minute
+ private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000);
+
+ public static CohereCompletionServiceSettings fromMap(Map map) {
+ ValidationException validationException = new ValidationException();
+
+ String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
+ URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
+ RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
+ String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
+
+ if (validationException.validationErrors().isEmpty() == false) {
+ throw validationException;
+ }
+
+ return new CohereCompletionServiceSettings(uri, modelId, rateLimitSettings);
+ }
+
+ private final URI uri;
+
+ private final String modelId;
+
+ private final RateLimitSettings rateLimitSettings;
+
+ public CohereCompletionServiceSettings(@Nullable URI uri, @Nullable String modelId, @Nullable RateLimitSettings rateLimitSettings) {
+ this.uri = uri;
+ this.modelId = modelId;
+ this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
+ }
+
+ public CohereCompletionServiceSettings(@Nullable String url, @Nullable String modelId, @Nullable RateLimitSettings rateLimitSettings) {
+ this(createOptionalUri(url), modelId, rateLimitSettings);
+ }
+
+ public CohereCompletionServiceSettings(StreamInput in) throws IOException {
+ uri = createOptionalUri(in.readOptionalString());
+ modelId = in.readOptionalString();
+ rateLimitSettings = new RateLimitSettings(in);
+ }
+
+ @Override
+ public RateLimitSettings rateLimitSettings() {
+ return rateLimitSettings;
+ }
+
+ public URI uri() {
+ return uri;
+ }
+
+ public String modelId() {
+ return modelId;
+ }
+
+ @Override
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+ builder.startObject();
+
+ toXContentFragmentOfExposedFields(builder, params);
+ rateLimitSettings.toXContent(builder, params);
+
+ builder.endObject();
+ return builder;
+ }
+
+ @Override
+ public String getWriteableName() {
+ return NAME;
+ }
+
+ @Override
+ public TransportVersion getMinimalSupportedVersion() {
+ return TransportVersions.ML_INFERENCE_COHERE_COMPLETION_ADDED;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ var uriToWrite = uri != null ? uri.toString() : null;
+ out.writeOptionalString(uriToWrite);
+ out.writeOptionalString(modelId);
+ rateLimitSettings.writeTo(out);
+ }
+
+ @Override
+ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
+ if (uri != null) {
+ builder.field(URL, uri.toString());
+ }
+
+ if (modelId != null) {
+ builder.field(MODEL_ID, modelId);
+ }
+
+ return builder;
+ }
+
+ @Override
+ public boolean equals(Object object) {
+ if (this == object) return true;
+ if (object == null || getClass() != object.getClass()) return false;
+ CohereCompletionServiceSettings that = (CohereCompletionServiceSettings) object;
+ return Objects.equals(uri, that.uri)
+ && Objects.equals(modelId, that.modelId)
+ && Objects.equals(rateLimitSettings, that.rateLimitSettings);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(uri, modelId, rateLimitSettings);
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/InternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/InternalServiceSettings.java
index 854722d989340..ee7db662b4997 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/InternalServiceSettings.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/InternalServiceSettings.java
@@ -41,7 +41,7 @@ protected static void validateParameters(Integer numAllocations, ValidationExcep
);
} else if (numAllocations < 1) {
validationException.addValidationError(
- ServiceUtils.mustBeAPositiveNumberErrorMessage(NUM_ALLOCATIONS, ModelConfigurations.SERVICE_SETTINGS, numAllocations)
+ ServiceUtils.mustBeAPositiveIntegerErrorMessage(NUM_ALLOCATIONS, ModelConfigurations.SERVICE_SETTINGS, numAllocations)
);
}
@@ -49,7 +49,7 @@ protected static void validateParameters(Integer numAllocations, ValidationExcep
validationException.addValidationError(ServiceUtils.missingSettingErrorMsg(NUM_THREADS, ModelConfigurations.SERVICE_SETTINGS));
} else if (numThreads < 1) {
validationException.addValidationError(
- ServiceUtils.mustBeAPositiveNumberErrorMessage(NUM_THREADS, ModelConfigurations.SERVICE_SETTINGS, numThreads)
+ ServiceUtils.mustBeAPositiveIntegerErrorMessage(NUM_THREADS, ModelConfigurations.SERVICE_SETTINGS, numThreads)
);
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettings.java
index 985168c7ccfd1..cfc375a525dd6 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettings.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettings.java
@@ -19,7 +19,7 @@
import java.util.Objects;
import java.util.concurrent.TimeUnit;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveLong;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
public class RateLimitSettings implements Writeable, ToXContentFragment {
@@ -32,7 +32,7 @@ public class RateLimitSettings implements Writeable, ToXContentFragment {
public static RateLimitSettings of(Map map, RateLimitSettings defaultValue, ValidationException validationException) {
Map settings = removeFromMapOrDefaultEmpty(map, FIELD_NAME);
- var requestsPerMinute = extractOptionalPositiveInteger(settings, REQUESTS_PER_MINUTE_FIELD, FIELD_NAME, validationException);
+ var requestsPerMinute = extractOptionalPositiveLong(settings, REQUESTS_PER_MINUTE_FIELD, FIELD_NAME, validationException);
return requestsPerMinute == null ? defaultValue : new RateLimitSettings(requestsPerMinute);
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java
new file mode 100644
index 0000000000000..15d082f455130
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java
@@ -0,0 +1,229 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.action.azureaistudio;
+
+import org.apache.http.HttpHeaders;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.http.MockRequest;
+import org.elasticsearch.test.http.MockResponse;
+import org.elasticsearch.test.http.MockWebServer;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.inference.common.TruncatorTests;
+import org.elasticsearch.xpack.inference.external.action.openai.OpenAiChatCompletionActionTests;
+import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
+import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
+import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.services.ServiceComponents;
+import org.elasticsearch.xpack.inference.services.ServiceComponentsTests;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider;
+import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModelTests;
+import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModelTests;
+import org.junit.After;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+
+import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
+import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
+import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
+import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
+import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.API_KEY_HEADER;
+import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+
+public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
+ private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
+ private final MockWebServer webServer = new MockWebServer();
+ private ThreadPool threadPool;
+ private HttpClientManager clientManager;
+
+ @Before
+ public void init() throws Exception {
+ webServer.start();
+ threadPool = createThreadPool(inferenceUtilityPool());
+ clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
+ }
+
+ @After
+ public void shutdown() throws IOException {
+ clientManager.close();
+ terminate(threadPool);
+ webServer.close();
+ }
+
+ public void testEmbeddingsRequestAction() throws IOException {
+ var senderFactory = new HttpRequestSender.Factory(
+ ServiceComponentsTests.createWithEmptySettings(threadPool),
+ clientManager,
+ mockClusterServiceEmpty()
+ );
+
+ var timeoutSettings = buildSettingsWithRetryFields(
+ TimeValue.timeValueMillis(1),
+ TimeValue.timeValueMinutes(1),
+ TimeValue.timeValueSeconds(0)
+ );
+
+ var serviceComponents = new ServiceComponents(
+ threadPool,
+ mock(ThrottlerManager.class),
+ timeoutSettings,
+ TruncatorTests.createTruncator()
+ );
+
+ try (var sender = senderFactory.createSender("test_service")) {
+ sender.start();
+
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testEmbeddingsTokenResponseJson));
+
+ var model = AzureAiStudioEmbeddingsModelTests.createModel(
+ "id",
+ "http://will-be-replaced.local",
+ AzureAiStudioProvider.OPENAI,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey"
+ );
+ model.setURI(getUrl(webServer));
+
+ var creator = new AzureAiStudioActionCreator(sender, serviceComponents);
+ var action = creator.create(model, Map.of());
+ PlainActionFuture listener = new PlainActionFuture<>();
+ action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+ var result = listener.actionGet(TIMEOUT);
+
+ assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F)))));
+ assertThat(webServer.requests(), hasSize(1));
+ assertNull(webServer.requests().get(0).getUri().getQuery());
+ assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
+ assertThat(webServer.requests().get(0).getHeader(API_KEY_HEADER), equalTo("apikey"));
+
+ var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+ assertThat(requestMap.size(), is(1));
+ assertThat(requestMap.get("input"), is(List.of("abc")));
+ }
+ }
+
+ public void testChatCompletionRequestAction() throws IOException {
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+ var timeoutSettings = buildSettingsWithRetryFields(
+ TimeValue.timeValueMillis(1),
+ TimeValue.timeValueMinutes(1),
+ TimeValue.timeValueSeconds(0)
+ );
+
+ var serviceComponents = new ServiceComponents(
+ threadPool,
+ mock(ThrottlerManager.class),
+ timeoutSettings,
+ TruncatorTests.createTruncator()
+ );
+
+ try (var sender = senderFactory.createSender("test_service")) {
+ sender.start();
+
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testCompletionTokenResponseJson));
+ var webserverUrl = getUrl(webServer);
+ var model = AzureAiStudioChatCompletionModelTests.createModel(
+ "id",
+ "http://will-be-replaced.local",
+ AzureAiStudioProvider.COHERE,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey"
+ );
+ model.setURI(webserverUrl);
+
+ var creator = new AzureAiStudioActionCreator(sender, serviceComponents);
+ var action = creator.create(model, Map.of());
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+ var result = listener.actionGet(TIMEOUT);
+
+ assertThat(
+ result.asMap(),
+ is(OpenAiChatCompletionActionTests.buildExpectedChatCompletionResultMap(List.of("test input string")))
+ );
+ assertThat(webServer.requests(), hasSize(1));
+
+ MockRequest request = webServer.requests().get(0);
+
+ assertNull(request.getUri().getQuery());
+ assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
+ assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo("apikey"));
+
+ var requestMap = entityAsMap(request.getBody());
+ assertThat(requestMap.size(), is(1));
+ assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc"))));
+ }
+ }
+
+ private static String testEmbeddingsTokenResponseJson = """
+ {
+ "object": "list",
+ "data": [
+ {
+ "object": "embedding",
+ "index": 0,
+ "embedding": [
+ 0.0123,
+ -0.0123
+ ]
+ }
+ ],
+ "model": "text-embedding-ada-002-v2",
+ "usage": {
+ "prompt_tokens": 8,
+ "total_tokens": 8
+ }
+ }
+ """;
+
+ private static String testCompletionTokenResponseJson = """
+ {
+ "choices": [
+ {
+ "finish_reason": "stop",
+ "index": 0,
+ "message": {
+ "content": "test input string",
+ "role": "assistant",
+ "tool_calls": null
+ }
+ }
+ ],
+ "created": 1714006424,
+ "id": "f92b5b4d-0de3-4152-a3c6-5aae8a74555c",
+ "model": "",
+ "object": "chat.completion",
+ "usage": {
+ "completion_tokens": 35,
+ "prompt_tokens": 8,
+ "total_tokens": 43
+ }
+ }""";
+
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java
index 73b627742ab03..8d63072b5d7aa 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java
@@ -24,6 +24,7 @@
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
+import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModelTests;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
@@ -39,6 +40,7 @@
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
+import static org.elasticsearch.xpack.inference.external.action.cohere.CohereCompletionActionTests.buildExpectedChatCompletionResultMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation;
@@ -148,4 +150,124 @@ public void testCreate_CohereEmbeddingsModel() throws IOException {
);
}
}
+
+ public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOException {
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+ try (var sender = senderFactory.createSender("test_service")) {
+ sender.start();
+
+ String responseJson = """
+ {
+ "response_id": "some id",
+ "text": "result",
+ "generation_id": "some id",
+ "chat_history": [
+ {
+ "role": "USER",
+ "message": "input"
+ },
+ {
+ "role": "CHATBOT",
+ "message": "result"
+ }
+ ],
+ "finish_reason": "COMPLETE",
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "input_tokens": 4,
+ "output_tokens": 191
+ },
+ "tokens": {
+ "input_tokens": 70,
+ "output_tokens": 191
+ }
+ }
+ }
+ """;
+
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+ var model = CohereCompletionModelTests.createModel(getUrl(webServer), "secret", "model");
+ var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool));
+ var action = actionCreator.create(model, Map.of());
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+ var result = listener.actionGet(TIMEOUT);
+
+ assertThat(result.asMap(), is(buildExpectedChatCompletionResultMap(List.of("result"))));
+ assertThat(webServer.requests(), hasSize(1));
+ assertNull(webServer.requests().get(0).getUri().getQuery());
+ assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), is(XContentType.JSON.mediaType()));
+ assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), is("Bearer secret"));
+
+ var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+ assertThat(requestMap, is(Map.of("message", "abc", "model", "model")));
+ }
+ }
+
+ public void testCreate_CohereCompletionModel_WithoutModelSpecified() throws IOException {
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+ try (var sender = senderFactory.createSender("test_service")) {
+ sender.start();
+
+ String responseJson = """
+ {
+ "response_id": "some id",
+ "text": "result",
+ "generation_id": "some id",
+ "chat_history": [
+ {
+ "role": "USER",
+ "message": "input"
+ },
+ {
+ "role": "CHATBOT",
+ "message": "result"
+ }
+ ],
+ "finish_reason": "COMPLETE",
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "input_tokens": 4,
+ "output_tokens": 191
+ },
+ "tokens": {
+ "input_tokens": 70,
+ "output_tokens": 191
+ }
+ }
+ }
+ """;
+
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+ var model = CohereCompletionModelTests.createModel(getUrl(webServer), "secret", null);
+ var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool));
+ var action = actionCreator.create(model, Map.of());
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+ var result = listener.actionGet(TIMEOUT);
+
+ assertThat(result.asMap(), is(buildExpectedChatCompletionResultMap(List.of("result"))));
+ assertThat(webServer.requests(), hasSize(1));
+ assertNull(webServer.requests().get(0).getUri().getQuery());
+ assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), is(XContentType.JSON.mediaType()));
+ assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), is("Bearer secret"));
+
+ var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+ assertThat(requestMap, is(Map.of("message", "abc")));
+ }
+ }
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java
new file mode 100644
index 0000000000000..195f2bab1d6b5
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java
@@ -0,0 +1,353 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.action.cohere;
+
+import org.apache.http.HttpHeaders;
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.http.MockResponse;
+import org.elasticsearch.test.http.MockWebServer;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
+import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
+import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
+import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.external.request.cohere.CohereUtils;
+import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModelTests;
+import org.junit.After;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+
+import static org.elasticsearch.core.Strings.format;
+import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
+import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
+import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.mock;
+
+public class CohereCompletionActionTests extends ESTestCase {
+
+ private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
+ private final MockWebServer webServer = new MockWebServer();
+ private ThreadPool threadPool;
+ private HttpClientManager clientManager;
+
+ @Before
+ public void init() throws Exception {
+ webServer.start();
+ threadPool = createThreadPool(inferenceUtilityPool());
+ clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
+ }
+
+ @After
+ public void shutdown() throws IOException {
+ clientManager.close();
+ terminate(threadPool);
+ webServer.close();
+ }
+
+ public void testExecute_ReturnsSuccessfulResponse_WithModelSpecified() throws IOException {
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+ try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) {
+ sender.start();
+
+ String responseJson = """
+ {
+ "response_id": "some id",
+ "text": "result",
+ "generation_id": "some id",
+ "chat_history": [
+ {
+ "role": "USER",
+ "message": "input"
+ },
+ {
+ "role": "CHATBOT",
+ "message": "result"
+ }
+ ],
+ "finish_reason": "COMPLETE",
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "input_tokens": 4,
+ "output_tokens": 191
+ },
+ "tokens": {
+ "input_tokens": 70,
+ "output_tokens": 191
+ }
+ }
+ }
+ """;
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+ var action = createAction(getUrl(webServer), "secret", "model", sender);
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+ var result = listener.actionGet(TIMEOUT);
+
+ assertThat(result.asMap(), is(buildExpectedChatCompletionResultMap(List.of("result"))));
+ assertThat(webServer.requests(), hasSize(1));
+ assertNull(webServer.requests().get(0).getUri().getQuery());
+ assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
+ assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
+ assertThat(
+ webServer.requests().get(0).getHeader(CohereUtils.REQUEST_SOURCE_HEADER),
+ equalTo(CohereUtils.ELASTIC_REQUEST_SOURCE)
+ );
+
+ var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+ assertThat(requestMap, is(Map.of("message", "abc", "model", "model")));
+ }
+ }
+
+ public void testExecute_ReturnsSuccessfulResponse_WithoutModelSpecified() throws IOException {
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+ try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) {
+ sender.start();
+
+ String responseJson = """
+ {
+ "response_id": "some id",
+ "text": "result",
+ "generation_id": "some id",
+ "chat_history": [
+ {
+ "role": "USER",
+ "message": "input"
+ },
+ {
+ "role": "CHATBOT",
+ "message": "result"
+ }
+ ],
+ "finish_reason": "COMPLETE",
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "input_tokens": 4,
+ "output_tokens": 191
+ },
+ "tokens": {
+ "input_tokens": 70,
+ "output_tokens": 191
+ }
+ }
+ }
+ """;
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+ var action = createAction(getUrl(webServer), "secret", null, sender);
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+ var result = listener.actionGet(TIMEOUT);
+
+ assertThat(result.asMap(), is(buildExpectedChatCompletionResultMap(List.of("result"))));
+ assertThat(webServer.requests(), hasSize(1));
+ assertNull(webServer.requests().get(0).getUri().getQuery());
+ assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
+ assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
+ assertThat(
+ webServer.requests().get(0).getHeader(CohereUtils.REQUEST_SOURCE_HEADER),
+ equalTo(CohereUtils.ELASTIC_REQUEST_SOURCE)
+ );
+
+ var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+ assertThat(requestMap, is(Map.of("message", "abc")));
+ }
+ }
+
+ public void testExecute_ThrowsURISyntaxException_ForInvalidUrl() throws IOException {
+ try (var sender = mock(Sender.class)) {
+ var thrownException = expectThrows(IllegalArgumentException.class, () -> createAction("a^b", "api key", "model", sender));
+ assertThat(thrownException.getMessage(), is("unable to parse url [a^b]"));
+ }
+ }
+
+ public void testExecute_ThrowsElasticsearchException() {
+ var sender = mock(Sender.class);
+ doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any());
+
+ var action = createAction(getUrl(webServer), "secret", "model", sender);
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+ var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+
+ assertThat(thrownException.getMessage(), is("failed"));
+ }
+
+ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() {
+ var sender = mock(Sender.class);
+
+ doAnswer(invocation -> {
+ @SuppressWarnings("unchecked")
+ ActionListener listener = (ActionListener) invocation.getArguments()[2];
+ listener.onFailure(new IllegalStateException("failed"));
+
+ return Void.TYPE;
+ }).when(sender).send(any(), any(), any(), any());
+
+ var action = createAction(getUrl(webServer), "secret", "model", sender);
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+ var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+
+ assertThat(thrownException.getMessage(), is(format("Failed to send Cohere completion request to [%s]", getUrl(webServer))));
+ }
+
+ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled_WhenUrlIsNull() {
+ var sender = mock(Sender.class);
+
+ doAnswer(invocation -> {
+ @SuppressWarnings("unchecked")
+ ActionListener listener = (ActionListener) invocation.getArguments()[2];
+ listener.onFailure(new IllegalStateException("failed"));
+
+ return Void.TYPE;
+ }).when(sender).send(any(), any(), any(), any());
+
+ var action = createAction(null, "secret", "model", sender);
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+ var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+
+ assertThat(thrownException.getMessage(), is(format("Failed to send Cohere completion request", getUrl(webServer))));
+ }
+
+ public void testExecute_ThrowsException() {
+ var sender = mock(Sender.class);
+ doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any());
+
+ var action = createAction(getUrl(webServer), "secret", "model", sender);
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+ var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+
+ assertThat(thrownException.getMessage(), is(format("Failed to send Cohere completion request to [%s]", getUrl(webServer))));
+ }
+
+ public void testExecute_ThrowsExceptionWithNullUrl() {
+ var sender = mock(Sender.class);
+ doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any());
+
+ var action = createAction(null, "secret", "model", sender);
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+ var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+
+ assertThat(thrownException.getMessage(), is("Failed to send Cohere completion request"));
+ }
+
+ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException {
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+ try (var sender = senderFactory.createSender("test_service")) {
+ sender.start();
+
+ String responseJson = """
+ {
+ "response_id": "some id",
+ "text": "result",
+ "generation_id": "some id",
+ "chat_history": [
+ {
+ "role": "USER",
+ "message": "input"
+ },
+ {
+ "role": "CHATBOT",
+ "message": "result"
+ }
+ ],
+ "finish_reason": "COMPLETE",
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "input_tokens": 4,
+ "output_tokens": 191
+ },
+ "tokens": {
+ "input_tokens": 70,
+ "output_tokens": 191
+ }
+ }
+ }
+ """;
+
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+ var action = createAction(getUrl(webServer), "secret", "model", sender);
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ action.execute(new DocumentsOnlyInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+ var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
+
+ assertThat(thrownException.getMessage(), is("Cohere completion only accepts 1 input"));
+ assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST));
+ }
+ }
+
+ public static Map buildExpectedChatCompletionResultMap(List results) {
+ return Map.of(
+ ChatCompletionResults.COMPLETION,
+ results.stream().map(result -> Map.of(ChatCompletionResults.Result.RESULT, result)).toList()
+ );
+ }
+
+ private CohereCompletionAction createAction(String url, String apiKey, @Nullable String modelName, Sender sender) {
+ var model = CohereCompletionModelTests.createModel(url, apiKey, modelName);
+
+ return new CohereCompletionAction(sender, model, threadPool);
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioChatCompletionRequestEntityTests.java
new file mode 100644
index 0000000000000..3b086f4d3b900
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioChatCompletionRequestEntityTests.java
@@ -0,0 +1,227 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.azureaistudio;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.hamcrest.CoreMatchers.is;
+
+public class AzureAiStudioChatCompletionRequestEntityTests extends ESTestCase {
+
+ public void testToXContent_WhenTokenEndpoint_NoParameters() throws IOException {
+ var entity = new AzureAiStudioChatCompletionRequestEntity(List.of("abc"), AzureAiStudioEndpointType.TOKEN, null, null, null, null);
+ var request = getXContentAsString(entity);
+ var expectedRequest = getExpectedTokenEndpointRequest(List.of("abc"), null, null, null, null);
+ assertThat(request, is(expectedRequest));
+ }
+
+ public void testToXContent_WhenTokenEndpoint_WithTemperatureParam() throws IOException {
+ var entity = new AzureAiStudioChatCompletionRequestEntity(List.of("abc"), AzureAiStudioEndpointType.TOKEN, 1.0, null, null, null);
+ var request = getXContentAsString(entity);
+ var expectedRequest = getExpectedTokenEndpointRequest(List.of("abc"), 1.0, null, null, null);
+ assertThat(request, is(expectedRequest));
+ }
+
+ public void testToXContent_WhenTokenEndpoint_WithTopPParam() throws IOException {
+ var entity = new AzureAiStudioChatCompletionRequestEntity(List.of("abc"), AzureAiStudioEndpointType.TOKEN, null, 2.0, null, null);
+ var request = getXContentAsString(entity);
+ var expectedRequest = getExpectedTokenEndpointRequest(List.of("abc"), null, 2.0, null, null);
+ assertThat(request, is(expectedRequest));
+ }
+
+ public void testToXContent_WhenTokenEndpoint_WithDoSampleParam() throws IOException {
+ var entity = new AzureAiStudioChatCompletionRequestEntity(List.of("abc"), AzureAiStudioEndpointType.TOKEN, null, null, true, null);
+ var request = getXContentAsString(entity);
+ var expectedRequest = getExpectedTokenEndpointRequest(List.of("abc"), null, null, true, null);
+ assertThat(request, is(expectedRequest));
+ }
+
+ public void testToXContent_WhenTokenEndpoint_WithMaxNewTokensParam() throws IOException {
+ var entity = new AzureAiStudioChatCompletionRequestEntity(List.of("abc"), AzureAiStudioEndpointType.TOKEN, null, null, null, 512);
+ var request = getXContentAsString(entity);
+ var expectedRequest = getExpectedTokenEndpointRequest(List.of("abc"), null, null, null, 512);
+ assertThat(request, is(expectedRequest));
+ }
+
+ public void testToXContent_WhenRealtimeEndpoint_NoParameters() throws IOException {
+ var entity = new AzureAiStudioChatCompletionRequestEntity(
+ List.of("abc"),
+ AzureAiStudioEndpointType.REALTIME,
+ null,
+ null,
+ null,
+ null
+ );
+ var request = getXContentAsString(entity);
+ var expectedRequest = getExpectedRealtimeEndpointRequest(List.of("abc"), null, null, null, null);
+ assertThat(request, is(expectedRequest));
+ }
+
+ public void testToXContent_WhenRealtimeEndpoint_WithTemperatureParam() throws IOException {
+ var entity = new AzureAiStudioChatCompletionRequestEntity(
+ List.of("abc"),
+ AzureAiStudioEndpointType.REALTIME,
+ 1.0,
+ null,
+ null,
+ null
+ );
+ var request = getXContentAsString(entity);
+ var expectedRequest = getExpectedRealtimeEndpointRequest(List.of("abc"), 1.0, null, null, null);
+ assertThat(request, is(expectedRequest));
+ }
+
+ public void testToXContent_WhenRealtimeEndpoint_WithTopPParam() throws IOException {
+ var entity = new AzureAiStudioChatCompletionRequestEntity(
+ List.of("abc"),
+ AzureAiStudioEndpointType.REALTIME,
+ null,
+ 2.0,
+ null,
+ null
+ );
+ var request = getXContentAsString(entity);
+ var expectedRequest = getExpectedRealtimeEndpointRequest(List.of("abc"), null, 2.0, null, null);
+ assertThat(request, is(expectedRequest));
+ }
+
+ public void testToXContent_WhenRealtimeEndpoint_WithDoSampleParam() throws IOException {
+ var entity = new AzureAiStudioChatCompletionRequestEntity(
+ List.of("abc"),
+ AzureAiStudioEndpointType.REALTIME,
+ null,
+ null,
+ true,
+ null
+ );
+ var request = getXContentAsString(entity);
+ var expectedRequest = getExpectedRealtimeEndpointRequest(List.of("abc"), null, null, true, null);
+ assertThat(request, is(expectedRequest));
+ }
+
+ public void testToXContent_WhenRealtimeEndpoint_WithMaxNewTokensParam() throws IOException {
+ var entity = new AzureAiStudioChatCompletionRequestEntity(
+ List.of("abc"),
+ AzureAiStudioEndpointType.REALTIME,
+ null,
+ null,
+ null,
+ 512
+ );
+ var request = getXContentAsString(entity);
+ var expectedRequest = getExpectedRealtimeEndpointRequest(List.of("abc"), null, null, null, 512);
+ assertThat(request, is(expectedRequest));
+ }
+
+ private String getXContentAsString(AzureAiStudioChatCompletionRequestEntity entity) throws IOException {
+ XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+ entity.toXContent(builder, null);
+ return Strings.toString(builder);
+ }
+
+ private String getExpectedTokenEndpointRequest(
+ List inputs,
+ @Nullable Double temperature,
+ @Nullable Double topP,
+ @Nullable Boolean doSample,
+ @Nullable Integer maxNewTokens
+ ) {
+ String expected = "{";
+
+ expected = addMessageInputs("messages", expected, inputs);
+ expected = addParameters(expected, temperature, topP, doSample, maxNewTokens);
+
+ expected += "}";
+ return expected;
+ }
+
+ private String getExpectedRealtimeEndpointRequest(
+ List inputs,
+ @Nullable Double temperature,
+ @Nullable Double topP,
+ @Nullable Boolean doSample,
+ @Nullable Integer maxNewTokens
+ ) {
+ String expected = "{\"input_data\":{";
+
+ expected = addMessageInputs("input_string", expected, inputs);
+ expected = addParameters(expected, temperature, topP, doSample, maxNewTokens);
+
+ expected += "}}";
+ return expected;
+ }
+
+ private String addMessageInputs(String fieldName, String expected, List inputs) {
+ StringBuilder messages = new StringBuilder(Strings.format("\"%s\":[", fieldName));
+ var hasOne = false;
+ for (String input : inputs) {
+ if (hasOne) {
+ messages.append(",");
+ }
+ messages.append(getMessageString(input));
+ hasOne = true;
+ }
+ messages.append("]");
+
+ return expected + messages;
+ }
+
+ private String getMessageString(String input) {
+ return Strings.format("{\"content\":\"%s\",\"role\":\"user\"}", input);
+ }
+
+ private String addParameters(String expected, Double temperature, Double topP, Boolean doSample, Integer maxNewTokens) {
+ if (temperature == null && topP == null && doSample == null && maxNewTokens == null) {
+ return expected;
+ }
+
+ StringBuilder parameters = new StringBuilder(",\"parameters\":{");
+
+ var hasOne = false;
+ if (temperature != null) {
+ parameters.append(Strings.format("\"temperature\":%.1f", temperature));
+ hasOne = true;
+ }
+
+ if (topP != null) {
+ if (hasOne) {
+ parameters.append(",");
+ }
+ parameters.append(Strings.format("\"top_p\":%.1f", topP));
+ hasOne = true;
+ }
+
+ if (doSample != null) {
+ if (hasOne) {
+ parameters.append(",");
+ }
+ parameters.append(Strings.format("\"do_sample\":%s", doSample.equals(Boolean.TRUE)));
+ hasOne = true;
+ }
+
+ if (maxNewTokens != null) {
+ if (hasOne) {
+ parameters.append(",");
+ }
+ parameters.append(Strings.format("\"max_new_tokens\":%d", maxNewTokens));
+ }
+
+ parameters.append("}");
+
+ return expected + parameters;
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioChatCompletionRequestTests.java
new file mode 100644
index 0000000000000..f3ddf7f9299d9
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioChatCompletionRequestTests.java
@@ -0,0 +1,465 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.azureaistudio;
+
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpPost;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.request.HttpRequest;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider;
+import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModelTests;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
+import static org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils.API_KEY_HEADER;
+import static org.hamcrest.Matchers.aMapWithSize;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.is;
+
+public class AzureAiStudioChatCompletionRequestTests extends ESTestCase {
+
+ public void testCreateRequest_WithOpenAiProviderTokenEndpoint_NoParams() throws IOException {
+ var request = createRequest(
+ "http://openaitarget.local",
+ AzureAiStudioProvider.OPENAI,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey",
+ "abcd"
+ );
+ var httpRequest = request.createHttpRequest();
+
+ var httpPost = validateRequestUrlAndContentType(httpRequest, "http://openaitarget.local");
+ validateRequestApiKey(httpPost, AzureAiStudioProvider.OPENAI, AzureAiStudioEndpointType.TOKEN, "apikey");
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ assertThat(requestMap, aMapWithSize(1));
+ assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abcd"))));
+ }
+
+ public void testCreateRequest_WithOpenAiProviderTokenEndpoint_WithTemperatureParam() throws IOException {
+ var request = createRequest(
+ "http://openaitarget.local",
+ AzureAiStudioProvider.OPENAI,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey",
+ 1.0,
+ null,
+ null,
+ null,
+ "abcd"
+ );
+ var httpRequest = request.createHttpRequest();
+
+ var httpPost = validateRequestUrlAndContentType(httpRequest, "http://openaitarget.local");
+ validateRequestApiKey(httpPost, AzureAiStudioProvider.OPENAI, AzureAiStudioEndpointType.TOKEN, "apikey");
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ assertThat(requestMap, aMapWithSize(2));
+ assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abcd"))));
+ assertThat(requestMap.get("parameters"), is(getParameterMap(1.0, null, null, null)));
+ }
+
+ public void testCreateRequest_WithOpenAiProviderTokenEndpoint_WithTopPParam() throws IOException {
+ var request = createRequest(
+ "http://openaitarget.local",
+ AzureAiStudioProvider.OPENAI,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey",
+ null,
+ 2.0,
+ null,
+ null,
+ "abcd"
+ );
+ var httpRequest = request.createHttpRequest();
+
+ var httpPost = validateRequestUrlAndContentType(httpRequest, "http://openaitarget.local");
+ validateRequestApiKey(httpPost, AzureAiStudioProvider.OPENAI, AzureAiStudioEndpointType.TOKEN, "apikey");
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ assertThat(requestMap, aMapWithSize(2));
+ assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abcd"))));
+ assertThat(requestMap.get("parameters"), is(getParameterMap(null, 2.0, null, null)));
+ }
+
+ public void testCreateRequest_WithOpenAiProviderTokenEndpoint_WithDoSampleParam() throws IOException {
+ var request = createRequest(
+ "http://openaitarget.local",
+ AzureAiStudioProvider.OPENAI,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey",
+ null,
+ null,
+ true,
+ null,
+ "abcd"
+ );
+ var httpRequest = request.createHttpRequest();
+
+ var httpPost = validateRequestUrlAndContentType(httpRequest, "http://openaitarget.local");
+ validateRequestApiKey(httpPost, AzureAiStudioProvider.OPENAI, AzureAiStudioEndpointType.TOKEN, "apikey");
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ assertThat(requestMap, aMapWithSize(2));
+ assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abcd"))));
+ assertThat(requestMap.get("parameters"), is(getParameterMap(null, null, true, null)));
+ }
+
+ public void testCreateRequest_WithOpenAiProviderTokenEndpoint_WithMaxNewTokensParam() throws IOException {
+ var request = createRequest(
+ "http://openaitarget.local",
+ AzureAiStudioProvider.OPENAI,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey",
+ null,
+ null,
+ null,
+ 512,
+ "abcd"
+ );
+ var httpRequest = request.createHttpRequest();
+
+ var httpPost = validateRequestUrlAndContentType(httpRequest, "http://openaitarget.local");
+ validateRequestApiKey(httpPost, AzureAiStudioProvider.OPENAI, AzureAiStudioEndpointType.TOKEN, "apikey");
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ assertThat(requestMap, aMapWithSize(2));
+ assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abcd"))));
+ assertThat(requestMap.get("parameters"), is(getParameterMap(null, null, null, 512)));
+ }
+
+ public void testCreateRequest_WithCohereProviderTokenEndpoint_NoParams() throws IOException {
+ var request = createRequest(
+ "http://coheretarget.local",
+ AzureAiStudioProvider.COHERE,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey",
+ "abcd"
+ );
+ var httpRequest = request.createHttpRequest();
+
+ var httpPost = validateRequestUrlAndContentType(httpRequest, "http://coheretarget.local/v1/chat/completions");
+ validateRequestApiKey(httpPost, AzureAiStudioProvider.COHERE, AzureAiStudioEndpointType.TOKEN, "apikey");
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ assertThat(requestMap, aMapWithSize(1));
+ assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abcd"))));
+ }
+
+ public void testCreateRequest_WithCohereProviderTokenEndpoint_WithTemperatureParam() throws IOException {
+ var request = createRequest(
+ "http://coheretarget.local",
+ AzureAiStudioProvider.COHERE,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey",
+ 1.0,
+ null,
+ null,
+ null,
+ "abcd"
+ );
+ var httpRequest = request.createHttpRequest();
+
+ var httpPost = validateRequestUrlAndContentType(httpRequest, "http://coheretarget.local/v1/chat/completions");
+ validateRequestApiKey(httpPost, AzureAiStudioProvider.COHERE, AzureAiStudioEndpointType.TOKEN, "apikey");
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ assertThat(requestMap, aMapWithSize(2));
+ assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abcd"))));
+ assertThat(requestMap.get("parameters"), is(getParameterMap(1.0, null, null, null)));
+ }
+
+ public void testCreateRequest_WithCohereProviderTokenEndpoint_WithTopPParam() throws IOException {
+ var request = createRequest(
+ "http://coheretarget.local",
+ AzureAiStudioProvider.COHERE,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey",
+ null,
+ 2.0,
+ null,
+ null,
+ "abcd"
+ );
+ var httpRequest = request.createHttpRequest();
+
+ var httpPost = validateRequestUrlAndContentType(httpRequest, "http://coheretarget.local/v1/chat/completions");
+ validateRequestApiKey(httpPost, AzureAiStudioProvider.COHERE, AzureAiStudioEndpointType.TOKEN, "apikey");
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ assertThat(requestMap, aMapWithSize(2));
+ assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abcd"))));
+ assertThat(requestMap.get("parameters"), is(getParameterMap(null, 2.0, null, null)));
+ }
+
+ public void testCreateRequest_WithCohereProviderTokenEndpoint_WithDoSampleParam() throws IOException {
+ var request = createRequest(
+ "http://coheretarget.local",
+ AzureAiStudioProvider.COHERE,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey",
+ null,
+ null,
+ true,
+ null,
+ "abcd"
+ );
+ var httpRequest = request.createHttpRequest();
+
+ var httpPost = validateRequestUrlAndContentType(httpRequest, "http://coheretarget.local/v1/chat/completions");
+ validateRequestApiKey(httpPost, AzureAiStudioProvider.COHERE, AzureAiStudioEndpointType.TOKEN, "apikey");
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ assertThat(requestMap, aMapWithSize(2));
+ assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abcd"))));
+ assertThat(requestMap.get("parameters"), is(getParameterMap(null, null, true, null)));
+ }
+
+ public void testCreateRequest_WithCohereProviderTokenEndpoint_WithMaxNewTokensParam() throws IOException {
+ var request = createRequest(
+ "http://coheretarget.local",
+ AzureAiStudioProvider.COHERE,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey",
+ null,
+ null,
+ null,
+ 512,
+ "abcd"
+ );
+ var httpRequest = request.createHttpRequest();
+
+ var httpPost = validateRequestUrlAndContentType(httpRequest, "http://coheretarget.local/v1/chat/completions");
+ validateRequestApiKey(httpPost, AzureAiStudioProvider.COHERE, AzureAiStudioEndpointType.TOKEN, "apikey");
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ assertThat(requestMap, aMapWithSize(2));
+ assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abcd"))));
+ assertThat(requestMap.get("parameters"), is(getParameterMap(null, null, null, 512)));
+ }
+
+ public void testCreateRequest_WithMistralProviderRealtimeEndpoint_NoParams() throws IOException {
+ var request = createRequest(
+ "http://mistral.local/score",
+ AzureAiStudioProvider.MISTRAL,
+ AzureAiStudioEndpointType.REALTIME,
+ "apikey",
+ "abcd"
+ );
+ var httpRequest = request.createHttpRequest();
+
+ var httpPost = validateRequestUrlAndContentType(httpRequest, "http://mistral.local/score");
+ validateRequestApiKey(httpPost, AzureAiStudioProvider.MISTRAL, AzureAiStudioEndpointType.REALTIME, "apikey");
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ assertThat(requestMap, aMapWithSize(1));
+
+ @SuppressWarnings("unchecked")
+ var input_data = (Map) requestMap.get("input_data");
+ assertThat(input_data, aMapWithSize(1));
+ assertThat(input_data.get("input_string"), is(List.of(Map.of("role", "user", "content", "abcd"))));
+ }
+
+ public void testCreateRequest_WithMistralProviderRealtimeEndpoint_WithTemperatureParam() throws IOException {
+ var request = createRequest(
+ "http://mistral.local/score",
+ AzureAiStudioProvider.MISTRAL,
+ AzureAiStudioEndpointType.REALTIME,
+ "apikey",
+ 1.0,
+ null,
+ null,
+ null,
+ "abcd"
+ );
+ var httpRequest = request.createHttpRequest();
+
+ var httpPost = validateRequestUrlAndContentType(httpRequest, "http://mistral.local/score");
+ validateRequestApiKey(httpPost, AzureAiStudioProvider.MISTRAL, AzureAiStudioEndpointType.REALTIME, "apikey");
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ assertThat(requestMap, aMapWithSize(1));
+
+ @SuppressWarnings("unchecked")
+ var input_data = (Map) requestMap.get("input_data");
+ assertThat(input_data, aMapWithSize(2));
+ assertThat(input_data.get("input_string"), is(List.of(Map.of("role", "user", "content", "abcd"))));
+ assertThat(input_data.get("parameters"), is(getParameterMap(1.0, null, null, null)));
+ }
+
+ public void testCreateRequest_WithMistralProviderRealtimeEndpoint_WithTopPParam() throws IOException {
+ var request = createRequest(
+ "http://mistral.local/score",
+ AzureAiStudioProvider.MISTRAL,
+ AzureAiStudioEndpointType.REALTIME,
+ "apikey",
+ null,
+ 2.0,
+ null,
+ null,
+ "abcd"
+ );
+ var httpRequest = request.createHttpRequest();
+
+ var httpPost = validateRequestUrlAndContentType(httpRequest, "http://mistral.local/score");
+ validateRequestApiKey(httpPost, AzureAiStudioProvider.MISTRAL, AzureAiStudioEndpointType.REALTIME, "apikey");
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ assertThat(requestMap, aMapWithSize(1));
+
+ @SuppressWarnings("unchecked")
+ var input_data = (Map) requestMap.get("input_data");
+ assertThat(input_data, aMapWithSize(2));
+ assertThat(input_data.get("input_string"), is(List.of(Map.of("role", "user", "content", "abcd"))));
+ assertThat(input_data.get("parameters"), is(getParameterMap(null, 2.0, null, null)));
+ }
+
+ public void testCreateRequest_WithMistralProviderRealtimeEndpoint_WithDoSampleParam() throws IOException {
+ var request = createRequest(
+ "http://mistral.local/score",
+ AzureAiStudioProvider.MISTRAL,
+ AzureAiStudioEndpointType.REALTIME,
+ "apikey",
+ null,
+ null,
+ true,
+ null,
+ "abcd"
+ );
+ var httpRequest = request.createHttpRequest();
+
+ var httpPost = validateRequestUrlAndContentType(httpRequest, "http://mistral.local/score");
+ validateRequestApiKey(httpPost, AzureAiStudioProvider.MISTRAL, AzureAiStudioEndpointType.REALTIME, "apikey");
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ assertThat(requestMap, aMapWithSize(1));
+
+ @SuppressWarnings("unchecked")
+ var input_data = (Map) requestMap.get("input_data");
+ assertThat(input_data, aMapWithSize(2));
+ assertThat(input_data.get("input_string"), is(List.of(Map.of("role", "user", "content", "abcd"))));
+ assertThat(input_data.get("parameters"), is(getParameterMap(null, null, true, null)));
+ }
+
+ public void testCreateRequest_WithMistralProviderRealtimeEndpoint_WithMaxNewTokensParam() throws IOException {
+ var request = createRequest(
+ "http://mistral.local/score",
+ AzureAiStudioProvider.MISTRAL,
+ AzureAiStudioEndpointType.REALTIME,
+ "apikey",
+ null,
+ null,
+ null,
+ 512,
+ "abcd"
+ );
+ var httpRequest = request.createHttpRequest();
+
+ var httpPost = validateRequestUrlAndContentType(httpRequest, "http://mistral.local/score");
+ validateRequestApiKey(httpPost, AzureAiStudioProvider.MISTRAL, AzureAiStudioEndpointType.REALTIME, "apikey");
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ assertThat(requestMap, aMapWithSize(1));
+
+ @SuppressWarnings("unchecked")
+ var input_data = (Map) requestMap.get("input_data");
+ assertThat(input_data, aMapWithSize(2));
+ assertThat(input_data.get("input_string"), is(List.of(Map.of("role", "user", "content", "abcd"))));
+ assertThat(input_data.get("parameters"), is(getParameterMap(null, null, null, 512)));
+ }
+
+ private HttpPost validateRequestUrlAndContentType(HttpRequest request, String expectedUrl) throws IOException {
+ assertThat(request.httpRequestBase(), instanceOf(HttpPost.class));
+ var httpPost = (HttpPost) request.httpRequestBase();
+ assertThat(httpPost.getURI().toString(), is(expectedUrl));
+ assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
+ return httpPost;
+ }
+
+ private void validateRequestApiKey(
+ HttpPost httpPost,
+ AzureAiStudioProvider provider,
+ AzureAiStudioEndpointType endpointType,
+ String apiKey
+ ) {
+ if (endpointType == AzureAiStudioEndpointType.TOKEN) {
+ if (provider == AzureAiStudioProvider.OPENAI) {
+ assertThat(httpPost.getLastHeader(API_KEY_HEADER).getValue(), is(apiKey));
+ } else {
+ assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(apiKey));
+ }
+ } else {
+ assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + apiKey));
+ }
+ }
+
+ private Map getParameterMap(
+ @Nullable Double temperature,
+ @Nullable Double topP,
+ @Nullable Boolean doSample,
+ @Nullable Integer maxNewTokens
+ ) {
+ var map = new HashMap();
+ if (temperature != null) {
+ map.put("temperature", temperature);
+ }
+ if (topP != null) {
+ map.put("top_p", topP);
+ }
+ if (doSample != null) {
+ map.put("do_sample", doSample);
+ }
+ if (maxNewTokens != null) {
+ map.put("max_new_tokens", maxNewTokens);
+ }
+ return map;
+ }
+
+ public static AzureAiStudioChatCompletionRequest createRequest(
+ String target,
+ AzureAiStudioProvider provider,
+ AzureAiStudioEndpointType endpointType,
+ String apiKey,
+ String input
+ ) {
+ return createRequest(target, provider, endpointType, apiKey, null, null, null, null, input);
+ }
+
+ public static AzureAiStudioChatCompletionRequest createRequest(
+ String target,
+ AzureAiStudioProvider provider,
+ AzureAiStudioEndpointType endpointType,
+ String apiKey,
+ @Nullable Double temperature,
+ @Nullable Double topP,
+ @Nullable Boolean doSample,
+ @Nullable Integer maxNewTokens,
+ String input
+ ) {
+ var model = AzureAiStudioChatCompletionModelTests.createModel(
+ "id",
+ target,
+ provider,
+ endpointType,
+ apiKey,
+ temperature,
+ topP,
+ doSample,
+ maxNewTokens,
+ null
+ );
+ return new AzureAiStudioChatCompletionRequest(model, List.of(input));
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioEmbeddingsRequestEntityTests.java
new file mode 100644
index 0000000000000..b2df7f7c27564
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioEmbeddingsRequestEntityTests.java
@@ -0,0 +1,77 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.azureaistudio;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.hamcrest.CoreMatchers.is;
+
+public class AzureAiStudioEmbeddingsRequestEntityTests extends ESTestCase {
+ public void testXContent_WritesUserWhenDefined() throws IOException {
+ var entity = new AzureAiStudioEmbeddingsRequestEntity(List.of("abc"), "testuser", null, false);
+
+ XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+ entity.toXContent(builder, null);
+ String xContentResult = Strings.toString(builder);
+
+ assertThat(xContentResult, is("""
+ {"input":["abc"],"user":"testuser"}"""));
+ }
+
+ public void testXContent_DoesNotWriteUserWhenItIsNull() throws IOException {
+ var entity = new AzureAiStudioEmbeddingsRequestEntity(List.of("abc"), null, null, false);
+
+ XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+ entity.toXContent(builder, null);
+ String xContentResult = Strings.toString(builder);
+
+ assertThat(xContentResult, is("""
+ {"input":["abc"]}"""));
+ }
+
+ public void testXContent_DoesNotWriteDimensionsWhenNotSetByUser() throws IOException {
+ var entity = new AzureAiStudioEmbeddingsRequestEntity(List.of("abc"), null, 100, false);
+
+ XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+ entity.toXContent(builder, null);
+ String xContentResult = Strings.toString(builder);
+
+ assertThat(xContentResult, is("""
+ {"input":["abc"]}"""));
+ }
+
+ public void testXContent_DoesNotWriteDimensionsWhenNull_EvenIfSetByUserIsTrue() throws IOException {
+ var entity = new AzureAiStudioEmbeddingsRequestEntity(List.of("abc"), null, null, true);
+
+ XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+ entity.toXContent(builder, null);
+ String xContentResult = Strings.toString(builder);
+
+ assertThat(xContentResult, is("""
+ {"input":["abc"]}"""));
+ }
+
+ public void testXContent_WritesDimensionsWhenNonNull_AndSetByUserIsTrue() throws IOException {
+ var entity = new AzureAiStudioEmbeddingsRequestEntity(List.of("abc"), null, 100, true);
+
+ XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+ entity.toXContent(builder, null);
+ String xContentResult = Strings.toString(builder);
+
+ assertThat(xContentResult, is("""
+ {"input":["abc"],"dimensions":100}"""));
+ }
+
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioEmbeddingsRequestTests.java
new file mode 100644
index 0000000000000..524d813a4da1f
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureaistudio/AzureAiStudioEmbeddingsRequestTests.java
@@ -0,0 +1,185 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.azureaistudio;
+
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpPost;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.common.Truncator;
+import org.elasticsearch.xpack.inference.common.TruncatorTests;
+import org.elasticsearch.xpack.inference.external.request.HttpRequest;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider;
+import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModelTests;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
+import static org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils.API_KEY_HEADER;
+import static org.hamcrest.Matchers.aMapWithSize;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.is;
+
+public class AzureAiStudioEmbeddingsRequestTests extends ESTestCase {
+
+ public void testCreateRequest_WithOpenAiProvider_NoAdditionalParams() throws IOException {
+ var request = createRequest(
+ "http://openaitarget.local",
+ AzureAiStudioProvider.OPENAI,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey",
+ "abcd",
+ null
+ );
+ var httpRequest = request.createHttpRequest();
+ var httpPost = validateRequestUrlAndContentType(httpRequest, "http://openaitarget.local");
+ validateRequestApiKey(httpPost, AzureAiStudioProvider.OPENAI, "apikey");
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ assertThat(requestMap, aMapWithSize(1));
+ assertThat(requestMap.get("input"), is(List.of("abcd")));
+ }
+
+ public void testCreateRequest_WithOpenAiProvider_WithUserParam() throws IOException {
+ var request = createRequest(
+ "http://openaitarget.local",
+ AzureAiStudioProvider.OPENAI,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey",
+ "abcd",
+ "userid"
+ );
+ var httpRequest = request.createHttpRequest();
+ var httpPost = validateRequestUrlAndContentType(httpRequest, "http://openaitarget.local");
+ validateRequestApiKey(httpPost, AzureAiStudioProvider.OPENAI, "apikey");
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ assertThat(requestMap, aMapWithSize(2));
+ assertThat(requestMap.get("input"), is(List.of("abcd")));
+ assertThat(requestMap.get("user"), is("userid"));
+ }
+
+ public void testCreateRequest_WithCohereProvider_NoAdditionalParams() throws IOException {
+ var request = createRequest(
+ "http://coheretarget.local",
+ AzureAiStudioProvider.COHERE,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey",
+ "abcd",
+ null
+ );
+ var httpRequest = request.createHttpRequest();
+ var httpPost = validateRequestUrlAndContentType(httpRequest, "http://coheretarget.local/v1/embeddings");
+ validateRequestApiKey(httpPost, AzureAiStudioProvider.COHERE, "apikey");
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ assertThat(requestMap, aMapWithSize(1));
+ assertThat(requestMap.get("input"), is(List.of("abcd")));
+ }
+
+ public void testCreateRequest_WithCohereProvider_WithUserParam() throws IOException {
+ var request = createRequest(
+ "http://coheretarget.local",
+ AzureAiStudioProvider.COHERE,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey",
+ "abcd",
+ "userid"
+ );
+ var httpRequest = request.createHttpRequest();
+ var httpPost = validateRequestUrlAndContentType(httpRequest, "http://coheretarget.local/v1/embeddings");
+ validateRequestApiKey(httpPost, AzureAiStudioProvider.COHERE, "apikey");
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ assertThat(requestMap, aMapWithSize(2));
+ assertThat(requestMap.get("input"), is(List.of("abcd")));
+ assertThat(requestMap.get("user"), is("userid"));
+ }
+
+ public void testTruncate_ReducesInputTextSizeByHalf() throws IOException {
+ var request = createRequest(
+ "http://openaitarget.local",
+ AzureAiStudioProvider.OPENAI,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey",
+ "abcd",
+ null
+ );
+ var truncatedRequest = request.truncate();
+
+ var httpRequest = truncatedRequest.createHttpRequest();
+ assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
+
+ var httpPost = (HttpPost) httpRequest.httpRequestBase();
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ assertThat(requestMap, aMapWithSize(1));
+ assertThat(requestMap.get("input"), is(List.of("ab")));
+ }
+
+ public void testIsTruncated_ReturnsTrue() {
+ var request = createRequest(
+ "http://openaitarget.local",
+ AzureAiStudioProvider.OPENAI,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey",
+ "abcd",
+ null
+ );
+ assertFalse(request.getTruncationInfo()[0]);
+
+ var truncatedRequest = request.truncate();
+ assertTrue(truncatedRequest.getTruncationInfo()[0]);
+ }
+
+ private HttpPost validateRequestUrlAndContentType(HttpRequest request, String expectedUrl) throws IOException {
+ assertThat(request.httpRequestBase(), instanceOf(HttpPost.class));
+ var httpPost = (HttpPost) request.httpRequestBase();
+ assertThat(httpPost.getURI().toString(), is(expectedUrl));
+ assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
+ return httpPost;
+ }
+
+ private void validateRequestApiKey(HttpPost httpPost, AzureAiStudioProvider provider, String apiKey) {
+ if (provider == AzureAiStudioProvider.OPENAI) {
+ assertThat(httpPost.getLastHeader(API_KEY_HEADER).getValue(), is(apiKey));
+ } else {
+ assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(apiKey));
+ }
+ }
+
+ public static AzureAiStudioEmbeddingsRequest createRequest(
+ String target,
+ AzureAiStudioProvider provider,
+ AzureAiStudioEndpointType endpointType,
+ String apiKey,
+ String input,
+ @Nullable String user
+ ) {
+ var model = AzureAiStudioEmbeddingsModelTests.createModel(
+ "id",
+ target,
+ provider,
+ endpointType,
+ apiKey,
+ null,
+ false,
+ null,
+ null,
+ user,
+ null
+ );
+ return new AzureAiStudioEmbeddingsRequest(
+ TruncatorTests.createTruncator(),
+ new Truncator.TruncationResult(List.of(input), new boolean[] { false }),
+ model
+ );
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereCompletionRequestEntityTests.java
new file mode 100644
index 0000000000000..dbe6a9438d884
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereCompletionRequestEntityTests.java
@@ -0,0 +1,53 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.cohere;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.request.cohere.completion.CohereCompletionRequestEntity;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.hamcrest.CoreMatchers.is;
+
+public class CohereCompletionRequestEntityTests extends ESTestCase {
+
+ public void testXContent_WritesAllFields() throws IOException {
+ var entity = new CohereCompletionRequestEntity(List.of("some input"), "model");
+
+ XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+ entity.toXContent(builder, null);
+ String xContentResult = Strings.toString(builder);
+
+ assertThat(xContentResult, is("""
+ {"message":"some input","model":"model"}"""));
+ }
+
+ public void testXContent_DoesNotWriteModelIfNotSpecified() throws IOException {
+ var entity = new CohereCompletionRequestEntity(List.of("some input"), null);
+
+ XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+ entity.toXContent(builder, null);
+ String xContentResult = Strings.toString(builder);
+
+ assertThat(xContentResult, is("""
+ {"message":"some input"}"""));
+ }
+
+ public void testXContent_ThrowsIfInputIsNull() {
+ expectThrows(NullPointerException.class, () -> new CohereCompletionRequestEntity(null, null));
+ }
+
+ public void testXContent_ThrowsIfMessageInInputIsNull() {
+ expectThrows(NullPointerException.class, () -> new CohereCompletionRequestEntity(List.of((String) null), null));
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereCompletionRequestTests.java
new file mode 100644
index 0000000000000..d6d0d5c00eaf4
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereCompletionRequestTests.java
@@ -0,0 +1,74 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.cohere;
+
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpPost;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.request.cohere.completion.CohereCompletionRequest;
+import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModelTests;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.sameInstance;
+
+public class CohereCompletionRequestTests extends ESTestCase {
+
+ public void testCreateRequest_UrlDefined() throws IOException {
+ var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", null));
+
+ var httpRequest = request.createHttpRequest();
+ assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
+
+ var httpPost = (HttpPost) httpRequest.httpRequestBase();
+
+ assertThat(httpPost.getURI().toString(), is("url"));
+ assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
+ assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
+ assertThat(httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), is(CohereUtils.ELASTIC_REQUEST_SOURCE));
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ assertThat(requestMap, is(Map.of("message", "abc")));
+ }
+
+ public void testCreateRequest_ModelDefined() throws IOException {
+ var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model"));
+
+ var httpRequest = request.createHttpRequest();
+ assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
+
+ var httpPost = (HttpPost) httpRequest.httpRequestBase();
+
+ assertThat(httpPost.getURI().toString(), is("url"));
+ assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
+ assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
+ assertThat(httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), is(CohereUtils.ELASTIC_REQUEST_SOURCE));
+
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
+ assertThat(requestMap, is(Map.of("message", "abc", "model", "model")));
+ }
+
+ public void testTruncate_ReturnsSameInstance() {
+ var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model"));
+ var truncatedRequest = request.truncate();
+
+ assertThat(truncatedRequest, sameInstance(request));
+ }
+
+ public void testTruncationInfo_ReturnsNull() {
+ var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model"));
+
+ assertNull(request.getTruncationInfo());
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRequestTests.java
new file mode 100644
index 0000000000000..444fee7cac3c7
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRequestTests.java
@@ -0,0 +1,36 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.cohere;
+
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpPost;
+import org.elasticsearch.common.settings.SecureString;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.cohere.CohereAccount;
+
+import java.net.URI;
+
+import static org.hamcrest.Matchers.is;
+
+public class CohereRequestTests extends ESTestCase {
+
+ public void testDecorateWithAuthHeader() {
+ var request = new HttpPost("http://www.abc.com");
+
+ CohereRequest.decorateWithAuthHeader(
+ request,
+ new CohereAccount(URI.create("http://www.abc.com"), new SecureString(new char[] { 'a', 'b', 'c' }))
+ );
+
+ assertThat(request.getFirstHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
+ assertThat(request.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer abc"));
+ assertThat(request.getFirstHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), is(CohereUtils.ELASTIC_REQUEST_SOURCE));
+ }
+
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiErrorResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiErrorResponseEntityTests.java
new file mode 100644
index 0000000000000..fd133a26f5532
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiErrorResponseEntityTests.java
@@ -0,0 +1,48 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.response;
+
+import org.apache.http.HttpResponse;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+
+public class AzureAndOpenAiErrorResponseEntityTests extends ESTestCase {
+
+ private static HttpResult getMockResult(String jsonString) {
+ var response = mock(HttpResponse.class);
+ return new HttpResult(response, Strings.toUTF8Bytes(jsonString));
+ }
+
+ public void testErrorResponse_ExtractsError() {
+ var result = getMockResult("""
+ {"error":{"message":"test_error_message"}}""");
+
+ var error = AzureAndOpenAiErrorResponseEntity.fromResponse(result);
+ assertNotNull(error);
+ assertThat(error.getErrorMessage(), is("test_error_message"));
+ }
+
+ public void testErrorResponse_ReturnsNullIfNoError() {
+ var result = getMockResult("""
+ {"noerror":true}""");
+
+ var error = AzureAndOpenAiErrorResponseEntity.fromResponse(result);
+ assertNull(error);
+ }
+
+ public void testErrorResponse_ReturnsNullIfNotJson() {
+ var result = getMockResult("not a json string");
+
+ var error = AzureAndOpenAiErrorResponseEntity.fromResponse(result);
+ assertNull(error);
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiExternalResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiExternalResponseHandlerTests.java
new file mode 100644
index 0000000000000..4c9fb143c3a5c
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiExternalResponseHandlerTests.java
@@ -0,0 +1,245 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.response;
+
+import org.apache.http.Header;
+import org.apache.http.HeaderElement;
+import org.apache.http.HttpResponse;
+import org.apache.http.StatusLine;
+import org.apache.http.message.BasicHeader;
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.http.retry.ContentTooLargeException;
+import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
+import org.elasticsearch.xpack.inference.external.request.RequestTests;
+
+import java.nio.charset.StandardCharsets;
+
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.core.Is.is;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class AzureAndOpenAiExternalResponseHandlerTests extends ESTestCase {
+
+ public void testCheckForFailureStatusCode() {
+ var statusLine = mock(StatusLine.class);
+
+ var httpResponse = mock(HttpResponse.class);
+ when(httpResponse.getStatusLine()).thenReturn(statusLine);
+ var header = mock(Header.class);
+ when(header.getElements()).thenReturn(new HeaderElement[] {});
+ when(httpResponse.getFirstHeader(anyString())).thenReturn(header);
+
+ var mockRequest = RequestTests.mockRequest("id");
+ var httpResult = new HttpResult(httpResponse, new byte[] {});
+ var handler = new AzureAndOpenAiExternalResponseHandler(
+ "",
+ (request, result) -> null,
+ AzureAndOpenAiErrorResponseEntity::fromResponse
+ );
+
+ // 200 ok
+ when(statusLine.getStatusCode()).thenReturn(200);
+ handler.checkForFailureStatusCode(mockRequest, httpResult);
+ // 503
+ when(statusLine.getStatusCode()).thenReturn(503);
+ var retryException = expectThrows(RetryException.class, () -> handler.checkForFailureStatusCode(mockRequest, httpResult));
+ assertTrue(retryException.shouldRetry());
+ assertThat(
+ retryException.getCause().getMessage(),
+ containsString("Received a server busy error status code for request from inference entity id [id] status [503]")
+ );
+ assertThat(((ElasticsearchStatusException) retryException.getCause()).status(), is(RestStatus.BAD_REQUEST));
+ // 501
+ when(statusLine.getStatusCode()).thenReturn(501);
+ retryException = expectThrows(RetryException.class, () -> handler.checkForFailureStatusCode(mockRequest, httpResult));
+ assertFalse(retryException.shouldRetry());
+ assertThat(
+ retryException.getCause().getMessage(),
+ containsString("Received a server error status code for request from inference entity id [id] status [501]")
+ );
+ assertThat(((ElasticsearchStatusException) retryException.getCause()).status(), is(RestStatus.BAD_REQUEST));
+ // 500
+ when(statusLine.getStatusCode()).thenReturn(500);
+ retryException = expectThrows(RetryException.class, () -> handler.checkForFailureStatusCode(mockRequest, httpResult));
+ assertTrue(retryException.shouldRetry());
+ assertThat(
+ retryException.getCause().getMessage(),
+ containsString("Received a server error status code for request from inference entity id [id] status [500]")
+ );
+ assertThat(((ElasticsearchStatusException) retryException.getCause()).status(), is(RestStatus.BAD_REQUEST));
+ // 429
+ when(statusLine.getStatusCode()).thenReturn(429);
+ retryException = expectThrows(RetryException.class, () -> handler.checkForFailureStatusCode(mockRequest, httpResult));
+ assertTrue(retryException.shouldRetry());
+ assertThat(retryException.getCause().getMessage(), containsString("Received a rate limit status code."));
+ assertThat(((ElasticsearchStatusException) retryException.getCause()).status(), is(RestStatus.TOO_MANY_REQUESTS));
+ // 413
+ when(statusLine.getStatusCode()).thenReturn(413);
+ retryException = expectThrows(ContentTooLargeException.class, () -> handler.checkForFailureStatusCode(mockRequest, httpResult));
+ assertTrue(retryException.shouldRetry());
+ assertThat(retryException.getCause().getMessage(), containsString("Received a content too large status code"));
+ assertThat(((ElasticsearchStatusException) retryException.getCause()).status(), is(RestStatus.REQUEST_ENTITY_TOO_LARGE));
+ // 400 content too large
+ retryException = expectThrows(
+ ContentTooLargeException.class,
+ () -> handler.checkForFailureStatusCode(mockRequest, createContentTooLargeResult(400))
+ );
+ assertTrue(retryException.shouldRetry());
+ assertThat(retryException.getCause().getMessage(), containsString("Received a content too large status code"));
+ assertThat(((ElasticsearchStatusException) retryException.getCause()).status(), is(RestStatus.BAD_REQUEST));
+ // 400 generic bad request should not be marked as a content too large
+ when(statusLine.getStatusCode()).thenReturn(400);
+ retryException = expectThrows(RetryException.class, () -> handler.checkForFailureStatusCode(mockRequest, httpResult));
+ assertFalse(retryException.shouldRetry());
+ assertThat(
+ retryException.getCause().getMessage(),
+ containsString("Received an unsuccessful status code for request from inference entity id [id] status [400]")
+ );
+ assertThat(((ElasticsearchStatusException) retryException.getCause()).status(), is(RestStatus.BAD_REQUEST));
+ // 400 is not flagged as a content too large when the error message is different
+ when(statusLine.getStatusCode()).thenReturn(400);
+ retryException = expectThrows(
+ RetryException.class,
+ () -> handler.checkForFailureStatusCode(mockRequest, createResult(400, "blah"))
+ );
+ assertFalse(retryException.shouldRetry());
+ assertThat(
+ retryException.getCause().getMessage(),
+ containsString("Received an unsuccessful status code for request from inference entity id [id] status [400]")
+ );
+ assertThat(((ElasticsearchStatusException) retryException.getCause()).status(), is(RestStatus.BAD_REQUEST));
+ // 401
+ when(statusLine.getStatusCode()).thenReturn(401);
+ retryException = expectThrows(RetryException.class, () -> handler.checkForFailureStatusCode(mockRequest, httpResult));
+ assertFalse(retryException.shouldRetry());
+ assertThat(
+ retryException.getCause().getMessage(),
+ containsString("Received an authentication error status code for request from inference entity id [id] status [401]")
+ );
+ assertThat(((ElasticsearchStatusException) retryException.getCause()).status(), is(RestStatus.UNAUTHORIZED));
+ // 300
+ when(statusLine.getStatusCode()).thenReturn(300);
+ retryException = expectThrows(RetryException.class, () -> handler.checkForFailureStatusCode(mockRequest, httpResult));
+ assertFalse(retryException.shouldRetry());
+ assertThat(
+ retryException.getCause().getMessage(),
+ containsString("Unhandled redirection for request from inference entity id [id] status [300]")
+ );
+ assertThat(((ElasticsearchStatusException) retryException.getCause()).status(), is(RestStatus.MULTIPLE_CHOICES));
+ // 402
+ when(statusLine.getStatusCode()).thenReturn(402);
+ retryException = expectThrows(RetryException.class, () -> handler.checkForFailureStatusCode(mockRequest, httpResult));
+ assertFalse(retryException.shouldRetry());
+ assertThat(
+ retryException.getCause().getMessage(),
+ containsString("Received an unsuccessful status code for request from inference entity id [id] status [402]")
+ );
+ assertThat(((ElasticsearchStatusException) retryException.getCause()).status(), is(RestStatus.PAYMENT_REQUIRED));
+ }
+
+ public void testBuildRateLimitErrorMessage() {
+ int statusCode = 429;
+ var statusLine = mock(StatusLine.class);
+ when(statusLine.getStatusCode()).thenReturn(statusCode);
+ var response = mock(HttpResponse.class);
+ when(response.getStatusLine()).thenReturn(statusLine);
+ var httpResult = new HttpResult(response, new byte[] {});
+
+ {
+ when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.REQUESTS_LIMIT)).thenReturn(
+ new BasicHeader(AzureAndOpenAiExternalResponseHandler.REQUESTS_LIMIT, "3000")
+ );
+ when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.REMAINING_REQUESTS)).thenReturn(
+ new BasicHeader(AzureAndOpenAiExternalResponseHandler.REMAINING_REQUESTS, "2999")
+ );
+ when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.TOKENS_LIMIT)).thenReturn(
+ new BasicHeader(AzureAndOpenAiExternalResponseHandler.TOKENS_LIMIT, "10000")
+ );
+ when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.REMAINING_TOKENS)).thenReturn(
+ new BasicHeader(AzureAndOpenAiExternalResponseHandler.REMAINING_TOKENS, "99800")
+ );
+
+ var error = AzureAndOpenAiExternalResponseHandler.buildRateLimitErrorMessage(httpResult);
+ assertThat(
+ error,
+ containsString("Token limit [10000], remaining tokens [99800]. Request limit [3000], remaining requests [2999]")
+ );
+ }
+
+ {
+ when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.TOKENS_LIMIT)).thenReturn(null);
+ when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.REMAINING_TOKENS)).thenReturn(null);
+ var error = AzureAndOpenAiExternalResponseHandler.buildRateLimitErrorMessage(httpResult);
+ assertThat(
+ error,
+ containsString("Token limit [unknown], remaining tokens [unknown]. Request limit [3000], remaining requests [2999]")
+ );
+ }
+
+ {
+ when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.REQUESTS_LIMIT)).thenReturn(null);
+ when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.REMAINING_REQUESTS)).thenReturn(
+ new BasicHeader(AzureAndOpenAiExternalResponseHandler.REMAINING_REQUESTS, "2999")
+ );
+ when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.TOKENS_LIMIT)).thenReturn(null);
+ when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.REMAINING_TOKENS)).thenReturn(null);
+ var error = AzureAndOpenAiExternalResponseHandler.buildRateLimitErrorMessage(httpResult);
+ assertThat(error, containsString("Remaining tokens [unknown]. Remaining requests [2999]"));
+ }
+
+ {
+ when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.REQUESTS_LIMIT)).thenReturn(null);
+ when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.REMAINING_REQUESTS)).thenReturn(
+ new BasicHeader(AzureAndOpenAiExternalResponseHandler.REMAINING_REQUESTS, "2999")
+ );
+ when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.TOKENS_LIMIT)).thenReturn(
+ new BasicHeader(AzureAndOpenAiExternalResponseHandler.TOKENS_LIMIT, "10000")
+ );
+ when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.REMAINING_TOKENS)).thenReturn(null);
+ var error = AzureAndOpenAiExternalResponseHandler.buildRateLimitErrorMessage(httpResult);
+ assertThat(
+ error,
+ containsString("Token limit [10000], remaining tokens [unknown]. Request limit [unknown], remaining requests [2999]")
+ );
+ }
+ }
+
+ private static HttpResult createContentTooLargeResult(int statusCode) {
+ return createResult(
+ statusCode,
+ "This model's maximum context length is 8192 tokens, however you requested 13531 tokens (13531 in your prompt;"
+ + "0 for the completion). Please reduce your prompt; or completion length."
+ );
+ }
+
+ private static HttpResult createResult(int statusCode, String message) {
+ var statusLine = mock(StatusLine.class);
+ when(statusLine.getStatusCode()).thenReturn(statusCode);
+ var httpResponse = mock(HttpResponse.class);
+ when(httpResponse.getStatusLine()).thenReturn(statusLine);
+
+ String responseJson = Strings.format("""
+ {
+ "error": {
+ "message": "%s",
+ "type": "content_too_large",
+ "param": null,
+ "code": null
+ }
+ }
+ """, message);
+
+ return new HttpResult(httpResponse, responseJson.getBytes(StandardCharsets.UTF_8));
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureaistudio/AzureAiStudioChatCompletionResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureaistudio/AzureAiStudioChatCompletionResponseEntityTests.java
new file mode 100644
index 0000000000000..7d5aafa181b19
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureaistudio/AzureAiStudioChatCompletionResponseEntityTests.java
@@ -0,0 +1,96 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.response.azureaistudio;
+
+import org.apache.http.HttpResponse;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioChatCompletionRequest;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType;
+import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider;
+import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModelTests;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.List;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+
+public class AzureAiStudioChatCompletionResponseEntityTests extends ESTestCase {
+
+ public void testCompletionResponse_FromTokenEndpoint() throws IOException {
+ var entity = new AzureAiStudioChatCompletionResponseEntity();
+ var model = AzureAiStudioChatCompletionModelTests.createModel(
+ "id",
+ "http://testopenai.local",
+ AzureAiStudioProvider.OPENAI,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey"
+ );
+ var request = new AzureAiStudioChatCompletionRequest(model, List.of("test input"));
+ var result = (ChatCompletionResults) entity.apply(
+ request,
+ new HttpResult(mock(HttpResponse.class), testTokenResponseJson.getBytes(StandardCharsets.UTF_8))
+ );
+
+ assertThat(result.getResults().size(), equalTo(1));
+ assertThat(result.getResults().get(0).content(), is("test input string"));
+ }
+
+ public void testCompletionResponse_FromRealtimeEndpoint() throws IOException {
+ var entity = new AzureAiStudioChatCompletionResponseEntity();
+ var model = AzureAiStudioChatCompletionModelTests.createModel(
+ "id",
+ "http://testmistral.local",
+ AzureAiStudioProvider.MISTRAL,
+ AzureAiStudioEndpointType.REALTIME,
+ "apikey"
+ );
+ var request = new AzureAiStudioChatCompletionRequest(model, List.of("test input"));
+ var result = (ChatCompletionResults) entity.apply(
+ request,
+ new HttpResult(mock(HttpResponse.class), testRealtimeResponseJson.getBytes(StandardCharsets.UTF_8))
+ );
+
+ assertThat(result.getResults().size(), equalTo(1));
+ assertThat(result.getResults().get(0).content(), is("test realtime response"));
+ }
+
+ private static String testRealtimeResponseJson = """
+ {
+ "output": "test realtime response"
+ }
+ """;
+
+ private static String testTokenResponseJson = """
+ {
+ "choices": [
+ {
+ "finish_reason": "stop",
+ "index": 0,
+ "message": {
+ "content": "test input string",
+ "role": "assistant",
+ "tool_calls": null
+ }
+ }
+ ],
+ "created": 1714006424,
+ "id": "f92b5b4d-0de3-4152-a3c6-5aae8a74555c",
+ "model": "",
+ "object": "chat.completion",
+ "usage": {
+ "completion_tokens": 35,
+ "prompt_tokens": 8,
+ "total_tokens": 43
+ }
+ }""";
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureaistudio/AzureAiStudioEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureaistudio/AzureAiStudioEmbeddingsResponseEntityTests.java
new file mode 100644
index 0000000000000..fd31743616e6e
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureaistudio/AzureAiStudioEmbeddingsResponseEntityTests.java
@@ -0,0 +1,60 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.response.azureaistudio;
+
+import org.apache.http.HttpResponse;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.request.Request;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.List;
+
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+
+/**
+ * Note - the underlying AzureAiStudioEmbeddingsResponseEntity uses the same
+ * response entity parser as OpenAI. This test just performs a smoke
+ * test of the wrapper
+ */
+public class AzureAiStudioEmbeddingsResponseEntityTests extends ESTestCase {
+ public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
+ String responseJson = """
+ {
+ "object": "list",
+ "data": [
+ {
+ "object": "embedding",
+ "index": 0,
+ "embedding": [
+ 0.014539449,
+ -0.015288644
+ ]
+ }
+ ],
+ "model": "text-embedding-ada-002-v2",
+ "usage": {
+ "prompt_tokens": 8,
+ "total_tokens": 8
+ }
+ }
+ """;
+
+ var entity = new AzureAiStudioEmbeddingsResponseEntity();
+
+ var parsedResults = (TextEmbeddingResults) entity.apply(
+ mock(Request.class),
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+ );
+
+ assertThat(parsedResults.embeddings(), is(List.of(new TextEmbeddingResults.Embedding(List.of(0.014539449F, -0.015288644F)))));
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereCompletionResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereCompletionResponseEntityTests.java
new file mode 100644
index 0000000000000..70e1656195c3c
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereCompletionResponseEntityTests.java
@@ -0,0 +1,159 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.response.cohere;
+
+import org.apache.http.HttpResponse;
+import org.elasticsearch.common.ParsingException;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.request.Request;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+
+public class CohereCompletionResponseEntityTests extends ESTestCase {
+
+ public void testFromResponse_CreatesResponseEntityForText() throws IOException {
+ String responseJson = """
+ {
+ "response_id": "some id",
+ "text": "result",
+ "generation_id": "some id",
+ "chat_history": [
+ {
+ "role": "USER",
+ "message": "some input"
+ },
+ {
+ "role": "CHATBOT",
+ "message": "result"
+ }
+ ],
+ "finish_reason": "COMPLETE",
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "input_tokens": 4,
+ "output_tokens": 191
+ },
+ "tokens": {
+ "input_tokens": 70,
+ "output_tokens": 191
+ }
+ }
+ }
+ """;
+
+ ChatCompletionResults chatCompletionResults = CohereCompletionResponseEntity.fromResponse(
+ mock(Request.class),
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+ );
+
+ assertThat(chatCompletionResults.getResults().size(), is(1));
+ assertThat(chatCompletionResults.getResults().get(0).content(), is("result"));
+ }
+
+ public void testFromResponse_FailsWhenTextIsNotPresent() {
+ String responseJson = """
+ {
+ "response_id": "some id",
+ "not_text": "result",
+ "generation_id": "some id",
+ "chat_history": [
+ {
+ "role": "USER",
+ "message": "some input"
+ },
+ {
+ "role": "CHATBOT",
+ "message": "result"
+ }
+ ],
+ "finish_reason": "COMPLETE",
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "input_tokens": 4,
+ "output_tokens": 191
+ },
+ "tokens": {
+ "input_tokens": 70,
+ "output_tokens": 191
+ }
+ }
+ }
+ """;
+
+ var thrownException = expectThrows(
+ IllegalStateException.class,
+ () -> CohereCompletionResponseEntity.fromResponse(
+ mock(Request.class),
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+ )
+ );
+
+ assertThat(thrownException.getMessage(), is("Failed to find required field [text] in Cohere chat response"));
+ }
+
+ public void testFromResponse_FailsWhenTextIsNotAString() {
+ String responseJson = """
+ {
+ "response_id": "some id",
+ "text": {
+ "text": "result"
+ },
+ "generation_id": "some id",
+ "chat_history": [
+ {
+ "role": "USER",
+ "message": "some input"
+ },
+ {
+ "role": "CHATBOT",
+ "message": "result"
+ }
+ ],
+ "finish_reason": "COMPLETE",
+ "meta": {
+ "api_version": {
+ "version": "1"
+ },
+ "billed_units": {
+ "input_tokens": 4,
+ "output_tokens": 191
+ },
+ "tokens": {
+ "input_tokens": 70,
+ "output_tokens": 191
+ }
+ }
+ }
+ """;
+
+ var thrownException = expectThrows(
+ ParsingException.class,
+ () -> CohereCompletionResponseEntity.fromResponse(
+ mock(Request.class),
+ new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+ )
+ );
+
+ assertThat(
+ thrownException.getMessage(),
+ is("Failed to parse object: expecting token of type [VALUE_STRING] but found [START_OBJECT]")
+ );
+ }
+}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java
index bf9fdbe7235b6..6f05ab79629e6 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java
@@ -11,6 +11,7 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.core.Booleans;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceResults;
@@ -32,6 +33,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveLong;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalTimeValue;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredSecureString;
@@ -193,6 +195,95 @@ public void testRemoveAsTypeMissingReturnsNull() {
assertThat(map.entrySet(), hasSize(3));
}
+ public void testRemoveAsOneOfTypes_Validation_WithCorrectTypes() {
+ Map map = new HashMap<>(Map.of("a", 5, "b", "a string", "c", Boolean.TRUE, "d", 1.0));
+ ValidationException validationException = new ValidationException();
+
+ Integer i = (Integer) ServiceUtils.removeAsOneOfTypes(map, "a", List.of(String.class, Integer.class), validationException);
+ assertEquals(Integer.valueOf(5), i);
+ assertNull(map.get("a")); // field has been removed
+
+ String str = (String) ServiceUtils.removeAsOneOfTypes(map, "b", List.of(Integer.class, String.class), validationException);
+ assertEquals("a string", str);
+ assertNull(map.get("b"));
+
+ Boolean b = (Boolean) ServiceUtils.removeAsOneOfTypes(map, "c", List.of(String.class, Boolean.class), validationException);
+ assertEquals(Boolean.TRUE, b);
+ assertNull(map.get("c"));
+
+ Double d = (Double) ServiceUtils.removeAsOneOfTypes(map, "d", List.of(Booleans.class, Double.class), validationException);
+ assertEquals(Double.valueOf(1.0), d);
+ assertNull(map.get("d"));
+
+ assertThat(map.entrySet(), empty());
+ }
+
+ public void testRemoveAsOneOfTypes_Validation_WithIncorrectType() {
+ Map map = new HashMap<>(Map.of("a", 5, "b", "a string", "c", Boolean.TRUE, "d", 5.0, "e", 5));
+
+ var validationException = new ValidationException();
+ Object result = ServiceUtils.removeAsOneOfTypes(map, "a", List.of(String.class, Boolean.class), validationException);
+ assertNull(result);
+ assertThat(validationException.validationErrors(), hasSize(1));
+ assertThat(
+ validationException.validationErrors().get(0),
+ containsString("field [a] is not of one of the expected types. The value [5] cannot be converted to one of [String, Boolean]")
+ );
+ assertNull(map.get("a"));
+
+ validationException = new ValidationException();
+ result = ServiceUtils.removeAsOneOfTypes(map, "b", List.of(Boolean.class, Integer.class), validationException);
+ assertNull(result);
+ assertThat(validationException.validationErrors(), hasSize(1));
+ assertThat(
+ validationException.validationErrors().get(0),
+ containsString(
+ "field [b] is not of one of the expected types. The value [a string] cannot be converted to one of [Boolean, Integer]"
+ )
+ );
+ assertNull(map.get("b"));
+
+ validationException = new ValidationException();
+ result = ServiceUtils.removeAsOneOfTypes(map, "c", List.of(String.class, Integer.class), validationException);
+ assertNull(result);
+ assertThat(validationException.validationErrors(), hasSize(1));
+ assertThat(
+ validationException.validationErrors().get(0),
+ containsString(
+ "field [c] is not of one of the expected types. The value [true] cannot be converted to one of [String, Integer]"
+ )
+ );
+ assertNull(map.get("c"));
+
+ validationException = new ValidationException();
+ result = ServiceUtils.removeAsOneOfTypes(map, "d", List.of(String.class, Boolean.class), validationException);
+ assertNull(result);
+ assertThat(validationException.validationErrors(), hasSize(1));
+ assertThat(
+ validationException.validationErrors().get(0),
+ containsString("field [d] is not of one of the expected types. The value [5.0] cannot be converted to one of [String, Boolean]")
+ );
+ assertNull(map.get("d"));
+
+ validationException = new ValidationException();
+ result = ServiceUtils.removeAsOneOfTypes(map, "e", List.of(String.class, Boolean.class), validationException);
+ assertNull(result);
+ assertThat(validationException.validationErrors(), hasSize(1));
+ assertThat(
+ validationException.validationErrors().get(0),
+ containsString("field [e] is not of one of the expected types. The value [5] cannot be converted to one of [String, Boolean]")
+ );
+ assertNull(map.get("e"));
+
+ assertThat(map.entrySet(), empty());
+ }
+
+ public void testRemoveAsOneOfTypesMissingReturnsNull() {
+ Map map = new HashMap<>(Map.of("a", 5, "b", "a string", "c", Boolean.TRUE));
+ assertNull(ServiceUtils.removeAsOneOfTypes(map, "missing", List.of(Integer.class), new ValidationException()));
+ assertThat(map.entrySet(), hasSize(3));
+ }
+
public void testConvertToUri_CreatesUri() {
var validation = new ValidationException();
var uri = convertToUri("www.elastic.co", "name", "scope", validation);
@@ -347,6 +438,22 @@ public void testExtractOptionalPositiveInt() {
assertThat(validation.validationErrors(), hasSize(1));
}
+ public void testExtractOptionalPositiveLong_IntegerValue() {
+ var validation = new ValidationException();
+ validation.addValidationError("previous error");
+ Map map = modifiableMap(Map.of("abc", 3));
+ assertEquals(Long.valueOf(3), extractOptionalPositiveLong(map, "abc", "scope", validation));
+ assertThat(validation.validationErrors(), hasSize(1));
+ }
+
+ public void testExtractOptionalPositiveLong() {
+ var validation = new ValidationException();
+ validation.addValidationError("previous error");
+ Map map = modifiableMap(Map.of("abc", 4_000_000_000L));
+ assertEquals(Long.valueOf(4_000_000_000L), extractOptionalPositiveLong(map, "abc", "scope", validation));
+ assertThat(validation.validationErrors(), hasSize(1));
+ }
+
public void testExtractOptionalEnum_ReturnsNull_WhenFieldDoesNotExist() {
var validation = new ValidationException();
Map map = modifiableMap(Map.of("key", "value"));
@@ -470,6 +577,127 @@ public void testExtractOptionalTimeValue_ReturnsNullAndAddsException_WhenTimeVal
);
}
+ public void testExtractOptionalDouble_ExtractsAsDoubleInRange() {
+ var validationException = new ValidationException();
+ Map map = modifiableMap(Map.of("key", 1.01));
+ var result = ServiceUtils.extractOptionalDoubleInRange(map, "key", 0.0, 2.0, "test_scope", validationException);
+ assertEquals(Double.valueOf(1.01), result);
+ assertTrue(map.isEmpty());
+ assertThat(validationException.validationErrors().size(), is(0));
+ }
+
+ public void testExtractOptionalDouble_InRange_ReturnsNullWhenKeyNotPresent() {
+ var validationException = new ValidationException();
+ Map map = modifiableMap(Map.of("key", 1.01));
+ var result = ServiceUtils.extractOptionalDoubleInRange(map, "other_key", 0.0, 2.0, "test_scope", validationException);
+ assertNull(result);
+ assertThat(map.size(), is(1));
+ assertThat(map.get("key"), is(1.01));
+ }
+
+ public void testExtractOptionalDouble_InRange_HasErrorWhenBelowMinValue() {
+ var validationException = new ValidationException();
+ Map map = modifiableMap(Map.of("key", -2.0));
+ var result = ServiceUtils.extractOptionalDoubleInRange(map, "key", 0.0, 2.0, "test_scope", validationException);
+ assertNull(result);
+ assertThat(validationException.validationErrors().size(), is(1));
+ assertThat(
+ validationException.validationErrors().get(0),
+ is("[test_scope] Invalid value [-2.0]. [key] must be a greater than or equal to [0.0]")
+ );
+ }
+
+ public void testExtractOptionalDouble_InRange_HasErrorWhenAboveMaxValue() {
+ var validationException = new ValidationException();
+ Map map = modifiableMap(Map.of("key", 12.0));
+ var result = ServiceUtils.extractOptionalDoubleInRange(map, "key", 0.0, 2.0, "test_scope", validationException);
+ assertNull(result);
+ assertThat(validationException.validationErrors().size(), is(1));
+ assertThat(
+ validationException.validationErrors().get(0),
+ is("[test_scope] Invalid value [12.0]. [key] must be a less than or equal to [2.0]")
+ );
+ }
+
+ public void testExtractOptionalDouble_InRange_DoesNotCheckMinWhenNull() {
+ var validationException = new ValidationException();
+ Map map = modifiableMap(Map.of("key", -2.0));
+ var result = ServiceUtils.extractOptionalDoubleInRange(map, "key", null, 2.0, "test_scope", validationException);
+ assertEquals(Double.valueOf(-2.0), result);
+ assertTrue(map.isEmpty());
+ assertThat(validationException.validationErrors().size(), is(0));
+ }
+
+ public void testExtractOptionalDouble_InRange_DoesNotCheckMaxWhenNull() {
+ var validationException = new ValidationException();
+ Map map = modifiableMap(Map.of("key", 12.0));
+ var result = ServiceUtils.extractOptionalDoubleInRange(map, "key", 0.0, null, "test_scope", validationException);
+ assertEquals(Double.valueOf(12.0), result);
+ assertTrue(map.isEmpty());
+ assertThat(validationException.validationErrors().size(), is(0));
+ }
+
+ public void testExtractOptionalFloat_ExtractsAFloat() {
+ Map map = modifiableMap(Map.of("key", 1.0f));
+ var result = ServiceUtils.extractOptionalFloat(map, "key");
+ assertThat(result, is(1.0f));
+ assertTrue(map.isEmpty());
+ }
+
+ public void testExtractOptionalFloat_ReturnsNullWhenKeyNotPresent() {
+ Map map = modifiableMap(Map.of("key", 1.0f));
+ var result = ServiceUtils.extractOptionalFloat(map, "other_key");
+ assertNull(result);
+ assertThat(map.size(), is(1));
+ assertThat(map.get("key"), is(1.0f));
+ }
+
+ public void testExtractRequiredEnum_ExtractsAEnum() {
+ ValidationException validationException = new ValidationException();
+ Map map = modifiableMap(Map.of("key", "ingest"));
+ var result = ServiceUtils.extractRequiredEnum(
+ map,
+ "key",
+ "testscope",
+ InputType::fromString,
+ EnumSet.allOf(InputType.class),
+ validationException
+ );
+ assertThat(result, is(InputType.INGEST));
+ }
+
+ public void testExtractRequiredEnum_ReturnsNullWhenEnumValueIsNotPresent() {
+ ValidationException validationException = new ValidationException();
+ Map map = modifiableMap(Map.of("key", "invalid"));
+ var result = ServiceUtils.extractRequiredEnum(
+ map,
+ "key",
+ "testscope",
+ InputType::fromString,
+ EnumSet.allOf(InputType.class),
+ validationException
+ );
+ assertNull(result);
+ assertThat(validationException.validationErrors().size(), is(1));
+ assertThat(validationException.validationErrors().get(0), containsString("Invalid value [invalid] received. [key] must be one of"));
+ }
+
+ public void testExtractRequiredEnum_HasValidationErrorOnMissingSetting() {
+ ValidationException validationException = new ValidationException();
+ Map map = modifiableMap(Map.of("key", "ingest"));
+ var result = ServiceUtils.extractRequiredEnum(
+ map,
+ "missing_key",
+ "testscope",
+ InputType::fromString,
+ EnumSet.allOf(InputType.class),
+ validationException
+ );
+ assertNull(result);
+ assertThat(validationException.validationErrors().size(), is(1));
+ assertThat(validationException.validationErrors().get(0), is("[testscope] does not contain the required setting [missing_key]"));
+ }
+
public void testGetEmbeddingSize_ReturnsError_WhenTextEmbeddingResults_IsEmpty() {
var service = mock(InferenceService.class);
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java
new file mode 100644
index 0000000000000..51593c8d052d9
--- /dev/null
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java
@@ -0,0 +1,1177 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureaistudio;
+
+import org.apache.http.HttpHeaders;
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkedInferenceServiceResults;
+import org.elasticsearch.inference.ChunkingOptions;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.SimilarityMeasure;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.http.MockResponse;
+import org.elasticsearch.test.http.MockWebServer;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
+import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults;
+import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
+import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModel;
+import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModelTests;
+import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionServiceSettingsTests;
+import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettings;
+import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettingsTests;
+import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel;
+import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModelTests;
+import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettingsTests;
+import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsTaskSettingsTests;
+import org.hamcrest.CoreMatchers;
+import org.hamcrest.MatcherAssert;
+import org.hamcrest.Matchers;
+import org.junit.After;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.net.URISyntaxException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.TimeUnit;
+
+import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
+import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
+import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
+import static org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils.API_KEY_HEADER;
+import static org.elasticsearch.xpack.inference.results.ChunkedTextEmbeddingResultsTests.asMapWithListsInsteadOfArrays;
+import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
+import static org.elasticsearch.xpack.inference.services.Utils.getInvalidModel;
+import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.API_KEY_FIELD;
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.when;
+
+public class AzureAiStudioServiceTests extends ESTestCase {
+ private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
+ private final MockWebServer webServer = new MockWebServer();
+ private ThreadPool threadPool;
+ private HttpClientManager clientManager;
+
+ @Before
+ public void init() throws Exception {
+ webServer.start();
+ threadPool = createThreadPool(inferenceUtilityPool());
+ clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
+ }
+
+ @After
+ public void shutdown() throws IOException {
+ clientManager.close();
+ terminate(threadPool);
+ webServer.close();
+ }
+
+ public void testParseRequestConfig_CreatesAnAzureAiStudioEmbeddingsModel() throws IOException {
+ try (var service = createService()) {
+ ActionListener modelVerificationListener = ActionListener.wrap(model -> {
+ assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class));
+
+ var embeddingsModel = (AzureAiStudioEmbeddingsModel) model;
+ assertThat(embeddingsModel.getServiceSettings().target(), is("http://target.local"));
+ assertThat(embeddingsModel.getServiceSettings().provider(), is(AzureAiStudioProvider.OPENAI));
+ assertThat(embeddingsModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN));
+ assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
+ assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
+ }, exception -> fail("Unexpected exception: " + exception));
+
+ service.parseRequestConfig(
+ "id",
+ TaskType.TEXT_EMBEDDING,
+ getRequestConfigMap(
+ getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", null, null, null, null),
+ getEmbeddingsTaskSettingsMap("user"),
+ getSecretSettingsMap("secret")
+ ),
+ Set.of(),
+ modelVerificationListener
+ );
+ }
+ }
+
+ public void testParseRequestConfig_CreatesAnAzureAiStudioChatCompletionModel() throws IOException {
+ try (var service = createService()) {
+ ActionListener modelVerificationListener = ActionListener.wrap(model -> {
+ assertThat(model, instanceOf(AzureAiStudioChatCompletionModel.class));
+
+ var completionModel = (AzureAiStudioChatCompletionModel) model;
+ assertThat(completionModel.getServiceSettings().target(), is("http://target.local"));
+ assertThat(completionModel.getServiceSettings().provider(), is(AzureAiStudioProvider.OPENAI));
+ assertThat(completionModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN));
+ assertThat(completionModel.getSecretSettings().apiKey().toString(), is("secret"));
+ assertNull(completionModel.getTaskSettings().temperature());
+ assertTrue(completionModel.getTaskSettings().doSample());
+ }, exception -> fail("Unexpected exception: " + exception));
+
+ service.parseRequestConfig(
+ "id",
+ TaskType.COMPLETION,
+ getRequestConfigMap(
+ getChatCompletionServiceSettingsMap("http://target.local", "openai", "token"),
+ getChatCompletionTaskSettingsMap(null, null, true, null),
+ getSecretSettingsMap("secret")
+ ),
+ Set.of(),
+ modelVerificationListener
+ );
+ }
+ }
+
+ public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException {
+ try (var service = createService()) {
+ ActionListener modelVerificationListener = ActionListener.wrap(
+ model -> fail("Expected exception, but got model: " + model),
+ exception -> {
+ assertThat(exception, instanceOf(ElasticsearchStatusException.class));
+ assertThat(exception.getMessage(), is("The [azureaistudio] service does not support task type [sparse_embedding]"));
+ }
+ );
+
+ service.parseRequestConfig(
+ "id",
+ TaskType.SPARSE_EMBEDDING,
+ getRequestConfigMap(
+ getChatCompletionServiceSettingsMap("http://target.local", "openai", "token"),
+ getChatCompletionTaskSettingsMap(null, null, true, null),
+ getSecretSettingsMap("secret")
+ ),
+ Set.of(),
+ modelVerificationListener
+ );
+ }
+ }
+
+ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException {
+ try (var service = createService()) {
+ var config = getRequestConfigMap(
+ getChatCompletionServiceSettingsMap("http://target.local", "openai", "token"),
+ getChatCompletionTaskSettingsMap(null, null, true, null),
+ getSecretSettingsMap("secret")
+ );
+ config.put("extra_key", "value");
+
+ ActionListener modelVerificationListener = ActionListener.wrap(
+ model -> fail("Expected exception, but got model: " + model),
+ exception -> {
+ assertThat(exception, instanceOf(ElasticsearchStatusException.class));
+ assertThat(
+ exception.getMessage(),
+ is("Model configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service")
+ );
+ }
+ );
+
+ service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), modelVerificationListener);
+ }
+ }
+
+ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingServiceSettingsMap() throws IOException {
+ try (var service = createService()) {
+ var serviceSettings = getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", null, null, null, null);
+ serviceSettings.put("extra_key", "value");
+
+ var config = getRequestConfigMap(serviceSettings, getEmbeddingsTaskSettingsMap("user"), getSecretSettingsMap("secret"));
+
+ ActionListener modelVerificationListener = ActionListener.wrap(
+ model -> fail("Expected exception, but got model: " + model),
+ exception -> {
+ assertThat(exception, instanceOf(ElasticsearchStatusException.class));
+ assertThat(
+ exception.getMessage(),
+ is("Model configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service")
+ );
+ }
+ );
+
+ service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener);
+ }
+ }
+
+ public void testParseRequestConfig_ThrowsWhenDimsSetByUserExistsInEmbeddingServiceSettingsMap() throws IOException {
+ try (var service = createService()) {
+ var config = getRequestConfigMap(
+ getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", 1024, true, null, null),
+ getEmbeddingsTaskSettingsMap("user"),
+ getSecretSettingsMap("secret")
+ );
+
+ ActionListener modelVerificationListener = ActionListener.wrap(
+ model -> fail("Expected exception, but got model: " + model),
+ exception -> {
+ assertThat(exception, instanceOf(ValidationException.class));
+ assertThat(
+ exception.getMessage(),
+ containsString("[service_settings] does not allow the setting [dimensions_set_by_user]")
+ );
+ }
+ );
+
+ service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener);
+ }
+ }
+
+ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingTaskSettingsMap() throws IOException {
+ try (var service = createService()) {
+ var taskSettings = getEmbeddingsTaskSettingsMap("user");
+ taskSettings.put("extra_key", "value");
+
+ var config = getRequestConfigMap(
+ getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", null, null, null, null),
+ taskSettings,
+ getSecretSettingsMap("secret")
+ );
+
+ ActionListener modelVerificationListener = ActionListener.wrap(
+ model -> fail("Expected exception, but got model: " + model),
+ exception -> {
+ assertThat(exception, instanceOf(ElasticsearchStatusException.class));
+ assertThat(
+ exception.getMessage(),
+ is("Model configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service")
+ );
+ }
+ );
+
+ service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener);
+ }
+ }
+
+ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingSecretSettingsMap() throws IOException {
+ try (var service = createService()) {
+ var secretSettings = getSecretSettingsMap("secret");
+ secretSettings.put("extra_key", "value");
+
+ var config = getRequestConfigMap(
+ getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", null, null, null, null),
+ getEmbeddingsTaskSettingsMap("user"),
+ secretSettings
+ );
+
+ ActionListener modelVerificationListener = ActionListener.wrap(
+ model -> fail("Expected exception, but got model: " + model),
+ exception -> {
+ assertThat(exception, instanceOf(ElasticsearchStatusException.class));
+ assertThat(
+ exception.getMessage(),
+ is("Model configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service")
+ );
+ }
+ );
+
+ service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener);
+ }
+ }
+
+ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInChatCompletionServiceSettingsMap() throws IOException {
+ try (var service = createService()) {
+ var serviceSettings = getChatCompletionServiceSettingsMap("http://target.local", "openai", "token");
+ serviceSettings.put("extra_key", "value");
+
+ var config = getRequestConfigMap(
+ serviceSettings,
+ getChatCompletionTaskSettingsMap(null, 2.0, null, null),
+ getSecretSettingsMap("secret")
+ );
+
+ ActionListener modelVerificationListener = ActionListener.wrap(
+ model -> fail("Expected exception, but got model: " + model),
+ exception -> {
+ assertThat(exception, instanceOf(ElasticsearchStatusException.class));
+ assertThat(
+ exception.getMessage(),
+ is("Model configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service")
+ );
+ }
+ );
+
+ service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), modelVerificationListener);
+ }
+ }
+
+ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInChatCompletionTaskSettingsMap() throws IOException {
+ try (var service = createService()) {
+ var taskSettings = getChatCompletionTaskSettingsMap(null, 2.0, null, null);
+ taskSettings.put("extra_key", "value");
+
+ var config = getRequestConfigMap(
+ getChatCompletionServiceSettingsMap("http://target.local", "openai", "token"),
+ taskSettings,
+ getSecretSettingsMap("secret")
+ );
+
+ ActionListener modelVerificationListener = ActionListener.wrap(
+ model -> fail("Expected exception, but got model: " + model),
+ exception -> {
+ assertThat(exception, instanceOf(ElasticsearchStatusException.class));
+ assertThat(
+ exception.getMessage(),
+ is("Model configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service")
+ );
+ }
+ );
+
+ service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), modelVerificationListener);
+ }
+ }
+
+ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInChatCompletionSecretSettingsMap() throws IOException {
+ try (var service = createService()) {
+ var secretSettings = getSecretSettingsMap("secret");
+ secretSettings.put("extra_key", "value");
+
+ var config = getRequestConfigMap(
+ getChatCompletionServiceSettingsMap("http://target.local", "openai", "token"),
+ getChatCompletionTaskSettingsMap(null, 2.0, null, null),
+ secretSettings
+ );
+
+ ActionListener modelVerificationListener = ActionListener.wrap(
+ model -> fail("Expected exception, but got model: " + model),
+ exception -> {
+ assertThat(exception, instanceOf(ElasticsearchStatusException.class));
+ assertThat(
+ exception.getMessage(),
+ is("Model configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service")
+ );
+ }
+ );
+
+ service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), modelVerificationListener);
+ }
+ }
+
+ public void testParseRequestConfig_ThrowsWhenProviderIsNotValidForEmbeddings() throws IOException {
+ try (var service = createService()) {
+ var serviceSettings = getEmbeddingsServiceSettingsMap("http://target.local", "databricks", "token", null, null, null, null);
+
+ var config = getRequestConfigMap(serviceSettings, getEmbeddingsTaskSettingsMap("user"), getSecretSettingsMap("secret"));
+
+ ActionListener modelVerificationListener = ActionListener.wrap(
+ model -> fail("Expected exception, but got model: " + model),
+ exception -> {
+ assertThat(exception, instanceOf(ElasticsearchStatusException.class));
+ assertThat(exception.getMessage(), is("The [text_embedding] task type for provider [databricks] is not available"));
+ }
+ );
+
+ service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener);
+ }
+ }
+
+ public void testParseRequestConfig_ThrowsWhenEndpointTypeIsNotValidForEmbeddingsProvider() throws IOException {
+ try (var service = createService()) {
+ var serviceSettings = getEmbeddingsServiceSettingsMap("http://target.local", "openai", "realtime", null, null, null, null);
+
+ var config = getRequestConfigMap(serviceSettings, getEmbeddingsTaskSettingsMap("user"), getSecretSettingsMap("secret"));
+
+ ActionListener modelVerificationListener = ActionListener.wrap(
+ model -> fail("Expected exception, but got model: " + model),
+ exception -> {
+ assertThat(exception, instanceOf(ElasticsearchStatusException.class));
+ assertThat(
+ exception.getMessage(),
+ is("The [realtime] endpoint type with [text_embedding] task type for provider [openai] is not available")
+ );
+ }
+ );
+
+ service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener);
+ }
+ }
+
+ public void testParseRequestConfig_ThrowsWhenEndpointTypeIsNotValidForChatCompletionProvider() throws IOException {
+ try (var service = createService()) {
+ var serviceSettings = getChatCompletionServiceSettingsMap("http://target.local", "openai", "realtime");
+
+ var config = getRequestConfigMap(
+ serviceSettings,
+ getChatCompletionTaskSettingsMap(null, null, null, null),
+ getSecretSettingsMap("secret")
+ );
+
+ ActionListener modelVerificationListener = ActionListener.wrap(
+ model -> fail("Expected exception, but got model: " + model),
+ exception -> {
+ assertThat(exception, instanceOf(ElasticsearchStatusException.class));
+ assertThat(
+ exception.getMessage(),
+ is("The [realtime] endpoint type with [completion] task type for provider [openai] is not available")
+ );
+ }
+ );
+
+ service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), modelVerificationListener);
+ }
+ }
+
+ public void testParsePersistedConfig_CreatesAnAzureAiStudioEmbeddingsModel() throws IOException {
+ try (var service = createService()) {
+ var config = getPersistedConfigMap(
+ getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", 1024, true, 512, null),
+ getEmbeddingsTaskSettingsMap("user"),
+ getSecretSettingsMap("secret")
+ );
+
+ var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets());
+
+ assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class));
+
+ var embeddingsModel = (AzureAiStudioEmbeddingsModel) model;
+ assertThat(embeddingsModel.getServiceSettings().target(), is("http://target.local"));
+ assertThat(embeddingsModel.getServiceSettings().provider(), is(AzureAiStudioProvider.OPENAI));
+ assertThat(embeddingsModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN));
+ assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024));
+ assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(true));
+ assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512));
+ assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
+ assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
+ }
+ }
+
+ public void testParsePersistedConfig_CreatesAnAzureAiStudioChatCompletionModel() throws IOException {
+ try (var service = createService()) {
+ var config = getPersistedConfigMap(
+ getChatCompletionServiceSettingsMap("http://target.local", "openai", "token"),
+ getChatCompletionTaskSettingsMap(1.0, 2.0, true, 512),
+ getSecretSettingsMap("secret")
+ );
+
+ var model = service.parsePersistedConfigWithSecrets("id", TaskType.COMPLETION, config.config(), config.secrets());
+
+ assertThat(model, instanceOf(AzureAiStudioChatCompletionModel.class));
+
+ var chatCompletionModel = (AzureAiStudioChatCompletionModel) model;
+ assertThat(chatCompletionModel.getServiceSettings().target(), is("http://target.local"));
+ assertThat(chatCompletionModel.getServiceSettings().provider(), is(AzureAiStudioProvider.OPENAI));
+ assertThat(chatCompletionModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN));
+ assertThat(chatCompletionModel.getTaskSettings().temperature(), is(1.0));
+ assertThat(chatCompletionModel.getTaskSettings().topP(), is(2.0));
+ assertThat(chatCompletionModel.getTaskSettings().doSample(), is(true));
+ assertThat(chatCompletionModel.getTaskSettings().maxNewTokens(), is(512));
+ }
+ }
+
+ public void testParsePersistedConfig_ThrowsUnsupportedModelType() throws IOException {
+ try (var service = createService()) {
+ ActionListener modelVerificationListener = ActionListener.wrap(
+ model -> fail("Expected exception, but got model: " + model),
+ exception -> {
+ assertThat(exception, instanceOf(ElasticsearchStatusException.class));
+ assertThat(exception.getMessage(), is("The [azureaistudio] service does not support task type [sparse_embedding]"));
+ }
+ );
+
+ service.parseRequestConfig(
+ "id",
+ TaskType.SPARSE_EMBEDDING,
+ getRequestConfigMap(
+ getChatCompletionServiceSettingsMap("http://target.local", "openai", "token"),
+ getChatCompletionTaskSettingsMap(null, null, true, null),
+ getSecretSettingsMap("secret")
+ ),
+ Set.of(),
+ modelVerificationListener
+ );
+ }
+ }
+
+ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException {
+ try (var service = createService()) {
+ var config = getPersistedConfigMap(
+ getChatCompletionServiceSettingsMap("http://target.local", "openai", "token"),
+ getChatCompletionTaskSettingsMap(1.0, 2.0, true, 512),
+ getSecretSettingsMap("secret")
+ );
+
+ var thrownException = expectThrows(
+ ElasticsearchStatusException.class,
+ () -> service.parsePersistedConfigWithSecrets("id", TaskType.SPARSE_EMBEDDING, config.config(), config.secrets())
+ );
+
+ assertThat(
+ thrownException.getMessage(),
+ is("Failed to parse stored model [id] for [azureaistudio] service, please delete and add the service again")
+ );
+ }
+ }
+
+ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException {
+ try (var service = createService()) {
+ var serviceSettings = getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", 1024, true, 512, null);
+ var taskSettings = getEmbeddingsTaskSettingsMap("user");
+ var secretSettings = getSecretSettingsMap("secret");
+ var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings);
+ config.config().put("extra_key", "value");
+
+ var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets());
+
+ assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class));
+ }
+ }
+
+ public void testParsePersistedConfig_DoesNotThrowWhenExtraKeyExistsInEmbeddingServiceSettingsMap() throws IOException {
+ try (var service = createService()) {
+ var serviceSettings = getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", 1024, true, 512, null);
+ serviceSettings.put("extra_key", "value");
+
+ var taskSettings = getEmbeddingsTaskSettingsMap("user");
+ var secretSettings = getSecretSettingsMap("secret");
+ var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings);
+
+ var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets());
+
+ assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class));
+ }
+ }
+
+ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInEmbeddingTaskSettingsMap() throws IOException {
+ try (var service = createService()) {
+ var serviceSettings = getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", 1024, true, 512, null);
+ var taskSettings = getEmbeddingsTaskSettingsMap("user");
+ taskSettings.put("extra_key", "value");
+
+ var secretSettings = getSecretSettingsMap("secret");
+ var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings);
+
+ var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets());
+
+ assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class));
+ }
+ }
+
+ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInEmbeddingSecretSettingsMap() throws IOException {
+ try (var service = createService()) {
+ var serviceSettings = getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", 1024, true, 512, null);
+ var taskSettings = getEmbeddingsTaskSettingsMap("user");
+ var secretSettings = getSecretSettingsMap("secret");
+ secretSettings.put("extra_key", "value");
+
+ var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings);
+
+ var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets());
+
+ assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class));
+ }
+ }
+
+ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInChatCompletionServiceSettingsMap() throws IOException {
+ try (var service = createService()) {
+ var serviceSettings = getChatCompletionServiceSettingsMap("http://target.local", "openai", "token");
+ serviceSettings.put("extra_key", "value");
+ var taskSettings = getChatCompletionTaskSettingsMap(1.0, 2.0, true, 512);
+ var secretSettings = getSecretSettingsMap("secret");
+ var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings);
+
+ var model = service.parsePersistedConfigWithSecrets("id", TaskType.COMPLETION, config.config(), config.secrets());
+
+ assertThat(model, instanceOf(AzureAiStudioChatCompletionModel.class));
+ }
+ }
+
+ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInChatCompletionTaskSettingsMap() throws IOException {
+ try (var service = createService()) {
+ var serviceSettings = getChatCompletionServiceSettingsMap("http://target.local", "openai", "token");
+ var taskSettings = getChatCompletionTaskSettingsMap(1.0, 2.0, true, 512);
+ taskSettings.put("extra_key", "value");
+ var secretSettings = getSecretSettingsMap("secret");
+ var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings);
+
+ var model = service.parsePersistedConfigWithSecrets("id", TaskType.COMPLETION, config.config(), config.secrets());
+
+ assertThat(model, instanceOf(AzureAiStudioChatCompletionModel.class));
+ }
+ }
+
+ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInChatCompletionSecretSettingsMap() throws IOException {
+ try (var service = createService()) {
+ var serviceSettings = getChatCompletionServiceSettingsMap("http://target.local", "openai", "token");
+ var taskSettings = getChatCompletionTaskSettingsMap(1.0, 2.0, true, 512);
+ var secretSettings = getSecretSettingsMap("secret");
+ secretSettings.put("extra_key", "value");
+ var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings);
+
+ var model = service.parsePersistedConfigWithSecrets("id", TaskType.COMPLETION, config.config(), config.secrets());
+
+ assertThat(model, instanceOf(AzureAiStudioChatCompletionModel.class));
+ }
+ }
+
+ public void testParsePersistedConfig_WithoutSecretsCreatesEmbeddingsModel() throws IOException {
+ try (var service = createService()) {
+ var config = getPersistedConfigMap(
+ getEmbeddingsServiceSettingsMap("http://target.local", "openai", "token", 1024, true, 512, null),
+ getEmbeddingsTaskSettingsMap("user"),
+ Map.of()
+ );
+
+ var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, config.config());
+
+ assertThat(model, instanceOf(AzureAiStudioEmbeddingsModel.class));
+
+ var embeddingsModel = (AzureAiStudioEmbeddingsModel) model;
+ assertThat(embeddingsModel.getServiceSettings().target(), is("http://target.local"));
+ assertThat(embeddingsModel.getServiceSettings().provider(), is(AzureAiStudioProvider.OPENAI));
+ assertThat(embeddingsModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN));
+ assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024));
+ assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(true));
+ assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512));
+ assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
+ }
+ }
+
+ public void testParsePersistedConfig_WithoutSecretsCreatesChatCompletionModel() throws IOException {
+ try (var service = createService()) {
+ var config = getPersistedConfigMap(
+ getChatCompletionServiceSettingsMap("http://target.local", "openai", "token"),
+ getChatCompletionTaskSettingsMap(1.0, 2.0, true, 512),
+ Map.of()
+ );
+
+ var model = service.parsePersistedConfig("id", TaskType.COMPLETION, config.config());
+
+ assertThat(model, instanceOf(AzureAiStudioChatCompletionModel.class));
+
+ var chatCompletionModel = (AzureAiStudioChatCompletionModel) model;
+ assertThat(chatCompletionModel.getServiceSettings().target(), is("http://target.local"));
+ assertThat(chatCompletionModel.getServiceSettings().provider(), is(AzureAiStudioProvider.OPENAI));
+ assertThat(chatCompletionModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN));
+ assertThat(chatCompletionModel.getTaskSettings().temperature(), is(1.0));
+ assertThat(chatCompletionModel.getTaskSettings().topP(), is(2.0));
+ assertThat(chatCompletionModel.getTaskSettings().doSample(), is(true));
+ assertThat(chatCompletionModel.getTaskSettings().maxNewTokens(), is(512));
+ }
+ }
+
+ public void testCheckModelConfig_ForEmbeddingsModel_Works() throws IOException {
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+ try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testEmbeddingResultJson));
+
+ var model = AzureAiStudioEmbeddingsModelTests.createModel(
+ "id",
+ getUrl(webServer),
+ AzureAiStudioProvider.OPENAI,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey",
+ null,
+ false,
+ null,
+ null,
+ null,
+ null
+ );
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ service.checkModelConfig(model, listener);
+
+ var result = listener.actionGet(TIMEOUT);
+ assertThat(
+ result,
+ is(
+ AzureAiStudioEmbeddingsModelTests.createModel(
+ "id",
+ getUrl(webServer),
+ AzureAiStudioProvider.OPENAI,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey",
+ 2,
+ false,
+ null,
+ SimilarityMeasure.DOT_PRODUCT,
+ null,
+ null
+ )
+ )
+ );
+
+ assertThat(webServer.requests(), hasSize(1));
+
+ var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+ MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"))));
+ }
+ }
+
+ public void testCheckModelConfig_ForEmbeddingsModel_ThrowsIfEmbeddingSizeDoesNotMatchValueSetByUser() throws IOException {
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+ try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testEmbeddingResultJson));
+
+ var model = AzureAiStudioEmbeddingsModelTests.createModel(
+ "id",
+ getUrl(webServer),
+ AzureAiStudioProvider.OPENAI,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey",
+ 3,
+ true,
+ null,
+ null,
+ null,
+ null
+ );
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ service.checkModelConfig(model, listener);
+
+ var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
+ assertThat(
+ exception.getMessage(),
+ is(
+ "The retrieved embeddings size [2] does not match the size specified in the settings [3]. "
+ + "Please recreate the [id] configuration with the correct dimensions"
+ )
+ );
+
+ assertThat(webServer.requests(), hasSize(1));
+
+ var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+ MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "dimensions", 3)));
+ }
+ }
+
+ public void testCheckModelConfig_WorksForChatCompletionsModel() throws IOException {
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+ try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testChatCompletionResultJson));
+
+ var model = AzureAiStudioChatCompletionModelTests.createModel(
+ "id",
+ getUrl(webServer),
+ AzureAiStudioProvider.OPENAI,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey",
+ null,
+ null,
+ null,
+ null,
+ null
+ );
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ service.checkModelConfig(model, listener);
+
+ var result = listener.actionGet(TIMEOUT);
+ assertThat(
+ result,
+ is(
+ AzureAiStudioChatCompletionModelTests.createModel(
+ "id",
+ getUrl(webServer),
+ AzureAiStudioProvider.OPENAI,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey",
+ null,
+ null,
+ null,
+ AzureAiStudioChatCompletionTaskSettings.DEFAULT_MAX_NEW_TOKENS,
+ null
+ )
+ )
+ );
+ }
+ }
+
+ public void testInfer_ThrowsErrorWhenModelIsNotAzureAiStudioModel() throws IOException {
+ var sender = mock(Sender.class);
+
+ var factory = mock(HttpRequestSender.Factory.class);
+ when(factory.createSender(anyString())).thenReturn(sender);
+
+ var mockModel = getInvalidModel("model_id", "service_name");
+
+ try (var service = new AzureAiStudioService(factory, createWithEmptySettings(threadPool))) {
+ PlainActionFuture listener = new PlainActionFuture<>();
+ service.infer(
+ mockModel,
+ null,
+ List.of(""),
+ new HashMap<>(),
+ InputType.INGEST,
+ InferenceAction.Request.DEFAULT_TIMEOUT,
+ listener
+ );
+
+ var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
+ assertThat(
+ thrownException.getMessage(),
+ is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
+ );
+
+ verify(factory, times(1)).createSender(anyString());
+ verify(sender, times(1)).start();
+ }
+
+ verify(sender, times(1)).close();
+ verifyNoMoreInteractions(factory);
+ verifyNoMoreInteractions(sender);
+ }
+
+ public void testChunkedInfer_Embeddings_CallsInfer_ConvertsFloatResponse() throws IOException, URISyntaxException {
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+ try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
+
+ String responseJson = """
+ {
+ "object": "list",
+ "data": [
+ {
+ "object": "embedding",
+ "index": 0,
+ "embedding": [
+ 0.0123,
+ -0.0123
+ ]
+ }
+ ],
+ "model": "text-embedding-ada-002-v2",
+ "usage": {
+ "prompt_tokens": 8,
+ "total_tokens": 8
+ }
+ }
+ """;
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+ var model = AzureAiStudioEmbeddingsModelTests.createModel(
+ "id",
+ getUrl(webServer),
+ AzureAiStudioProvider.OPENAI,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey",
+ null,
+ false,
+ null,
+ null,
+ "user",
+ null
+ );
+ PlainActionFuture> listener = new PlainActionFuture<>();
+ service.chunkedInfer(
+ model,
+ List.of("abc"),
+ new HashMap<>(),
+ InputType.INGEST,
+ new ChunkingOptions(null, null),
+ InferenceAction.Request.DEFAULT_TIMEOUT,
+ listener
+ );
+
+ var result = listener.actionGet(TIMEOUT).get(0);
+ assertThat(result, CoreMatchers.instanceOf(ChunkedTextEmbeddingResults.class));
+
+ assertThat(
+ asMapWithListsInsteadOfArrays((ChunkedTextEmbeddingResults) result),
+ Matchers.is(
+ Map.of(
+ ChunkedTextEmbeddingResults.FIELD_NAME,
+ List.of(
+ Map.of(
+ ChunkedNlpInferenceResults.TEXT,
+ "abc",
+ ChunkedNlpInferenceResults.INFERENCE,
+ List.of((double) 0.0123f, (double) -0.0123f)
+ )
+ )
+ )
+ )
+ );
+ assertThat(webServer.requests(), hasSize(1));
+ assertNull(webServer.requests().get(0).getUri().getQuery());
+ assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
+ assertThat(webServer.requests().get(0).getHeader(API_KEY_HEADER), equalTo("apikey"));
+
+ var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+ assertThat(requestMap.size(), Matchers.is(2));
+ assertThat(requestMap.get("input"), Matchers.is(List.of("abc")));
+ assertThat(requestMap.get("user"), Matchers.is("user"));
+ }
+ }
+
+ public void testInfer_ThrowsWhenQueryIsPresent() throws IOException {
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+ try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testChatCompletionResultJson));
+
+ var model = AzureAiStudioChatCompletionModelTests.createModel(
+ "id",
+ getUrl(webServer),
+ AzureAiStudioProvider.OPENAI,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey"
+ );
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ UnsupportedOperationException exception = expectThrows(
+ UnsupportedOperationException.class,
+ () -> service.infer(
+ model,
+ "should throw",
+ List.of("abc"),
+ new HashMap<>(),
+ InputType.INGEST,
+ InferenceAction.Request.DEFAULT_TIMEOUT,
+ listener
+ )
+ );
+
+ assertThat(exception.getMessage(), is("Azure AI Studio service does not support inference with query input"));
+ }
+ }
+
+ public void testInfer_WithChatCompletionModel() throws IOException {
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+ try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testChatCompletionResultJson));
+
+ var model = AzureAiStudioChatCompletionModelTests.createModel(
+ "id",
+ getUrl(webServer),
+ AzureAiStudioProvider.OPENAI,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey"
+ );
+
+ PlainActionFuture listener = new PlainActionFuture<>();
+ service.infer(
+ model,
+ null,
+ List.of("abc"),
+ new HashMap<>(),
+ InputType.INGEST,
+ InferenceAction.Request.DEFAULT_TIMEOUT,
+ listener
+ );
+
+ var result = listener.actionGet(TIMEOUT);
+ assertThat(result, CoreMatchers.instanceOf(ChatCompletionResults.class));
+
+ var completionResults = (ChatCompletionResults) result;
+ assertThat(completionResults.getResults().size(), is(1));
+ assertThat(completionResults.getResults().get(0).content(), is("test completion content"));
+ }
+ }
+
+ public void testInfer_UnauthorisedResponse() throws IOException {
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+ try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
+
+ String responseJson = """
+ {
+ "error": {
+ "message": "Incorrect API key provided:",
+ "type": "invalid_request_error",
+ "param": null,
+ "code": "invalid_api_key"
+ }
+ }
+ """;
+ webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson));
+
+ var model = AzureAiStudioEmbeddingsModelTests.createModel(
+ "id",
+ getUrl(webServer),
+ AzureAiStudioProvider.OPENAI,
+ AzureAiStudioEndpointType.TOKEN,
+ "apikey",
+ null,
+ false,
+ null,
+ null,
+ "user",
+ null
+ );
+ PlainActionFuture listener = new PlainActionFuture<>();
+ service.infer(
+ model,
+ null,
+ List.of("abc"),
+ new HashMap<>(),
+ InputType.INGEST,
+ InferenceAction.Request.DEFAULT_TIMEOUT,
+ listener
+ );
+
+ var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+ assertThat(error.getMessage(), containsString("Received an authentication error status code for request"));
+ assertThat(error.getMessage(), containsString("Error message: [Incorrect API key provided:]"));
+ assertThat(webServer.requests(), hasSize(1));
+ }
+ }
+
+ // ----------------------------------------------------------------
+
+ private AzureAiStudioService createService() {
+ return new AzureAiStudioService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool));
+ }
+
+ private Map getRequestConfigMap(
+ Map serviceSettings,
+ Map taskSettings,
+ Map secretSettings
+ ) {
+ var builtServiceSettings = new HashMap<>();
+ builtServiceSettings.putAll(serviceSettings);
+ builtServiceSettings.putAll(secretSettings);
+
+ return new HashMap<>(
+ Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)
+ );
+ }
+
+ private record PeristedConfigRecord(Map config, Map secrets) {}
+
+ private PeristedConfigRecord getPersistedConfigMap(
+ Map serviceSettings,
+ Map taskSettings,
+ Map