diff --git a/bin/langstream b/bin/langstream index d6cd917ff..ea779fad5 100755 --- a/bin/langstream +++ b/bin/langstream @@ -37,7 +37,9 @@ if [ ! -d langstream-cli/target/cli ]; then fi popd > /dev/null LANGSTREAM_CLI_CONFIG=${LANGSTREAM_CLI_CONFIG:-"conf/cli.yaml"} -echo "Using development CLI config file $(realpath $LANGSTREAM_CLI_CONFIG). To use the global config file, set LANGSTREAM_CLI_CONFIG=\$HOME/.langstream/config" +if [ "$LANGSTREAM_CLI_CONFIG" == "conf/cli.yaml" ]; then + echo "Using development CLI config file $(realpath $LANGSTREAM_CLI_CONFIG). To use the global config file, set LANGSTREAM_CLI_CONFIG=\$HOME/.langstream/config" +fi "$ROOT_DIR/langstream-cli/target/cli/bin/langstream" --conf "$LANGSTREAM_CLI_CONFIG" "$@" diff --git a/docker/build.sh b/docker/build.sh index 8b0f6b757..4a50d5532 100755 --- a/docker/build.sh +++ b/docker/build.sh @@ -27,11 +27,13 @@ docker_platforms() { fi } +common_flags="-DskipTests -PskipPython -Dlicense.skip -Dspotless.skip" + build_docker_image() { module=$1 - ./mvnw clean install -am -DskipTests -pl $module -T 1C -PskipPython - ./mvnw package -DskipTests -Pdocker -pl $module -Ddocker.platforms="$(docker_platforms)" -PskipPython + ./mvnw install -am -pl $module -T 1C $common_flags + ./mvnw package -Pdocker -pl $module -Ddocker.platforms="$(docker_platforms)" $common_flags docker images | head -n 2 } @@ -47,9 +49,9 @@ elif [ "$only_image" == "api-gateway" ]; then build_docker_image langstream-api-gateway else # Build all artifacts - ./mvnw install -DskipTests -T 1C -Ddocker.platforms="$(docker_platforms)" -PskipPython + ./mvnw install -T 1C -Ddocker.platforms="$(docker_platforms)" $common_flags # Build docker images - ./mvnw package -DskipTests -Pdocker -Ddocker.platforms="$(docker_platforms)" -PskipPython + ./mvnw package -Pdocker -Ddocker.platforms="$(docker_platforms)" $common_flags docker images | head -n 6 fi diff --git a/examples/applications/gateway-authentication/gateways.yaml b/examples/applications/gateway-authentication/gateways.yaml index 53b62d997..5fc5729e8 100644 --- a/examples/applications/gateway-authentication/gateways.yaml +++ b/examples/applications/gateway-authentication/gateways.yaml @@ -21,20 +21,20 @@ gateways: topic: input-topic parameters: - sessionId - produceOptions: + produce-options: headers: - key: langstream-client-session-id - valueFromParameters: sessionId + value-from-parameters: sessionId - id: consume-output-no-auth type: consume topic: output-topic parameters: - sessionId - consumeOptions: + consume-options: filters: headers: - key: langstream-client-session-id - valueFromParameters: sessionId + value-from-parameters: sessionId - id: produce-input-auth-google type: produce @@ -43,14 +43,15 @@ gateways: - sessionId authentication: provider: google + allow-test-mode: true configuration: clientId: "{{ secrets.google.client-id }}" - produceOptions: + produce-options: headers: - key: langstream-client-user-id - valueFromAuthentication: subject + value-from-authentication: subject - key: langstream-client-session-id - valueFromParameters: sessionId + value-from-parameters: sessionId - id: consume-output-auth-google type: consume @@ -58,16 +59,17 @@ gateways: parameters: - sessionId authentication: + allow-test-mode: true provider: google configuration: clientId: "{{ secrets.google.client-id }}" - consumeOptions: + consume-options: filters: headers: - key: langstream-client-user-id - valueFromAuthentication: subject + value-from-authentication: subject - key: langstream-client-session-id - valueFromParameters: sessionId + value-from-parameters: sessionId - id: produce-input-auth-github type: produce @@ -78,12 +80,12 @@ gateways: provider: github configuration: clientId: "{{ secrets.github.client-id }}" - produceOptions: + produce-options: headers: - key: langstream-client-user-id - valueFromAuthentication: login + value-from-authentication: login - key: langstream-client-session-id - valueFromParameters: sessionId + value-from-parameters: sessionId - id: consume-output-auth-github type: consume @@ -94,10 +96,10 @@ gateways: provider: github configuration: clientId: "{{ secrets.github.client-id }}" - consumeOptions: + consume-options: filters: headers: - key: langstream-client-user-id - valueFromAuthentication: login + value-from-authentication: login - key: langstream-client-session-id - valueFromParameters: sessionId + value-from-parameters: sessionId diff --git a/examples/applications/openai-completions/gateways.yaml b/examples/applications/openai-completions/gateways.yaml index 5de4c3e5c..eda8d4ea8 100644 --- a/examples/applications/openai-completions/gateways.yaml +++ b/examples/applications/openai-completions/gateways.yaml @@ -21,21 +21,21 @@ gateways: topic: input-topic parameters: - sessionId - produceOptions: + produce-options: headers: - key: langstream-client-session-id - valueFromParameters: sessionId + value-from-parameters: sessionId - id: consume-output type: consume topic: output-topic parameters: - sessionId - consumeOptions: + consume-options: filters: headers: - key: langstream-client-session-id - valueFromParameters: sessionId + value-from-parameters: sessionId - id: consume-history type: consume @@ -48,10 +48,10 @@ gateways: provider: google configuration: clientId: "{{ secrets.google.client-id }}" - produceOptions: + produce-options: headers: - key: langstream-client-user-id - valueFromAuthentication: subject + value-from-authentication: subject - id: consume-output-auth type: consume @@ -60,9 +60,9 @@ gateways: provider: google configuration: clientId: "{{ secrets.google.client-id }}" - consumeOptions: + consume-options: filters: headers: - key: langstream-client-user-id - valueFromAuthentication: subject + value-from-authentication: subject diff --git a/langstream-api-gateway-auth/langstream-github-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/github/GitHubAuthenticationProvider.java b/langstream-api-gateway-auth/langstream-github-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/github/GitHubAuthenticationProvider.java index 137f6f47f..92221d916 100644 --- a/langstream-api-gateway-auth/langstream-github-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/github/GitHubAuthenticationProvider.java +++ b/langstream-api-gateway-auth/langstream-github-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/github/GitHubAuthenticationProvider.java @@ -81,7 +81,7 @@ public GatewayAuthenticationResult authenticate(GatewayRequestContext context) { log.info("X-OAuth-Client-Id: {}", responseClientId); log.info("Required: X-OAuth-Client-Id: {}", clientId); - Map result = new ObjectMapper().readValue(body, Map.class); + Map result = mapper.readValue(body, Map.class); if (log.isDebugEnabled()) { response.headers().map().forEach((k, v) -> log.debug("Header {}: {}", k, v)); } diff --git a/langstream-api-gateway-auth/langstream-http-api-gateway-auth/pom.xml b/langstream-api-gateway-auth/langstream-http-api-gateway-auth/pom.xml new file mode 100644 index 000000000..f356e9dfe --- /dev/null +++ b/langstream-api-gateway-auth/langstream-http-api-gateway-auth/pom.xml @@ -0,0 +1,47 @@ + + + + + langstream-api-gateway-auth + ai.langstream + 0.0.16-SNAPSHOT + + 4.0.0 + + langstream-http-api-gateway-auth + + + + ai.langstream + langstream-api + ${project.version} + + + org.slf4j + slf4j-api + provided + + + com.fasterxml.jackson.core + jackson-databind + + + \ No newline at end of file diff --git a/langstream-api-gateway-auth/langstream-http-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/jwt/admin/HttpAuthenticationProvider.java b/langstream-api-gateway-auth/langstream-http-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/jwt/admin/HttpAuthenticationProvider.java new file mode 100644 index 000000000..cb484c89c --- /dev/null +++ b/langstream-api-gateway-auth/langstream-http-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/jwt/admin/HttpAuthenticationProvider.java @@ -0,0 +1,92 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.apigateway.auth.impl.jwt.admin; + +import ai.langstream.api.gateway.GatewayAuthenticationProvider; +import ai.langstream.api.gateway.GatewayAuthenticationResult; +import ai.langstream.api.gateway.GatewayRequestContext; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.Map; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class HttpAuthenticationProvider implements GatewayAuthenticationProvider { + + private static final ObjectMapper mapper = new ObjectMapper(); + private HttpAuthenticationProviderConfiguration httpConfiguration; + private HttpClient httpClient; + + @Override + public String type() { + return "http"; + } + + @Override + @SneakyThrows + public void initialize(Map configuration) { + httpConfiguration = + mapper.convertValue(configuration, HttpAuthenticationProviderConfiguration.class); + httpClient = + HttpClient.newBuilder() + .connectTimeout(Duration.ofSeconds(30)) + .followRedirects(HttpClient.Redirect.ALWAYS) + .build(); + } + + @Override + public GatewayAuthenticationResult authenticate(GatewayRequestContext context) { + + final Map placeholders = Map.of("tenant", context.tenant()); + final String uri = resolvePlaceholders(placeholders, httpConfiguration.getPathTemplate()); + final String url = httpConfiguration.getBaseUrl() + uri; + + log.info("Authenticating admin with url: {}", url); + + final HttpRequest.Builder builder = HttpRequest.newBuilder().uri(URI.create(url)); + + httpConfiguration.getHeaders().forEach(builder::header); + builder.header("Authorization", "Bearer " + context.credentials()); + final HttpRequest request = builder.GET().build(); + + final HttpResponse response; + try { + response = httpClient.send(request, HttpResponse.BodyHandlers.discarding()); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } catch (Throwable e) { + return GatewayAuthenticationResult.authenticationFailed(e.getMessage()); + } + if (httpConfiguration.getAcceptedStatuses().contains(response.statusCode())) { + return GatewayAuthenticationResult.authenticationSuccessful(Map.of()); + } + return GatewayAuthenticationResult.authenticationFailed( + "Http authentication failed: " + response.statusCode()); + } + + private static String resolvePlaceholders(Map placeholders, String url) { + for (Map.Entry entry : placeholders.entrySet()) { + url = url.replace("{" + entry.getKey() + "}", entry.getValue()); + } + return url; + } +} diff --git a/langstream-api-gateway-auth/langstream-http-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/jwt/admin/HttpAuthenticationProviderConfiguration.java b/langstream-api-gateway-auth/langstream-http-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/jwt/admin/HttpAuthenticationProviderConfiguration.java new file mode 100644 index 000000000..01e896816 --- /dev/null +++ b/langstream-api-gateway-auth/langstream-http-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/jwt/admin/HttpAuthenticationProviderConfiguration.java @@ -0,0 +1,41 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.apigateway.auth.impl.jwt.admin; + +import com.fasterxml.jackson.annotation.JsonAlias; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@AllArgsConstructor +public class HttpAuthenticationProviderConfiguration { + + @JsonAlias({"base-url", "baseurl"}) + private String baseUrl; + + @JsonAlias({"path-template", "pathtemplate"}) + private String pathTemplate; + + private Map headers = new HashMap<>(); + + @JsonAlias({"accepted-statuses", "acceptedstatuses"}) + private List acceptedStatuses = List.of(200, 201); +} diff --git a/langstream-api-gateway-auth/langstream-http-api-gateway-auth/src/main/resources/META-INF/services/ai.langstream.api.gateway.GatewayAuthenticationProvider b/langstream-api-gateway-auth/langstream-http-api-gateway-auth/src/main/resources/META-INF/services/ai.langstream.api.gateway.GatewayAuthenticationProvider new file mode 100644 index 000000000..2a472f630 --- /dev/null +++ b/langstream-api-gateway-auth/langstream-http-api-gateway-auth/src/main/resources/META-INF/services/ai.langstream.api.gateway.GatewayAuthenticationProvider @@ -0,0 +1 @@ +ai.langstream.apigateway.auth.impl.jwt.admin.HttpAuthenticationProvider \ No newline at end of file diff --git a/langstream-api-gateway-auth/langstream-jwt-api-gateway-auth/pom.xml b/langstream-api-gateway-auth/langstream-jwt-api-gateway-auth/pom.xml new file mode 100644 index 000000000..0d2bad396 --- /dev/null +++ b/langstream-api-gateway-auth/langstream-jwt-api-gateway-auth/pom.xml @@ -0,0 +1,43 @@ + + + + + langstream-api-gateway-auth + ai.langstream + 0.0.16-SNAPSHOT + + 4.0.0 + + langstream-jwt-api-gateway-auth + + + + ai.langstream + langstream-api + ${project.version} + + + ai.langstream + langstream-auth-jwt + ${project.version} + + + \ No newline at end of file diff --git a/langstream-api-gateway-auth/langstream-jwt-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/jwt/admin/JwtAuthenticationProvider.java b/langstream-api-gateway-auth/langstream-jwt-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/jwt/admin/JwtAuthenticationProvider.java new file mode 100644 index 000000000..098414c13 --- /dev/null +++ b/langstream-api-gateway-auth/langstream-jwt-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/jwt/admin/JwtAuthenticationProvider.java @@ -0,0 +1,76 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.apigateway.auth.impl.jwt.admin; + +import ai.langstream.api.gateway.GatewayAuthenticationProvider; +import ai.langstream.api.gateway.GatewayAuthenticationResult; +import ai.langstream.api.gateway.GatewayRequestContext; +import ai.langstream.auth.jwt.AuthenticationProviderToken; +import ai.langstream.auth.jwt.JwtProperties; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.util.List; +import java.util.Map; +import lombok.SneakyThrows; + +public class JwtAuthenticationProvider implements GatewayAuthenticationProvider { + + private static final ObjectMapper mapper = new ObjectMapper(); + private AuthenticationProviderToken authenticationProviderToken; + private List adminRoles; + + @Override + public String type() { + return "jwt"; + } + + @Override + @SneakyThrows + public void initialize(Map configuration) { + final JwtAuthenticationProviderConfiguration tokenProperties = + mapper.convertValue(configuration, JwtAuthenticationProviderConfiguration.class); + + if (tokenProperties.adminRoles() != null) { + this.adminRoles = tokenProperties.adminRoles(); + } else { + this.adminRoles = List.of(); + } + + final JwtProperties jwtProperties = + new JwtProperties( + tokenProperties.secretKey(), + tokenProperties.publicKey(), + tokenProperties.authClaim(), + tokenProperties.publicAlg(), + tokenProperties.audienceClaim(), + tokenProperties.audience(), + tokenProperties.jwksHostsAllowlist()); + this.authenticationProviderToken = new AuthenticationProviderToken(jwtProperties); + } + + @Override + public GatewayAuthenticationResult authenticate(GatewayRequestContext context) { + String role; + try { + role = authenticationProviderToken.authenticate(context.credentials()); + } catch (AuthenticationProviderToken.AuthenticationException ex) { + return GatewayAuthenticationResult.authenticationFailed(ex.getMessage()); + } + if (!adminRoles.contains(role)) { + return GatewayAuthenticationResult.authenticationFailed("Not an admin."); + } + return GatewayAuthenticationResult.authenticationSuccessful(Map.of()); + } +} diff --git a/langstream-api-gateway-auth/langstream-jwt-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/jwt/admin/JwtAuthenticationProviderConfiguration.java b/langstream-api-gateway-auth/langstream-jwt-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/jwt/admin/JwtAuthenticationProviderConfiguration.java new file mode 100644 index 000000000..08375fad2 --- /dev/null +++ b/langstream-api-gateway-auth/langstream-jwt-api-gateway-auth/src/main/java/ai/langstream/apigateway/auth/impl/jwt/admin/JwtAuthenticationProviderConfiguration.java @@ -0,0 +1,28 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.apigateway.auth.impl.jwt.admin; + +import java.util.List; + +public record JwtAuthenticationProviderConfiguration( + String secretKey, + String publicKey, + String authClaim, + String publicAlg, + String audienceClaim, + String audience, + List adminRoles, + String jwksHostsAllowlist) {} diff --git a/langstream-api-gateway-auth/langstream-jwt-api-gateway-auth/src/main/resources/META-INF/services/ai.langstream.api.gateway.GatewayAuthenticationProvider b/langstream-api-gateway-auth/langstream-jwt-api-gateway-auth/src/main/resources/META-INF/services/ai.langstream.api.gateway.GatewayAuthenticationProvider new file mode 100644 index 000000000..121dce9e6 --- /dev/null +++ b/langstream-api-gateway-auth/langstream-jwt-api-gateway-auth/src/main/resources/META-INF/services/ai.langstream.api.gateway.GatewayAuthenticationProvider @@ -0,0 +1 @@ +ai.langstream.apigateway.auth.impl.jwt.admin.JwtAuthenticationProvider \ No newline at end of file diff --git a/langstream-api-gateway-auth/pom.xml b/langstream-api-gateway-auth/pom.xml index 67b135777..d31e9864f 100644 --- a/langstream-api-gateway-auth/pom.xml +++ b/langstream-api-gateway-auth/pom.xml @@ -30,5 +30,7 @@ langstream-google-api-gateway-auth langstream-github-api-gateway-auth + langstream-jwt-api-gateway-auth + langstream-http-api-gateway-auth diff --git a/langstream-api-gateway/pom.xml b/langstream-api-gateway/pom.xml index d970b7b39..77329e93d 100644 --- a/langstream-api-gateway/pom.xml +++ b/langstream-api-gateway/pom.xml @@ -104,6 +104,11 @@ kafka test + + com.github.tomakehurst + wiremock + test + ai.langstream langstream-core @@ -124,6 +129,20 @@ runtime + + ai.langstream + langstream-jwt-api-gateway-auth + ${project.version} + runtime + + + + ai.langstream + langstream-http-api-gateway-auth + ${project.version} + runtime + + diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/LangStreamApiGateway.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/LangStreamApiGateway.java index d8ba089cc..c3c47e2dd 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/LangStreamApiGateway.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/LangStreamApiGateway.java @@ -15,6 +15,7 @@ */ package ai.langstream.apigateway; +import ai.langstream.apigateway.config.GatewayTestAuthenticationProperties; import ai.langstream.apigateway.config.StorageProperties; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -24,7 +25,7 @@ import org.springframework.core.env.Environment; @SpringBootApplication -@EnableConfigurationProperties({StorageProperties.class}) +@EnableConfigurationProperties({StorageProperties.class, GatewayTestAuthenticationProperties.class}) public class LangStreamApiGateway { static { diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/config/GatewayTestAuthenticationProperties.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/config/GatewayTestAuthenticationProperties.java new file mode 100644 index 000000000..112e21453 --- /dev/null +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/config/GatewayTestAuthenticationProperties.java @@ -0,0 +1,34 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.apigateway.config; + +import java.util.HashMap; +import java.util.Map; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import org.springframework.boot.context.properties.ConfigurationProperties; + +@ConfigurationProperties(prefix = "application.gateways.auth.test") +@Data +@NoArgsConstructor +@AllArgsConstructor +public class GatewayTestAuthenticationProperties { + + private String type; + + private Map configuration = new HashMap<>(); +} diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/AuthenticationInterceptor.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/AuthenticationInterceptor.java index 9c86d3c9f..e8e4594e7 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/AuthenticationInterceptor.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/AuthenticationInterceptor.java @@ -19,14 +19,16 @@ import ai.langstream.api.gateway.GatewayAuthenticationProviderRegistry; import ai.langstream.api.gateway.GatewayAuthenticationResult; import ai.langstream.api.gateway.GatewayRequestContext; -import ai.langstream.api.model.Application; import ai.langstream.api.model.Gateway; +import ai.langstream.apigateway.config.GatewayTestAuthenticationProperties; import ai.langstream.apigateway.websocket.handlers.AbstractHandler; +import ai.langstream.apigateway.websocket.impl.AuthenticatedGatewayRequestContextImpl; import java.net.URLDecoder; import java.nio.charset.StandardCharsets; import java.util.HashMap; import java.util.Map; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.codec.digest.DigestUtils; import org.springframework.http.HttpStatus; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; @@ -40,6 +42,24 @@ @Slf4j public class AuthenticationInterceptor implements HandshakeInterceptor { + private final GatewayAuthenticationProvider authTestProvider; + + public AuthenticationInterceptor( + GatewayTestAuthenticationProperties testAuthenticationProperties) { + if (testAuthenticationProperties.getType() != null) { + authTestProvider = + GatewayAuthenticationProviderRegistry.loadProvider( + testAuthenticationProperties.getType(), + testAuthenticationProperties.getConfiguration()); + log.info( + "Loaded test authentication provider {}", + authTestProvider.getClass().getName()); + } else { + authTestProvider = null; + log.info("No test authentication provider configured"); + } + } + @Override public boolean beforeHandshake( ServerHttpRequest request, @@ -102,85 +122,80 @@ public AuthFailedException(String message) { private Map authenticate(GatewayRequestContext gatewayRequestContext) throws AuthFailedException { + final Gateway.Authentication authentication = gatewayRequestContext.gateway().authentication(); - final Map principalValues; - if (authentication != null) { - final String provider = authentication.provider(); - final GatewayAuthenticationProvider authenticationProvider = - GatewayAuthenticationProviderRegistry.loadProvider( - provider, authentication.configuration()); - final GatewayAuthenticationResult result = - authenticationProvider.authenticate(gatewayRequestContext); - if (!result.authenticated()) { - throw new AuthFailedException(result.reason()); + if (authentication == null) { + return Map.of(); + } + + final GatewayAuthenticationResult result; + if (gatewayRequestContext.isTestMode()) { + if (!authentication.isAllowTestMode()) { + throw new AuthFailedException( + "Gateway " + + gatewayRequestContext.gateway().id() + + " of tenant " + + gatewayRequestContext.tenant() + + " does not allow test mode."); + } + if (authTestProvider == null) { + throw new AuthFailedException("No test auth provider specified"); } - principalValues = result.principalValues(); + result = authTestProvider.authenticate(gatewayRequestContext); } else { - principalValues = Map.of(); + final String provider = authentication.getProvider(); + final GatewayAuthenticationProvider authProvider = + GatewayAuthenticationProviderRegistry.loadProvider( + provider, authentication.getConfiguration()); + result = authProvider.authenticate(gatewayRequestContext); } - if (principalValues == null) { - return Map.of(); + if (result == null) { + throw new AuthFailedException("Authentication provider returned null"); + } + if (!result.authenticated()) { + throw new AuthFailedException(result.reason()); + } + return getPrincipalValues(result, gatewayRequestContext); + } + + private Map getPrincipalValues( + GatewayAuthenticationResult result, GatewayRequestContext context) { + if (!context.isTestMode()) { + final Map values = result.principalValues(); + if (values == null) { + return Map.of(); + } + return values; + } else { + final Map values = new HashMap<>(); + final String principalSubject = DigestUtils.sha256Hex(context.credentials()); + final int principalNumericId = principalSubject.hashCode(); + final String principalEmail = "%s@locahost".formatted(principalSubject); + + // google + values.putIfAbsent("subject", principalSubject); + values.putIfAbsent("email", principalEmail); + values.putIfAbsent("name", principalSubject); + + // github + values.putIfAbsent("login", principalSubject); + values.putIfAbsent("id", principalNumericId + ""); + return values; } - return principalValues; } private AuthenticatedGatewayRequestContext getAuthenticatedGatewayRequestContext( GatewayRequestContext gatewayRequestContext, Map principalValues, Map attributes) { - return new AuthenticatedGatewayRequestContext() { - @Override - public Map principalValues() { - return principalValues; - } - - @Override - public String tenant() { - return gatewayRequestContext.tenant(); - } - - @Override - public Map attributes() { - return attributes; - } - - @Override - public String applicationId() { - return gatewayRequestContext.applicationId(); - } - - @Override - public Application application() { - return gatewayRequestContext.application(); - } - @Override - public Gateway gateway() { - return gatewayRequestContext.gateway(); - } - - @Override - public String credentials() { - return gatewayRequestContext.credentials(); - } - - @Override - public Map userParameters() { - return gatewayRequestContext.userParameters(); - } - - @Override - public Map options() { - return gatewayRequestContext.options(); - } - - @Override - public Map httpHeaders() { - return gatewayRequestContext.httpHeaders(); - } - }; + return AuthenticatedGatewayRequestContextImpl.builder() + .gatewayRequestContext(gatewayRequestContext) + .attributes(attributes) + .principalValues(principalValues) + .build(); } @Override diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/WebSocketConfig.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/WebSocketConfig.java index 1d41a1548..3f9b780e5 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/WebSocketConfig.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/WebSocketConfig.java @@ -16,6 +16,7 @@ package ai.langstream.apigateway.websocket; import ai.langstream.api.storage.ApplicationStore; +import ai.langstream.apigateway.config.GatewayTestAuthenticationProperties; import ai.langstream.apigateway.websocket.handlers.ConsumeHandler; import ai.langstream.apigateway.websocket.handlers.ProduceHandler; import jakarta.annotation.PreDestroy; @@ -41,6 +42,7 @@ public class WebSocketConfig implements WebSocketConfigurer { public static final String PRODUCE_PATH = "/v1/produce/{tenant}/{application}/{gateway}"; private final ApplicationStore applicationStore; + private final GatewayTestAuthenticationProperties adminAuthenticationProperties; private final ExecutorService consumeThreadPool = Executors.newCachedThreadPool(); @Override @@ -49,7 +51,8 @@ public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { .addHandler(new ProduceHandler(applicationStore), PRODUCE_PATH) .setAllowedOrigins("*") .addInterceptors( - new HttpSessionHandshakeInterceptor(), new AuthenticationInterceptor()); + new HttpSessionHandshakeInterceptor(), + new AuthenticationInterceptor(adminAuthenticationProperties)); } @Bean diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java index ddcc88629..abfefa7fb 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/AbstractHandler.java @@ -30,6 +30,7 @@ import ai.langstream.api.runner.topics.TopicProducer; import ai.langstream.api.storage.ApplicationStore; import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext; +import ai.langstream.apigateway.websocket.impl.GatewayRequestContextImpl; import ai.langstream.impl.common.ApplicationPlaceholderResolver; import com.fasterxml.jackson.databind.ObjectMapper; import java.io.IOException; @@ -197,6 +198,7 @@ public GatewayRequestContext validateRequest( Map userParameters = new HashMap<>(); final String credentials = queryString.remove("credentials"); + final String testCredentials = queryString.remove("test-credentials"); for (Map.Entry entry : queryString.entrySet()) { if (entry.getKey().startsWith("option:")) { @@ -238,48 +240,21 @@ public GatewayRequestContext validateRequest( } validateOptions(options); - return new GatewayRequestContext() { - - @Override - public String tenant() { - return tenant; - } - - @Override - public String applicationId() { - return applicationId; - } - - @Override - public Application application() { - return application; - } - - @Override - public Gateway gateway() { - return gateway; - } - - @Override - public String credentials() { - return credentials; - } - - @Override - public Map userParameters() { - return userParameters; - } - - @Override - public Map options() { - return options; - } - - @Override - public Map httpHeaders() { - return httpHeaders; - } - }; + if (credentials != null && testCredentials != null) { + throw new IllegalArgumentException( + "credentials and test-credentials cannot be used together"); + } + return GatewayRequestContextImpl.builder() + .tenant(tenant) + .applicationId(applicationId) + .application(application) + .credentials(credentials) + .testCredentials(testCredentials) + .httpHeaders(httpHeaders) + .options(options) + .userParameters(userParameters) + .gateway(gateway) + .build(); } protected void recordCloseableResource( diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ProduceHandler.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ProduceHandler.java index aa1efd5f7..98f55d3af 100644 --- a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ProduceHandler.java +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/handlers/ProduceHandler.java @@ -228,7 +228,8 @@ private List
getCommonHeaders( value = principalValues.get(mapping.valueFromAuthentication()); } if (value == null) { - throw new IllegalArgumentException(mapping.key() + "header cannot be empty"); + throw new IllegalArgumentException( + "header " + mapping.key() + " cannot be empty"); } headers.add(SimpleRecord.SimpleHeader.of(mapping.key(), value)); diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/impl/AuthenticatedGatewayRequestContextImpl.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/impl/AuthenticatedGatewayRequestContextImpl.java new file mode 100644 index 000000000..9c75f47ce --- /dev/null +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/impl/AuthenticatedGatewayRequestContextImpl.java @@ -0,0 +1,86 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.apigateway.websocket.impl; + +import ai.langstream.api.gateway.GatewayRequestContext; +import ai.langstream.api.model.Application; +import ai.langstream.api.model.Gateway; +import ai.langstream.apigateway.websocket.AuthenticatedGatewayRequestContext; +import java.util.Map; +import lombok.Builder; + +@Builder +public class AuthenticatedGatewayRequestContextImpl implements AuthenticatedGatewayRequestContext { + + private final GatewayRequestContext gatewayRequestContext; + private final Map attributes; + private final Map principalValues; + + @Override + public Map attributes() { + return attributes; + } + + @Override + public Map principalValues() { + return principalValues; + } + + @Override + public String tenant() { + return gatewayRequestContext.tenant(); + } + + @Override + public String applicationId() { + return gatewayRequestContext.applicationId(); + } + + @Override + public Application application() { + return gatewayRequestContext.application(); + } + + @Override + public Gateway gateway() { + return gatewayRequestContext.gateway(); + } + + @Override + public String credentials() { + return gatewayRequestContext.credentials(); + } + + @Override + public boolean isTestMode() { + return gatewayRequestContext.isTestMode(); + } + + @Override + public Map userParameters() { + return gatewayRequestContext.userParameters(); + } + + @Override + public Map options() { + return gatewayRequestContext.options(); + } + + @Override + public Map httpHeaders() { + return gatewayRequestContext.httpHeaders(); + } +} diff --git a/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/impl/GatewayRequestContextImpl.java b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/impl/GatewayRequestContextImpl.java new file mode 100644 index 000000000..233c3e80b --- /dev/null +++ b/langstream-api-gateway/src/main/java/ai/langstream/apigateway/websocket/impl/GatewayRequestContextImpl.java @@ -0,0 +1,86 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.apigateway.websocket.impl; + +import ai.langstream.api.gateway.GatewayRequestContext; +import ai.langstream.api.model.Application; +import ai.langstream.api.model.Gateway; +import java.util.Map; +import lombok.AllArgsConstructor; +import lombok.Builder; + +@Builder +@AllArgsConstructor +public class GatewayRequestContextImpl implements GatewayRequestContext { + + private final String tenant; + private final String applicationId; + private final Application application; + private final Gateway gateway; + private final String credentials; + private final String testCredentials; + private final Map userParameters; + private final Map options; + private final Map httpHeaders; + + @Override + public String tenant() { + return tenant; + } + + @Override + public String applicationId() { + return applicationId; + } + + @Override + public Application application() { + return application; + } + + @Override + public Gateway gateway() { + return gateway; + } + + @Override + public String credentials() { + if (isTestMode()) { + return testCredentials; + } + return credentials; + } + + @Override + public boolean isTestMode() { + return testCredentials != null; + } + + @Override + public Map userParameters() { + return userParameters; + } + + @Override + public Map options() { + return options; + } + + @Override + public Map httpHeaders() { + return httpHeaders; + } +} diff --git a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/TestGatewayAuthenticationProvider.java b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/TestGatewayAuthenticationProvider.java index fe276c423..c3782954b 100644 --- a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/TestGatewayAuthenticationProvider.java +++ b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/TestGatewayAuthenticationProvider.java @@ -37,7 +37,7 @@ public GatewayAuthenticationResult authenticate(GatewayRequestContext context) { log.info("Authenticating {}", context.credentials()); if (context.credentials().startsWith("test-user-password")) { return GatewayAuthenticationResult.authenticationSuccessful( - Map.of("user-id", context.credentials())); + Map.of("login", context.credentials())); } else { return GatewayAuthenticationResult.authenticationFailed("Invalid credentials"); } diff --git a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/handlers/ProduceConsumeHandlerTest.java b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/handlers/ProduceConsumeHandlerTest.java index 47f3f7c5b..8fd302237 100644 --- a/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/handlers/ProduceConsumeHandlerTest.java +++ b/langstream-api-gateway/src/test/java/ai/langstream/apigateway/websocket/handlers/ProduceConsumeHandlerTest.java @@ -36,6 +36,7 @@ import ai.langstream.api.runtime.ClusterRuntimeRegistry; import ai.langstream.api.runtime.PluginsRegistry; import ai.langstream.api.storage.ApplicationStore; +import ai.langstream.apigateway.config.GatewayTestAuthenticationProperties; import ai.langstream.apigateway.websocket.api.ConsumePushMessage; import ai.langstream.apigateway.websocket.api.ProduceRequest; import ai.langstream.apigateway.websocket.api.ProduceResponse; @@ -45,6 +46,9 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; +import com.github.tomakehurst.wiremock.client.WireMock; +import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; +import com.github.tomakehurst.wiremock.junit5.WireMockTest; import jakarta.websocket.CloseReason; import jakarta.websocket.DeploymentException; import jakarta.websocket.Session; @@ -62,6 +66,7 @@ import org.awaitility.Awaitility; import org.jetbrains.annotations.NotNull; import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; @@ -77,7 +82,10 @@ @SpringBootTest( webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT, - properties = {"spring.main.allow-bean-definition-overriding=true"}) + properties = { + "spring.main.allow-bean-definition-overriding=true", + }) +@WireMockTest class ProduceConsumeHandlerTest { protected static final ObjectMapper MAPPER = new ObjectMapper(); @@ -110,6 +118,23 @@ public ApplicationStore store() { return mock; } + + @Bean + @Primary + public GatewayTestAuthenticationProperties gatewayTestAuthenticationProperties() { + final GatewayTestAuthenticationProperties props = + new GatewayTestAuthenticationProperties(); + props.setType("http"); + props.setConfiguration( + Map.of( + "base-url", + wireMockBaseUrl, + "path-template", + "/auth/{tenant}", + "headers", + Map.of("h1", "v1"))); + return props; + } } @NotNull @@ -158,8 +183,17 @@ private static Application buildApp() throws Exception { @Autowired ApplicationStore store; + static WireMock wireMock; + static String wireMockBaseUrl; + + @BeforeAll + public static void beforeAll(WireMockRuntimeInfo wmRuntimeInfo) { + wireMock = wmRuntimeInfo.getWireMock(); + wireMockBaseUrl = wmRuntimeInfo.getHttpBaseUrl(); + } + @BeforeEach - public void beforeEach() { + public void beforeEach(WireMockRuntimeInfo wmRuntimeInfo) { testGateways = null; topics = null; Awaitility.setDefaultTimeout(30, TimeUnit.SECONDS); @@ -424,19 +458,19 @@ void testAuthentication() { "produce", Gateway.GatewayType.produce, topic, - new Gateway.Authentication("test-auth", Map.of()), + new Gateway.Authentication("test-auth", Map.of(), true), List.of(), new Gateway.ProduceOptions( List.of( Gateway.KeyValueComparison .valueFromAuthentication( - "header1", "user-id"))), + "header1", "login"))), null), new Gateway( "consume", Gateway.GatewayType.consume, topic, - new Gateway.Authentication("test-auth", Map.of()), + new Gateway.Authentication("test-auth", Map.of(), true), List.of(), null, new Gateway.ConsumeOptions( @@ -445,7 +479,7 @@ void testAuthentication() { Gateway.KeyValueComparison .valueFromAuthentication( "header1", - "user-id"))))))); + "login"))))))); connectAndExpectClose( URI.create( @@ -510,6 +544,98 @@ void testAuthentication() { assertEquals(List.of(), user2Messages); } + @Test + void testTestCredentials() { + wireMock.register( + WireMock.get("/auth/tenant1") + .withHeader("Authorization", WireMock.equalTo("Bearer test-user-password")) + .withHeader("h1", WireMock.equalTo("v1")) + .willReturn(WireMock.ok(""))); + final String topic = genTopic(); + prepareTopicsForTest(topic); + + List user1Messages = new ArrayList<>(); + + testGateways = + new Gateways( + List.of( + new Gateway( + "produce", + Gateway.GatewayType.produce, + topic, + new Gateway.Authentication("test-auth", Map.of(), true), + List.of(), + new Gateway.ProduceOptions( + List.of( + Gateway.KeyValueComparison + .valueFromAuthentication( + "header1", "login"))), + null), + new Gateway( + "consume", + Gateway.GatewayType.consume, + topic, + new Gateway.Authentication("test-auth", Map.of(), true), + List.of(), + null, + new Gateway.ConsumeOptions( + new Gateway.ConsumeOptionsFilters( + List.of( + Gateway.KeyValueComparison + .valueFromAuthentication( + "header1", + "login"))))), + new Gateway( + "consume-no-test", + Gateway.GatewayType.consume, + topic, + new Gateway.Authentication("test-auth", Map.of(), false), + List.of(), + null, + null))); + + @Cleanup + final ClientSession client1 = + connectAndCollectMessages( + URI.create( + "ws://localhost:%d/v1/consume/tenant1/application1/consume?test-credentials=test-user-password" + .formatted(port)), + user1Messages); + + connectAndProduce( + URI.create( + "ws://localhost:%d/v1/produce/tenant1/application1/produce?test-credentials=test-user-password" + .formatted(port)), + new ProduceRequest(null, "hello user", null)); + + Awaitility.await() + .untilAsserted( + () -> + assertMessagesContent( + List.of( + new MsgRecord( + null, + "hello user", + Map.of( + "header1", + "9d75ff199d33e051209b59702de27d1e470eafb58ac6d8865788bf23b48e6818"))), + user1Messages)); + + connectAndExpectClose( + URI.create( + "ws://localhost:%d/v1/consume/tenant1/application1/consume-no-admin?test-credentials=test-user-password" + .formatted(port)), + new CloseReason( + CloseReason.CloseCodes.VIOLATED_POLICY, + "Gateway consume-no-test of tenant tenant1 does not allow test mode.")); + + connectAndExpectClose( + URI.create( + "ws://localhost:%d/v1/produce/tenant1/application1/produce?test-credentials=test-user-password-but-wrong" + .formatted(port)), + new CloseReason(CloseReason.CloseCodes.VIOLATED_POLICY, "Invalid credentials")); + } + private record MsgRecord(Object key, Object value, Map headers) {} private void assertMessagesContent(List expected, List actual) { diff --git a/langstream-api/src/main/java/ai/langstream/api/gateway/GatewayRequestContext.java b/langstream-api/src/main/java/ai/langstream/api/gateway/GatewayRequestContext.java index e25a8451e..c70f70ff0 100644 --- a/langstream-api/src/main/java/ai/langstream/api/gateway/GatewayRequestContext.java +++ b/langstream-api/src/main/java/ai/langstream/api/gateway/GatewayRequestContext.java @@ -31,6 +31,8 @@ public interface GatewayRequestContext { String credentials(); + boolean isTestMode(); + Map userParameters(); Map options(); diff --git a/langstream-api/src/main/java/ai/langstream/api/model/Gateway.java b/langstream-api/src/main/java/ai/langstream/api/model/Gateway.java index 5a9ec9b17..8772ea9bf 100644 --- a/langstream-api/src/main/java/ai/langstream/api/model/Gateway.java +++ b/langstream-api/src/main/java/ai/langstream/api/model/Gateway.java @@ -19,6 +19,9 @@ import com.fasterxml.jackson.annotation.JsonProperty; import java.util.List; import java.util.Map; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; public record Gateway( String id, @@ -55,10 +58,22 @@ public Gateway( this(id, type, topic, authentication, parameters, produceOptions, consumeOptions, null); } - public record Authentication(String provider, Map configuration) {} + @Data + @NoArgsConstructor + @AllArgsConstructor + public static class Authentication { + private String provider; + private Map configuration; + + @JsonProperty("allow-test-mode") + private boolean allowTestMode = true; + } public record KeyValueComparison( - String key, String value, String valueFromParameters, String valueFromAuthentication) { + String key, + String value, + @JsonAlias({"value-from-parameters"}) String valueFromParameters, + @JsonAlias({"value-from-authentication"}) String valueFromAuthentication) { public static KeyValueComparison value(String key, String value) { return new KeyValueComparison(key, value, null, null); } diff --git a/langstream-auth-jwt/pom.xml b/langstream-auth-jwt/pom.xml new file mode 100644 index 000000000..7d5b4b111 --- /dev/null +++ b/langstream-auth-jwt/pom.xml @@ -0,0 +1,71 @@ + + + + + langstream-ai + ai.langstream + 0.0.16-SNAPSHOT + + 4.0.0 + + langstream-auth-jwt + + + + com.fasterxml.jackson.core + jackson-databind + + + org.slf4j + slf4j-api + provided + + + org.apache.commons + commons-lang3 + + + commons-codec + commons-codec + + + commons-io + commons-io + + + + io.jsonwebtoken + jjwt-api + + + io.jsonwebtoken + jjwt-impl + + + io.jsonwebtoken + jjwt-jackson + + + + + + + \ No newline at end of file diff --git a/langstream-webservice/src/main/java/ai/langstream/webservice/security/infrastructure/primary/AuthenticationProviderToken.java b/langstream-auth-jwt/src/main/java/ai/langstream/auth/jwt/AuthenticationProviderToken.java similarity index 93% rename from langstream-webservice/src/main/java/ai/langstream/webservice/security/infrastructure/primary/AuthenticationProviderToken.java rename to langstream-auth-jwt/src/main/java/ai/langstream/auth/jwt/AuthenticationProviderToken.java index d08d8b3bb..3152e96da 100644 --- a/langstream-webservice/src/main/java/ai/langstream/webservice/security/infrastructure/primary/AuthenticationProviderToken.java +++ b/langstream-auth-jwt/src/main/java/ai/langstream/auth/jwt/AuthenticationProviderToken.java @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package ai.langstream.webservice.security.infrastructure.primary; +package ai.langstream.auth.jwt; -import ai.langstream.webservice.config.AuthTokenProperties; import io.jsonwebtoken.Claims; import io.jsonwebtoken.Jwt; import io.jsonwebtoken.JwtException; @@ -59,7 +58,7 @@ public AuthenticationException(String message) { private final String audienceClaim; private final String audience; - public AuthenticationProviderToken(AuthTokenProperties tokenProperties) + public AuthenticationProviderToken(JwtProperties tokenProperties) throws IOException, IllegalArgumentException { this.publicKeyAlg = getPublicKeyAlgType(tokenProperties); parser = @@ -151,7 +150,7 @@ private String getPrincipal(Jwt jwt) { } } - private Key getValidationKeyFromConfig(AuthTokenProperties tokenProperties) throws IOException { + private Key getValidationKeyFromConfig(JwtProperties tokenProperties) throws IOException { String tokenSecretKey = tokenProperties.secretKey(); String tokenPublicKey = tokenProperties.publicKey(); byte[] validationKey; @@ -196,12 +195,12 @@ private static SecretKey decodeSecretKey(byte[] secretKey) { return Keys.hmacShaKeyFor(secretKey); } - private String getTokenRoleClaim(AuthTokenProperties tokenProperties) { + private String getTokenRoleClaim(JwtProperties tokenProperties) { String tokenAuthClaim = tokenProperties.authClaim(); return StringUtils.isNotBlank(tokenAuthClaim) ? tokenAuthClaim : "sub"; } - private SignatureAlgorithm getPublicKeyAlgType(AuthTokenProperties tokenProperties) + private SignatureAlgorithm getPublicKeyAlgType(JwtProperties tokenProperties) throws IllegalArgumentException { String tokenPublicAlg = tokenProperties.publicAlg(); if (StringUtils.isNotBlank(tokenPublicAlg)) { @@ -238,14 +237,13 @@ private static String keyTypeForSignatureAlgorithm(SignatureAlgorithm alg) { } } - private String getTokenAudienceClaim(AuthTokenProperties tokenProperties) + private String getTokenAudienceClaim(JwtProperties tokenProperties) throws IllegalArgumentException { String tokenAudienceClaim = tokenProperties.audienceClaim(); return StringUtils.isNotBlank(tokenAudienceClaim) ? tokenAudienceClaim : null; } - private String getTokenAudience(AuthTokenProperties tokenProperties) - throws IllegalArgumentException { + private String getTokenAudience(JwtProperties tokenProperties) throws IllegalArgumentException { String tokenAudience = tokenProperties.audience(); return StringUtils.isNotBlank(tokenAudience) ? tokenAudience : null; } diff --git a/langstream-webservice/src/main/java/ai/langstream/webservice/security/infrastructure/primary/JwksUriSigningKeyResolver.java b/langstream-auth-jwt/src/main/java/ai/langstream/auth/jwt/JwksUriSigningKeyResolver.java similarity index 99% rename from langstream-webservice/src/main/java/ai/langstream/webservice/security/infrastructure/primary/JwksUriSigningKeyResolver.java rename to langstream-auth-jwt/src/main/java/ai/langstream/auth/jwt/JwksUriSigningKeyResolver.java index 3e06d5eb2..051be17ad 100644 --- a/langstream-webservice/src/main/java/ai/langstream/webservice/security/infrastructure/primary/JwksUriSigningKeyResolver.java +++ b/langstream-auth-jwt/src/main/java/ai/langstream/auth/jwt/JwksUriSigningKeyResolver.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package ai.langstream.webservice.security.infrastructure.primary; +package ai.langstream.auth.jwt; import com.fasterxml.jackson.databind.ObjectMapper; import io.jsonwebtoken.Claims; diff --git a/langstream-auth-jwt/src/main/java/ai/langstream/auth/jwt/JwtProperties.java b/langstream-auth-jwt/src/main/java/ai/langstream/auth/jwt/JwtProperties.java new file mode 100644 index 000000000..c66919e72 --- /dev/null +++ b/langstream-auth-jwt/src/main/java/ai/langstream/auth/jwt/JwtProperties.java @@ -0,0 +1,25 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.auth.jwt; + +public record JwtProperties( + String secretKey, + String publicKey, + String authClaim, + String publicAlg, + String audienceClaim, + String audience, + String jwksHostsAllowlist) {} diff --git a/langstream-webservice/src/main/java/ai/langstream/webservice/security/infrastructure/primary/LocalKubernetesJwksUriSigningKeyResolver.java b/langstream-auth-jwt/src/main/java/ai/langstream/auth/jwt/LocalKubernetesJwksUriSigningKeyResolver.java similarity index 99% rename from langstream-webservice/src/main/java/ai/langstream/webservice/security/infrastructure/primary/LocalKubernetesJwksUriSigningKeyResolver.java rename to langstream-auth-jwt/src/main/java/ai/langstream/auth/jwt/LocalKubernetesJwksUriSigningKeyResolver.java index 5b226a71f..03067d35e 100644 --- a/langstream-webservice/src/main/java/ai/langstream/webservice/security/infrastructure/primary/LocalKubernetesJwksUriSigningKeyResolver.java +++ b/langstream-auth-jwt/src/main/java/ai/langstream/auth/jwt/LocalKubernetesJwksUriSigningKeyResolver.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package ai.langstream.webservice.security.infrastructure.primary; +package ai.langstream.auth.jwt; import com.fasterxml.jackson.databind.ObjectMapper; import java.io.IOException; diff --git a/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/BaseGatewayCmd.java b/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/BaseGatewayCmd.java index c5e41948f..76e8a4ac2 100644 --- a/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/BaseGatewayCmd.java +++ b/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/BaseGatewayCmd.java @@ -42,9 +42,12 @@ protected RootCmd getRootCmd() { } private static String computeQueryString( - String credentials, Map userParams, Map options) { + Map systemParams, + Map userParams, + Map options) { String paramsPart = ""; String optionsPart = ""; + String systemParamsPart = ""; if (userParams != null) { paramsPart = userParams.entrySet().stream() @@ -59,12 +62,14 @@ private static String computeQueryString( .collect(Collectors.joining("&")); } - String credentialsPart = ""; - if (credentials != null) { - credentialsPart = encodeParam("credentials", credentials, ""); + if (systemParams != null) { + systemParamsPart = + systemParams.entrySet().stream() + .map(e -> encodeParam(e, "")) + .collect(Collectors.joining("&")); } - return String.join("&", List.of(credentialsPart, paramsPart, optionsPart)); + return String.join("&", List.of(systemParamsPart, paramsPart, optionsPart)); } private static String encodeParam(Map.Entry e, String prefix) { @@ -83,8 +88,18 @@ protected String validateGatewayAndGetUrl( String type, Map params, Map options, - String credentials) { - validateGateway(applicationId, gatewayId, type, params, options, credentials); + String credentials, + String testCredentials) { + validateGateway( + applicationId, gatewayId, type, params, options, credentials, testCredentials); + + Map systemParams = new HashMap<>(); + if (credentials != null) { + systemParams.put("credentials", credentials); + } + if (testCredentials != null) { + systemParams.put("test-credentials", testCredentials); + } return String.format( "%s/v1/%s/%s/%s/%s?%s", getApiGatewayUrl(), @@ -92,7 +107,7 @@ protected String validateGatewayAndGetUrl( getTenant(), applicationId, gatewayId, - computeQueryString(credentials, params, options)); + computeQueryString(systemParams, params, options)); } private String getTenant() { @@ -112,7 +127,8 @@ protected void validateGateway( String type, Map params, Map options, - String credentials) { + String credentials, + String testCredentials) { final AdminClient client = getClient(); @@ -156,10 +172,26 @@ protected void validateGateway( } } if (selectedGateway.getAuthentication() != null) { - if (credentials == null) { + if (credentials == null && testCredentials == null) { throw new IllegalArgumentException( "gateway " + gatewayId + " of type " + type + " requires credentials"); } + if (testCredentials != null) { + final Object allowTestMode = + selectedGateway.getAuthentication().get("allow-test-mode"); + if (allowTestMode != null && allowTestMode.toString().equals("false")) { + throw new IllegalArgumentException( + "gateway " + + gatewayId + + " of type " + + type + + " do not allow test mode."); + } + } + } + if (credentials != null && testCredentials != null) { + throw new IllegalArgumentException( + "credentials and test-credentials cannot be used together"); } } } diff --git a/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ChatGatewayCmd.java b/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ChatGatewayCmd.java index 48f9eea52..23585b350 100644 --- a/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ChatGatewayCmd.java +++ b/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ChatGatewayCmd.java @@ -62,6 +62,11 @@ public class ChatGatewayCmd extends BaseGatewayCmd { "Credentials for the gateway. Required if the gateway requires authentication.") private String credentials; + @CommandLine.Option( + names = {"-tc", "--test-credentials"}, + description = "Test credentials for the gateway.") + private String testCredentials; + @CommandLine.Option( names = {"--connect-timeout"}, description = "Connect timeout for WebSocket connections in seconds.") @@ -78,7 +83,8 @@ public void run() { Gateways.Gateway.TYPE_CONSUME, params, consumeGatewayOptions, - credentials); + credentials, + testCredentials); final String producePath = validateGatewayAndGetUrl( applicationId, @@ -86,7 +92,8 @@ public void run() { Gateways.Gateway.TYPE_PRODUCE, params, Map.of(), - credentials); + credentials, + testCredentials); final Duration connectTimeout = connectTimeoutSeconds > 0 ? Duration.ofSeconds(connectTimeoutSeconds) : null; diff --git a/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ConsumeGatewayCmd.java b/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ConsumeGatewayCmd.java index e33d140d4..d09dc839f 100644 --- a/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ConsumeGatewayCmd.java +++ b/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ConsumeGatewayCmd.java @@ -46,6 +46,11 @@ public class ConsumeGatewayCmd extends BaseGatewayCmd { "Credentials for the gateway. Required if the gateway requires authentication.") private String credentials; + @CommandLine.Option( + names = {"-tc", "--test-credentials"}, + description = "Test credentials for the gateway.") + private String testCredentials; + @CommandLine.Option( names = {"--position"}, description = @@ -77,7 +82,8 @@ public void run() { Gateways.Gateway.TYPE_CONSUME, params, options, - credentials); + credentials, + testCredentials); final Duration connectTimeout = connectTimeoutSeconds > 0 ? Duration.ofSeconds(connectTimeoutSeconds) : null; diff --git a/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ProduceGatewayCmd.java b/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ProduceGatewayCmd.java index b1d57e274..4627b2332 100644 --- a/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ProduceGatewayCmd.java +++ b/langstream-cli/src/main/java/ai/langstream/cli/commands/gateway/ProduceGatewayCmd.java @@ -77,6 +77,11 @@ static class ProduceRequest { description = "Connect timeout for WebSocket connections in seconds.") private long connectTimeoutSeconds = 0; + @CommandLine.Option( + names = {"-tc", "--test-credentials"}, + description = "Test credentials for the gateway.") + private String testCredentials; + @Override @SneakyThrows public void run() { @@ -87,7 +92,8 @@ public void run() { Gateways.Gateway.TYPE_PRODUCE, params, Map.of(), - credentials); + credentials, + testCredentials); final Duration connectTimeout = connectTimeoutSeconds > 0 ? Duration.ofSeconds(connectTimeoutSeconds) : null; CountDownLatch countDownLatch = new CountDownLatch(1); diff --git a/langstream-core/src/main/java/ai/langstream/impl/common/ApplicationPlaceholderResolver.java b/langstream-core/src/main/java/ai/langstream/impl/common/ApplicationPlaceholderResolver.java index f83673a32..51d6ba119 100644 --- a/langstream-core/src/main/java/ai/langstream/impl/common/ApplicationPlaceholderResolver.java +++ b/langstream-core/src/main/java/ai/langstream/impl/common/ApplicationPlaceholderResolver.java @@ -177,11 +177,12 @@ private static Gateways resolveGateways(Application instance, Map gateways = applicationInstance.getGateways().gateways(); + Assertions.assertEquals(1, gateways.size()); + final Gateway gateway = gateways.get(0); + assertEquals("gw", gateway.id()); + assertEquals("t1", gateway.topic()); + assertEquals("google", gateway.authentication().getProvider()); + assertTrue(gateway.authentication().isAllowTestMode()); + } } diff --git a/langstream-webservice/pom.xml b/langstream-webservice/pom.xml index d393e6bb9..92ea5e848 100644 --- a/langstream-webservice/pom.xml +++ b/langstream-webservice/pom.xml @@ -38,6 +38,11 @@ + + ${project.groupId} + langstream-auth-jwt + ${project.version} + ${project.groupId} langstream-core @@ -152,18 +157,7 @@ zip4j - - io.jsonwebtoken - jjwt-api - - - io.jsonwebtoken - jjwt-impl - - - io.jsonwebtoken - jjwt-jackson - + commons-codec diff --git a/langstream-webservice/src/main/java/ai/langstream/webservice/security/infrastructure/primary/TokenAuthFilter.java b/langstream-webservice/src/main/java/ai/langstream/webservice/security/infrastructure/primary/TokenAuthFilter.java index fe57b52d4..cf132c71f 100644 --- a/langstream-webservice/src/main/java/ai/langstream/webservice/security/infrastructure/primary/TokenAuthFilter.java +++ b/langstream-webservice/src/main/java/ai/langstream/webservice/security/infrastructure/primary/TokenAuthFilter.java @@ -15,6 +15,8 @@ */ package ai.langstream.webservice.security.infrastructure.primary; +import ai.langstream.auth.jwt.AuthenticationProviderToken; +import ai.langstream.auth.jwt.JwtProperties; import ai.langstream.webservice.config.AuthTokenProperties; import jakarta.servlet.FilterChain; import jakarta.servlet.ServletException; @@ -45,7 +47,18 @@ public class TokenAuthFilter extends GenericFilterBean { @SneakyThrows public TokenAuthFilter(AuthTokenProperties tokenProperties) { this.tokenProperties = tokenProperties; - this.authenticationProvider = new AuthenticationProviderToken(tokenProperties); + + final JwtProperties jwtProperties = + new JwtProperties( + tokenProperties.secretKey(), + tokenProperties.publicKey(), + tokenProperties.authClaim(), + tokenProperties.publicAlg(), + tokenProperties.audienceClaim(), + tokenProperties.audience(), + tokenProperties.jwksHostsAllowlist()); + + this.authenticationProvider = new AuthenticationProviderToken(jwtProperties); } @Override diff --git a/pom.xml b/pom.xml index 39579026d..f11528c4f 100644 --- a/pom.xml +++ b/pom.xml @@ -731,6 +731,7 @@ langstream-admin-client langstream-api + langstream-auth-jwt langstream-cli langstream-core langstream-agents