Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce new gateway type: 'chat' #398

Merged
merged 3 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions examples/applications/gateway-authentication/gateways.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ gateways:
headers:
- key: langstream-client-session-id
value-from-parameters: sessionId
- id: chat-no-auth
type: chat
chat-options:
headers:
- value-from-parameters: session-id
questions-topic: input-topic
answers-topic: output-topic

- id: produce-input-auth-google
type: produce
Expand Down
2 changes: 1 addition & 1 deletion examples/instances/kafka-docker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ instance:
type: "kafka"
configuration:
admin:
bootstrap.servers: localhost:9092
bootstrap.servers: localhost:39092
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ai.langstream.api.storage.ApplicationStore;
import ai.langstream.apigateway.config.GatewayTestAuthenticationProperties;
import ai.langstream.apigateway.runner.TopicConnectionsRuntimeProviderBean;
import ai.langstream.apigateway.websocket.handlers.ChatHandler;
import ai.langstream.apigateway.websocket.handlers.ConsumeHandler;
import ai.langstream.apigateway.websocket.handlers.ProduceHandler;
import jakarta.annotation.PreDestroy;
Expand All @@ -42,6 +43,7 @@ public class WebSocketConfig implements WebSocketConfigurer {

public static final String CONSUME_PATH = "/v1/consume/{tenant}/{application}/{gateway}";
public static final String PRODUCE_PATH = "/v1/produce/{tenant}/{application}/{gateway}";
public static final String CHAT_PATH = "/v1/chat/{tenant}/{application}/{gateway}";

private final ApplicationStore applicationStore;
private final TopicConnectionsRuntimeProviderBean topicConnectionsRuntimeRegistryProvider;
Expand All @@ -61,6 +63,12 @@ public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
.addHandler(
new ProduceHandler(applicationStore, topicConnectionsRuntimeRegistry),
PRODUCE_PATH)
.addHandler(
new ChatHandler(
applicationStore,
consumeThreadPool,
topicConnectionsRuntimeRegistry),
CHAT_PATH)
.setAllowedOrigins("*")
.addInterceptors(
new HttpSessionHandshakeInterceptor(),
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.langstream.apigateway.websocket.handlers;

import static ai.langstream.apigateway.websocket.WebSocketConfig.CHAT_PATH;

import ai.langstream.api.model.Gateway;
import ai.langstream.api.runner.code.Header;
import ai.langstream.api.runner.code.Record;
import ai.langstream.api.runner.topics.TopicConnectionsRuntimeRegistry;
import ai.langstream.api.storage.ApplicationStore;
import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.function.Function;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;

@Slf4j
public class ChatHandler extends AbstractHandler {

private final ExecutorService executor;

public ChatHandler(
ApplicationStore applicationStore,
ExecutorService executor,
TopicConnectionsRuntimeRegistry topicConnectionsRuntimeRegistry) {
super(applicationStore, topicConnectionsRuntimeRegistry);
this.executor = executor;
}

@Override
public String path() {
return CHAT_PATH;
}

@Override
Gateway.GatewayType gatewayType() {
return Gateway.GatewayType.chat;
}

@Override
String tenantFromPath(Map<String, String> parsedPath, Map<String, String> queryString) {
return parsedPath.get("tenant");
}

@Override
String applicationIdFromPath(Map<String, String> parsedPath, Map<String, String> queryString) {
return parsedPath.get("application");
}

@Override
String gatewayFromPath(Map<String, String> parsedPath, Map<String, String> queryString) {
return parsedPath.get("gateway");
}

@Override
protected List<String> getAllRequiredParameters(Gateway gateway) {
List<String> parameters = gateway.getParameters();
if (parameters == null) {
parameters = new ArrayList<>();
}
if (gateway.getChatOptions() != null && gateway.getChatOptions().getHeaders() != null) {
for (Gateway.KeyValueComparison header : gateway.getChatOptions().getHeaders()) {
parameters.add(header.key());
}
}
return parameters;
}

@Override
public void onBeforeHandshakeCompleted(
AuthenticatedGatewayRequestContext context, Map<String, Object> attributes)
throws Exception {

setupReader(context);
setupProducer(context);

sendClientConnectedEvent(context);
}

private void setupProducer(AuthenticatedGatewayRequestContext context) {
final Gateway.ChatOptions chatOptions = context.gateway().getChatOptions();

List<Gateway.KeyValueComparison> headerConfig = new ArrayList<>();
final List<Gateway.KeyValueComparison> gwHeaders = chatOptions.getHeaders();
if (gwHeaders != null) {
for (Gateway.KeyValueComparison gwHeader : gwHeaders) {
headerConfig.add(gwHeader);
}
}
final List<Header> commonHeaders =
getProducerCommonHeaders(
headerConfig, context.userParameters(), context.principalValues());
setupProducer(
context.attributes(),
chatOptions.getQuestionsTopic(),
context.application().getInstance().streamingCluster(),
commonHeaders,
context.tenant(),
context.applicationId(),
context.gateway().getId());
}

private void setupReader(AuthenticatedGatewayRequestContext context) throws Exception {
final Gateway.ChatOptions chatOptions = context.gateway().getChatOptions();

List<Gateway.KeyValueComparison> headerFilters = new ArrayList<>();
final List<Gateway.KeyValueComparison> gwHeaders = chatOptions.getHeaders();
if (gwHeaders != null) {
for (Gateway.KeyValueComparison gwHeader : gwHeaders) {
headerFilters.add(gwHeader);
}
}
final List<Function<Record, Boolean>> messageFilters =
createMessageFilters(
headerFilters, context.userParameters(), context.principalValues());

setupReader(
context.attributes(),
chatOptions.getAnswersTopic(),
context.application().getInstance().streamingCluster(),
messageFilters,
context.options());
}

@Override
public void onOpen(
WebSocketSession webSocketSession, AuthenticatedGatewayRequestContext context) {
startReadingMessages(webSocketSession, context, executor);
}

@Override
public void onMessage(
WebSocketSession webSocketSession,
AuthenticatedGatewayRequestContext context,
TextMessage message)
throws Exception {
produceMessage(webSocketSession, message);
}

@Override
public void onClose(
WebSocketSession webSocketSession,
AuthenticatedGatewayRequestContext context,
CloseStatus status) {}

@Override
void validateOptions(Map<String, String> options) {
for (Map.Entry<String, String> option : options.entrySet()) {
switch (option.getKey()) {
case "position":
if (!StringUtils.hasText(option.getValue())) {
throw new IllegalArgumentException("'position' cannot be blank");
}
break;
default:
throw new IllegalArgumentException("Unknown option " + option.getKey());
}
}
}
}
Loading
Loading