Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ public <T> CompletableFuture<HttpResponse<T>> send(HttpRequest request, BodyHand
* @return a {@link CompletableFuture} of the HTTP response
*/
public <T> CompletableFuture<HttpResponse<T>> send(HttpRequest request, BodyHandler<T> handler, Executor executor) {
return new InterceptorChain(delegate, interceptors).proceed(request, handler,
requireNonNullElse(executor, executor()));
return new InterceptorChain(delegate, interceptors).proceed(request, handler, requireNonNullElse(executor, executor()));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
public abstract class WatsonxParameters {
protected final String projectId;
protected final String spaceId;
protected final String transactionId;

public WatsonxParameters(Builder<?> builder) {
projectId = builder.projectId;
spaceId = builder.spaceId;
transactionId = builder.transactionId;
}

/**
Expand All @@ -34,6 +36,15 @@ public String getSpaceId() {
return spaceId;
}

/**
* Returns the transaction id.
*
* @return transaction id value
*/
public String getTransactionId() {
return transactionId;
}

/**
* Abstract builder class for constructing {@link WatsonxParameters} instances.
*
Expand All @@ -43,6 +54,7 @@ public String getSpaceId() {
public static abstract class Builder<T extends Builder<T>> {
private String projectId;
private String spaceId;
private String transactionId;

/**
* Sets the project id.
Expand All @@ -63,6 +75,16 @@ public T spaceId(String spaceId) {
this.spaceId = spaceId;
return (T) this;
}

/**
* Sets the transaction id for request tracking.
*
* @param transactionId the transaction id.
*/
public T transactionId(String transactionId) {
this.transactionId = transactionId;
return (T) this;
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import com.ibm.watsonx.ai.core.http.interceptors.RetryInterceptor;
import com.ibm.watsonx.ai.deployment.DeploymentService;
import com.ibm.watsonx.ai.embedding.EmbeddingService;
import com.ibm.watsonx.ai.foundationmodel.FoundationModel;
import com.ibm.watsonx.ai.foundationmodel.FoundationModelService;
import com.ibm.watsonx.ai.rerank.RerankService;
import com.ibm.watsonx.ai.textextraction.TextExtractionService;
Expand Down Expand Up @@ -49,6 +48,7 @@ public abstract class WatsonxService {
public static final String ML_API_PATH = "/ml/v1";
public static final String ML_API_TEXT_PATH = ML_API_PATH.concat("/text");
public static final String API_VERSION = "2025-04-23";
public static final String TRANSACTION_ID_HEADER = "X-Global-Transaction-Id";

protected final URI url;
protected final String version;
Expand All @@ -57,7 +57,6 @@ public abstract class WatsonxService {
protected final SyncHttpClient syncHttpClient;
protected final HttpClient httpClient;
protected final AsyncHttpClient asyncHttpClient;
protected final AuthenticationProvider authenticationProvider;

protected WatsonxService(Builder<?> builder) {
url = requireNonNull(builder.url, "The url must be provided");
Expand All @@ -76,12 +75,9 @@ protected WatsonxService(Builder<?> builder) {
asyncHttpClientBuilder.interceptor(retryInterceptor);

if (nonNull(builder.authenticationProvider)) {
authenticationProvider = builder.authenticationProvider;
var bearerInterceptor = new BearerInterceptor(authenticationProvider);
var bearerInterceptor = new BearerInterceptor(builder.authenticationProvider);
syncHttpClientBuilder.interceptor(bearerInterceptor);
asyncHttpClientBuilder.interceptor(bearerInterceptor);
} else {
authenticationProvider = null;
}

if (logRequests || logResponses) {
Expand Down Expand Up @@ -190,6 +186,15 @@ public T authenticationProvider(AuthenticationProvider authenticationProvider) {
this.authenticationProvider = authenticationProvider;
return (T) this;
}

/**
* Returns the authentication provider.
*
* @return the configured {@link AuthenticationProvider}, or {@code null} if none has been set.
*/
public AuthenticationProvider getAuthenticationProvider() {
return authenticationProvider;
}
}

/**
Expand Down Expand Up @@ -257,17 +262,6 @@ protected ModelService(Builder<?> builder) {
);
}

/**
* Retrieves model details.
*
* @return Details of the the model.
*/
public FoundationModel getModelDetails() {
return foundationModelService.getModelDetails(modelId)
.orElseThrow(() -> new RuntimeException("The model with id \"%s\" doesn't exist".formatted(modelId)));
}


@SuppressWarnings("unchecked")
protected static abstract class Builder<T extends Builder<T>> extends ProjectService.Builder<T> {
private String modelId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static java.util.Objects.nonNull;
import static java.util.Objects.requireNonNull;
import static java.util.Objects.requireNonNullElse;
import static java.util.Optional.ofNullable;
import java.io.IOException;
import java.net.URI;
import java.net.http.HttpRequest;
Expand Down Expand Up @@ -52,7 +53,7 @@ public final class ChatService extends ModelService implements ChatProvider {

protected ChatService(Builder builder) {
super(builder);
requireNonNull(super.authenticationProvider, "authenticationProvider cannot be null");
requireNonNull(builder.getAuthenticationProvider(), "authenticationProvider cannot be null");
}

/**
Expand All @@ -73,8 +74,8 @@ public ChatResponse chat(List<ChatMessage> messages, List<Tool> tools, ChatParam
parameters = requireNonNullElse(parameters, ChatParameters.builder().build());

var modelId = requireNonNullElse(parameters.getModelId(), this.modelId);
var projectId = nonNull(parameters.getProjectId()) ? parameters.getProjectId() : this.projectId;
var spaceId = nonNull(parameters.getSpaceId()) ? parameters.getSpaceId() : this.spaceId;
var projectId = ofNullable(parameters.getProjectId()).orElse(this.projectId);
var spaceId = ofNullable(parameters.getSpaceId()).orElse(this.spaceId);
var timeout = requireNonNullElse(parameters.getTimeLimit(), this.timeout.toMillis());

var chatRequest = ChatRequest.builder()
Expand All @@ -91,12 +92,14 @@ public ChatResponse chat(List<ChatMessage> messages, List<Tool> tools, ChatParam
.header("Content-Type", "application/json")
.header("Accept", "application/json")
.POST(BodyPublishers.ofString(toJson(chatRequest)))
.timeout(Duration.ofMillis(timeout))
.build();
.timeout(Duration.ofMillis(timeout));

if (nonNull(parameters.getTransactionId()))
httpRequest.header(TRANSACTION_ID_HEADER, parameters.getTransactionId());

try {

var httpReponse = syncHttpClient.send(httpRequest, BodyHandlers.ofString());
var httpReponse = syncHttpClient.send(httpRequest.build(), BodyHandlers.ofString());
return fromJson(httpReponse.body(), ChatResponse.class);

} catch (IOException | InterruptedException e) {
Expand All @@ -123,8 +126,8 @@ public CompletableFuture<Void> chatStreaming(List<ChatMessage> messages, List<To
parameters = requireNonNullElse(parameters, ChatParameters.builder().build());

var modelId = requireNonNullElse(parameters.getModelId(), this.modelId);
var projectId = nonNull(parameters.getProjectId()) ? parameters.getProjectId() : this.projectId;
var spaceId = nonNull(parameters.getSpaceId()) ? parameters.getSpaceId() : this.spaceId;
var projectId = ofNullable(parameters.getProjectId()).orElse(this.projectId);
var spaceId = ofNullable(parameters.getSpaceId()).orElse(this.spaceId);
var timeout = requireNonNullElse(parameters.getTimeLimit(), this.timeout.toMillis());

var chatRequest = ChatRequest.builder()
Expand All @@ -141,12 +144,14 @@ public CompletableFuture<Void> chatStreaming(List<ChatMessage> messages, List<To
.header("Content-Type", "application/json")
.header("Accept", "text/event-stream")
.POST(BodyPublishers.ofString(toJson(chatRequest)))
.timeout(Duration.ofMillis(timeout))
.build();
.timeout(Duration.ofMillis(timeout));

if (nonNull(parameters.getTransactionId()))
httpRequest.header(TRANSACTION_ID_HEADER, parameters.getTransactionId());

var subscriber = subscriber(chatRequest.getToolChoiceOption(), handler);
return asyncHttpClient
.send(httpRequest, responseInfo -> logResponses
.send(httpRequest.build(), responseInfo -> logResponses
? BodySubscribers.fromLineSubscriber(new SseEventLogger(subscriber, responseInfo.statusCode(), responseInfo.headers()))
: BodySubscribers.fromLineSubscriber(subscriber)
).thenApply(response -> null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static java.util.Objects.nonNull;
import static java.util.Objects.requireNonNull;
import static java.util.Objects.requireNonNullElse;
import static java.util.Optional.ofNullable;
import java.io.IOException;
import java.net.URI;
import java.net.http.HttpRequest;
Expand Down Expand Up @@ -105,11 +106,14 @@ public DeploymentResource findById(FindByIdParameters parameters) {
.newBuilder(URI.create(url.toString() + "/ml/v4/deployments/%s%s".formatted(deploymentId, queryParameters.toString())))
.header("Content-Type", "application/json")
.timeout(Duration.ofMillis(timeout.toMillis()))
.GET().build();
.GET();

if (nonNull(parameters.getTransactionId()))
httpRequest.header(TRANSACTION_ID_HEADER, parameters.getTransactionId());

try {

var httpReponse = syncHttpClient.send(httpRequest, BodyHandlers.ofString());
var httpReponse = syncHttpClient.send(httpRequest.build(), BodyHandlers.ofString());
return fromJson(httpReponse.body(), DeploymentResource.class);

} catch (IOException | InterruptedException e) {
Expand All @@ -129,19 +133,21 @@ public TextGenerationResponse generate(String input, Moderation moderation, Text
logIgnoredParameters(parameters.getModelId(), parameters.getProjectId(), parameters.getSpaceId());

var textGenerationRequest =
new TextGenerationRequest(null, null, null, input, parameters, moderation);
new TextGenerationRequest(null, null, null, input, parameters.toSanitized(), moderation);

var httpRequest = HttpRequest
.newBuilder(URI.create(url.toString() + "%s/deployments/%s/text/generation?version=%s".formatted(ML_API_PATH, deployment, version)))
.header("Content-Type", "application/json")
.header("Accept", "application/json")
.timeout(Duration.ofMillis(timeout))
.POST(BodyPublishers.ofString(toJson(textGenerationRequest)))
.build();
.POST(BodyPublishers.ofString(toJson(textGenerationRequest)));

if (nonNull(parameters.getTransactionId()))
httpRequest.header(TRANSACTION_ID_HEADER, parameters.getTransactionId());

try {

var httpReponse = syncHttpClient.send(httpRequest, BodyHandlers.ofString());
var httpReponse = syncHttpClient.send(httpRequest.build(), BodyHandlers.ofString());
return fromJson(httpReponse.body(), TextGenerationResponse.class);

} catch (IOException | InterruptedException e) {
Expand All @@ -162,20 +168,22 @@ public CompletableFuture<Void> generateStreaming(String input, TextGenerationPar
logIgnoredParameters(parameters.getModelId(), parameters.getProjectId(), parameters.getSpaceId());

var textGenerationRequest =
new TextGenerationRequest(null, null, null, input, parameters, null);
new TextGenerationRequest(null, null, null, input, parameters.toSanitized(), null);

var httpRequest = HttpRequest
.newBuilder(
URI.create(url.toString() + "%s/deployments/%s/text/generation_stream?version=%s".formatted(ML_API_PATH, deployment, version)))
.header("Content-Type", "application/json")
.header("Accept", "text/event-stream")
.timeout(Duration.ofMillis(timeout))
.POST(BodyPublishers.ofString(toJson(textGenerationRequest)))
.build();
.POST(BodyPublishers.ofString(toJson(textGenerationRequest)));

if (nonNull(parameters.getTransactionId()))
httpRequest.header(TRANSACTION_ID_HEADER, parameters.getTransactionId());

var subscriber = subscriber(handler);
return asyncHttpClient
.send(httpRequest,
.send(httpRequest.build(),
responseInfo -> logResponses
? BodySubscribers.fromLineSubscriber(new SseEventLogger(subscriber, responseInfo.statusCode(), responseInfo.headers()))
: BodySubscribers.fromLineSubscriber(subscriber)
Expand Down Expand Up @@ -204,12 +212,14 @@ public ChatResponse chat(List<ChatMessage> messages, List<Tool> tools, ChatParam
.header("Content-Type", "application/json")
.header("Accept", "application/json")
.POST(BodyPublishers.ofString(toJson(chatRequest)))
.timeout(Duration.ofMillis(timeout))
.build();
.timeout(Duration.ofMillis(timeout));

if (nonNull(parameters.getTransactionId()))
httpRequest.header(TRANSACTION_ID_HEADER, parameters.getTransactionId());

try {

var httpReponse = syncHttpClient.send(httpRequest, BodyHandlers.ofString());
var httpReponse = syncHttpClient.send(httpRequest.build(), BodyHandlers.ofString());
return fromJson(httpReponse.body(), ChatResponse.class);

} catch (IOException | InterruptedException e) {
Expand Down Expand Up @@ -241,12 +251,14 @@ public CompletableFuture<Void> chatStreaming(List<ChatMessage> messages, List<To
.header("Content-Type", "application/json")
.header("Accept", "text/event-stream")
.POST(BodyPublishers.ofString(toJson(chatRequest)))
.timeout(Duration.ofMillis(timeout))
.build();
.timeout(Duration.ofMillis(timeout));

if (nonNull(parameters.getTransactionId()))
httpRequest.header(TRANSACTION_ID_HEADER, parameters.getTransactionId());

var subscriber = subscriber(chatRequest.getToolChoiceOption(), handler);
return asyncHttpClient
.send(httpRequest, responseInfo -> logResponses
.send(httpRequest.build(), responseInfo -> logResponses
? BodySubscribers.fromLineSubscriber(new SseEventLogger(subscriber, responseInfo.statusCode(), responseInfo.headers()))
: BodySubscribers.fromLineSubscriber(subscriber)
).thenApply(response -> null);
Expand All @@ -260,11 +272,13 @@ public ForecastResponse forecast(InputSchema inputSchema, ForecastData data, Tim

Parameters requestParameters = null;
Map<String, List<Object>> futureData = null;
String transactionId = null;

if (nonNull(parameters)) {
logIgnoredParameters(parameters.getModelId(), parameters.getProjectId(), parameters.getSpaceId());
requestParameters = parameters.toParameters();
futureData = nonNull(parameters.getFutureData()) ? parameters.getFutureData().asMap() : null;
futureData = ofNullable(parameters.getFutureData()).map(p -> p.asMap()).orElse(null);
transactionId = parameters.getTransactionId();
}

var forecastRequest = new ForecastRequest(null, null, null, data.asMap(), inputSchema, futureData, requestParameters);
Expand All @@ -274,12 +288,14 @@ public ForecastResponse forecast(InputSchema inputSchema, ForecastData data, Tim
.header("Content-Type", "application/json")
.header("Accept", "application/json")
.timeout(timeout)
.POST(BodyPublishers.ofString(toJson(forecastRequest)))
.build();
.POST(BodyPublishers.ofString(toJson(forecastRequest)));

if (nonNull(transactionId))
httpRequest.header(TRANSACTION_ID_HEADER, transactionId);

try {

var httpReponse = syncHttpClient.send(httpRequest, BodyHandlers.ofString());
var httpReponse = syncHttpClient.send(httpRequest.build(), BodyHandlers.ofString());
return fromJson(httpReponse.body(), ForecastResponse.class);

} catch (IOException | InterruptedException e) {
Expand Down
Loading