Skip to content

Commit

Permalink
Fix VertexAI parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
nicoloboschi committed Sep 27, 2023
1 parent b4c83d1 commit 8f636df
Show file tree
Hide file tree
Showing 20 changed files with 114 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,6 @@ pipeline:
completion-field: "value.answer"
# we are also logging the prompt we sent to the LLM
log-field: "value.prompt"
max-tokens: 20
prompt:
- "{{% value.question}}"
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,11 @@ public CompletableFuture<String> getTextCompletions(

// this is the default behavior, as it is async
// it works even if the streamingChunksConsumer is null
final String model = (String) options.get("model");
if (completionsOptions.isStream()) {
CompletableFuture<?> finished = new CompletableFuture<>();
Flux<com.azure.ai.openai.models.Completions> flux =
client.getCompletionsStream((String) options.get("model"), completionsOptions);
client.getCompletionsStream(model, completionsOptions);

TextCompletionsConsumer textCompletionsConsumer =
new TextCompletionsConsumer(
Expand All @@ -253,8 +254,7 @@ public CompletableFuture<String> getTextCompletions(
return finished.thenApply(___ -> textCompletionsConsumer.totalAnswer.toString());
} else {
com.azure.ai.openai.models.Completions completions =
client.getCompletions((String) options.get("model"), completionsOptions)
.block();
client.getCompletions(model, completionsOptions).block();
final String text = completions.getChoices().get(0).getText();
return CompletableFuture.completedFuture(text);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package ai.langstream.ai.agents.services.impl;

import ai.langstream.ai.agents.services.ServiceProviderProvider;
import ai.langstream.api.util.ConfigurationUtils;
import com.datastax.oss.streaming.ai.completions.ChatChoice;
import com.datastax.oss.streaming.ai.completions.ChatCompletions;
import com.datastax.oss.streaming.ai.completions.ChatMessage;
Expand Down Expand Up @@ -315,19 +316,33 @@ private void appendRequestParameters(
Map<String, Object> additionalConfiguration, CompletionRequest request) {
request.parameters = new HashMap<>();

if (additionalConfiguration.containsKey("temperature")) {
request.parameters.put(
"temperature", additionalConfiguration.get("temperature"));
}
if (additionalConfiguration.containsKey("max-tokens")) {
request.parameters.put(
"maxOutputTokens", additionalConfiguration.get("max-tokens"));
}
if (additionalConfiguration.containsKey("topP")) {
request.parameters.put("topP", additionalConfiguration.get("topP"));
appendDoubleValue("temperature", "temperature", additionalConfiguration, request);
appendIntValue("max-tokens", "maxOutputTokens", additionalConfiguration, request);
appendDoubleValue("topP", "topP", additionalConfiguration, request);
appendIntValue("topK", "topK", additionalConfiguration, request);
}

private void appendDoubleValue(
String key,
String toKey,
Map<String, Object> additionalConfiguration,
CompletionRequest request) {
final Double typedValue =
ConfigurationUtils.getDouble(key, null, additionalConfiguration);
if (typedValue != null) {
request.parameters.put(toKey, typedValue);
}
if (additionalConfiguration.containsKey("topK")) {
request.parameters.put("topK", additionalConfiguration.get("topK"));
}

private void appendIntValue(
String key,
String toKey,
Map<String, Object> additionalConfiguration,
CompletionRequest request) {
final Integer typedValue =
ConfigurationUtils.getInteger(key, null, additionalConfiguration);
if (typedValue != null) {
request.parameters.put(toKey, typedValue);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import static com.datastax.oss.streaming.ai.util.TransformFunctionUtil.convertToMap;

import com.azure.ai.openai.models.ChatCompletionsOptions;
import com.datastax.oss.streaming.ai.completions.ChatChoice;
import com.datastax.oss.streaming.ai.completions.ChatCompletions;
import com.datastax.oss.streaming.ai.completions.ChatMessage;
Expand Down Expand Up @@ -121,21 +120,8 @@ public CompletableFuture<?> processAsync(TransformContext transformContext) {
.execute(jsonRecord)))
.collect(Collectors.toList());

ChatCompletionsOptions chatCompletionsOptions =
new ChatCompletionsOptions(List.of())
.setMaxTokens(config.getMaxTokens())
.setTemperature(config.getTemperature())
.setTopP(config.getTopP())
.setLogitBias(config.getLogitBias())
.setStream(config.isStream())
.setUser(config.getUser())
.setStop(config.getStop())
.setPresencePenalty(config.getPresencePenalty())
.setFrequencyPenalty(config.getFrequencyPenalty());
Map<String, Object> options = convertToMap(chatCompletionsOptions);
options.put("model", config.getModel());
Map<String, Object> options = convertToMap(config);
options.put("min-chunks-per-message", config.getMinChunksPerMessage());
options.remove("messages");

CompletableFuture<ChatCompletions> chatCompletionsHandle =
completionsService.getChatCompletions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import static com.datastax.oss.streaming.ai.util.TransformFunctionUtil.convertToMap;

import com.azure.ai.openai.models.CompletionsOptions;
import com.datastax.oss.streaming.ai.completions.Chunk;
import com.datastax.oss.streaming.ai.completions.CompletionsService;
import com.datastax.oss.streaming.ai.model.JsonRecord;
Expand Down Expand Up @@ -87,21 +86,8 @@ public CompletableFuture<?> processAsync(TransformContext transformContext) {
.map(p -> messageTemplates.get(p).execute(jsonRecord))
.collect(Collectors.toList());

CompletionsOptions completionsOptions =
new CompletionsOptions(List.of())
.setMaxTokens(config.getMaxTokens())
.setTemperature(config.getTemperature())
.setTopP(config.getTopP())
.setLogitBias(config.getLogitBias())
.setStream(config.isStream())
.setUser(config.getUser())
.setStop(config.getStop())
.setPresencePenalty(config.getPresencePenalty())
.setFrequencyPenalty(config.getFrequencyPenalty());
Map<String, Object> options = convertToMap(completionsOptions);
options.put("model", config.getModel());
final Map<String, Object> options = convertToMap(config);
options.put("min-chunks-per-message", config.getMinChunksPerMessage());
options.remove("messages");

CompletableFuture<String> chatCompletionsHandle =
completionsService.getTextCompletions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ EmbeddingsService getEmbeddingsService(Map<String, Object> additionalConfigurati

void close();

public static class NoopServiceProvider implements ServiceProvider {
class NoopServiceProvider implements ServiceProvider {
@Override
public CompletionsService getCompletionsService(
Map<String, Object> additionalConfiguration) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,6 @@ private void generateAIProvidersConfiguration(
}
} else {
for (Resource resource : applicationInstance.getResources().values()) {
Map<String, Object> configurationCopy =
clusterRuntime.getResourceImplementation(resource, pluginsRegistry);
final String configKey =
switch (resource.type()) {
case SERVICE_VERTEX -> "vertex";
Expand All @@ -413,6 +411,8 @@ private void generateAIProvidersConfiguration(
default -> null;
};
if (configKey != null) {
Map<String, Object> configurationCopy =
clusterRuntime.getResourceImplementation(resource, pluginsRegistry);
configuration.put(configKey, configurationCopy);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package ai.langstream.tests;

import static ai.langstream.tests.TextCompletionsIT.getAppEnvForAIServiceProvider;

import ai.langstream.tests.util.BaseEndToEndTest;
import ai.langstream.tests.util.ConsumeGatewayMessage;
import java.util.List;
Expand All @@ -36,13 +38,10 @@ public class ChatCompletionsIT extends BaseEndToEndTest {

@BeforeAll
public static void checkCredentials() {
appEnv =
appEnv = getAppEnvForAIServiceProvider();
appEnv.putAll(
getAppEnvMapFromSystem(
List.of(
"OPEN_AI_ACCESS_KEY",
"OPEN_AI_URL",
"OPEN_AI_CHAT_COMPLETIONS_MODEL",
"OPEN_AI_PROVIDER"));
List.of("CHAT_COMPLETIONS_MODEL", "CHAT_COMPLETIONS_SERVICE")));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,25 @@ public class TextCompletionsIT extends BaseEndToEndTest {

@BeforeAll
public static void checkCredentials() {
appEnv = getAppEnvForAIServiceProvider();
appEnv.putAll(
getAppEnvMapFromSystem(
List.of("TEXT_COMPLETIONS_MODEL", "TEXT_COMPLETIONS_SERVICE")));
}

public static Map<String, String> getAppEnvForAIServiceProvider() {
try {
appEnv =
getAppEnvMapFromSystem(
List.of("OPEN_AI_ACCESS_KEY", "OPEN_AI_URL", "OPEN_AI_PROVIDER"));
return getAppEnvMapFromSystem(
List.of("OPEN_AI_ACCESS_KEY", "OPEN_AI_URL", "OPEN_AI_PROVIDER"));
} catch (Throwable t) {
// no openai - try vertex
appEnv =
getAppEnvMapFromSystem(
List.of(
"VERTEX_AI_URL",
"VERTEX_AI_TOKEN",
"VERTEX_AI_REGION",
"VERTEX_AI_PROJECT"));
return getAppEnvMapFromSystem(
List.of(
"VERTEX_AI_URL",
"VERTEX_AI_TOKEN",
"VERTEX_AI_REGION",
"VERTEX_AI_PROJECT"));
}

appEnv.putAll(
getAppEnvMapFromSystem(
List.of("TEXT_COMPLETIONS_MODEL", "TEXT_COMPLETIONS_SERVICE")));
}

@Test
Expand All @@ -80,6 +81,6 @@ public void test() throws Exception {
.formatted(sessionId)
.split(" "));
log.info("Output: {}", message);
Assertions.assertTrue(message.getAnswerFromChatCompletionsValue().contains("Bounjour"));
Assertions.assertTrue(message.getAnswerFromChatCompletionsValue().contains("Bonjour"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package ai.langstream.tests;

import static ai.langstream.tests.TextCompletionsIT.getAppEnvForAIServiceProvider;

import ai.langstream.tests.util.BaseEndToEndTest;
import ai.langstream.tests.util.ConsumeGatewayMessage;
import java.util.List;
Expand All @@ -36,20 +38,20 @@ public class WebCrawlerToVectorIT extends BaseEndToEndTest {

@BeforeAll
public static void checkCredentials() {
appEnv =
appEnv = getAppEnvForAIServiceProvider();
appEnv.putAll(
getAppEnvMapFromSystem(
List.of("CHAT_COMPLETIONS_MODEL", "CHAT_COMPLETIONS_SERVICE")));
appEnv.putAll(getAppEnvMapFromSystem(List.of("EMBEDDINGS_MODEL", "EMBEDDINGS_SERVICE")));

appEnv.putAll(
getAppEnvMapFromSystem(
List.of(
"OPEN_AI_ACCESS_KEY",
"OPEN_AI_URL",
"OPEN_AI_EMBEDDINGS_MODEL",
"OPEN_AI_CHAT_COMPLETIONS_MODEL",
"OPEN_AI_PROVIDER",
"ASTRA_TOKEN",
"ASTRA_CLIENT_ID",
"ASTRA_SECRET",
"ASTRA_SECURE_BUNDLE",
"ASTRA_ENVIRONMENT",
"ASTRA_DATABASE"));
"ASTRA_DATABASE")));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,9 @@ private static KubeCluster getKubeCluster() {
public void setupSingleTest() {
// cleanup previous runs
cleanupAllEndToEndTestsNamespaces();
codeStorageProvider.cleanup();
streamingClusterProvider.cleanup();

namespace = "ls-test-" + UUID.randomUUID().toString().substring(0, 8);

client.resource(
Expand Down Expand Up @@ -1132,7 +1135,6 @@ private static void deployLocalApplicationAndAwaitReady(
.pollInterval(5, TimeUnit.SECONDS)
.untilAsserted(
() -> {
log.info("waiting new executors to be ready");
final List<Pod> pods =
client.pods()
.inNamespace(tenantNamespace)
Expand All @@ -1144,6 +1146,10 @@ private static void deployLocalApplicationAndAwaitReady(
"langstream-runtime"))
.list()
.getItems();
log.info(
"waiting new executors to be ready, found {}, expected {}",
pods.size(),
expectedNumExecutors);
if (pods.size() != expectedNumExecutors) {
fail("too many pods: " + pods.size());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,17 @@
configuration:
resources:
- type: "open-ai-configuration"
id: "open-ai"
name: "OpenAI Azure configuration"
configuration:
url: "{{ secrets.open-ai.url }}"
access-key: "{{ secrets.open-ai.access-key }}"
provider: "{{ secrets.open-ai.provider }}"
- type: "vertex-configuration"
name: "Google Vertex AI configuration"
id: "vertex"
configuration:
url: "{{ secrets.vertex-ai.url }}"
token: "{{ secrets.vertex-ai.token }}"
region: "{{ secrets.vertex-ai.region }}"
project: "{{ secrets.vertex-ai.project }}"
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ pipeline:
type: "ai-chat-completions"
output: "ls-test-history-topic"
configuration:
model: "{{{secrets.open-ai.chat-completions-model}}}"
ai-service: "{{{secrets.chat-completions.service}}}"
model: "{{{secrets.chat-completions.model}}}"
completion-field: "value.answer"
log-field: "value.prompt"
stream-to-topic: "ls-test-output-topic"
stream-response-completion-field: "value"
min-chunks-per-message: 20
max-tokens: 20
messages:
- role: user
content: "You are an helpful assistant. Below you can fine a question from the user. Please try to help them the best way you can.\n\n{{% value.question}}"
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
configuration:
resources:
- type: "open-ai-configuration"
id: "open-ai"
name: "OpenAI Azure configuration"
configuration:
url: "{{ secrets.open-ai.url }}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,6 @@ pipeline:
stream-to-topic: "ls-test-output-topic"
stream-response-completion-field: "value"
min-chunks-per-message: 20
max-tokens: 20
prompt:
- "{{% value.question}}"
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ pipeline:
type: "ai-chat-completions"

configuration:
model: "{{{secrets.open-ai.chat-completions-model}}}" # This needs to be set to the model deployment name, not the base name
ai-service: "{{{secrets.chat-completions.service}}}"
model: "{{{secrets.chat-completions.model}}}" # This needs to be set to the model deployment name, not the base name
# on the ls-test-log-topic we add a field with the answer
completion-field: "value.answer"
# we are also logging the prompt we sent to the LLM
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
configuration:
resources:
- type: "open-ai-configuration"
id: "open-ai"
name: "OpenAI Azure configuration"
configuration:
url: "{{ secrets.open-ai.url }}"
Expand All @@ -29,7 +30,14 @@ configuration:
service: "astra"
clientId: "{{{ secrets.astra.clientId }}}"
secret: "{{{ secrets.astra.secret }}}"
secureBundle: "{{{ secrets.astra.secureBundle }}}"
database: "{{{ secrets.astra.database }}}"
token: "{{{ secrets.astra.token }}}"
environment: "{{{ secrets.astra.environment }}}"
environment: "{{{ secrets.astra.environment }}}"
- type: "vertex-configuration"
name: "Google Vertex AI configuration"
id: "vertex"
configuration:
url: "{{ secrets.vertex-ai.url }}"
token: "{{ secrets.vertex-ai.token }}"
region: "{{ secrets.vertex-ai.region }}"
project: "{{ secrets.vertex-ai.project }}"
Loading

0 comments on commit 8f636df

Please sign in to comment.