Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vector database sink config validation and documentation #583

Merged
merged 7 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
import ai.langstream.api.database.VectorDatabaseWriterProvider;
import ai.langstream.api.runner.code.Record;
import ai.langstream.api.util.ConfigurationUtils;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.util.ArrayList;
Expand All @@ -38,9 +36,6 @@
@Slf4j
public class JdbcWriter implements VectorDatabaseWriterProvider {

private static final ObjectMapper MAPPER =
new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);

@Override
public boolean supports(Map<String, Object> dataSourceConfig) {
return "jdbc".equals(dataSourceConfig.get("service"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
@AllArgsConstructor
public class AgentConfigurationModel {

private String type;
private String name;
private String description;
private Map<String, ConfigPropertyModel> properties;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,8 @@ public static void validateAgentModelFromClass(
Class modelClazz,
Map<String, Object> asMap,
boolean allowUnknownProperties) {
final EntityRef ref =
() ->
"agent configuration (agent: '%s', type: '%s')"
.formatted(
agentConfiguration.getName() == null
? agentConfiguration.getId()
: agentConfiguration.getName(),
agentConfiguration.getType());
validateModelFromClass(ref, modelClazz, asMap, allowUnknownProperties);
validateModelFromClass(
new AgentEntityRef(agentConfiguration), modelClazz, asMap, allowUnknownProperties);
}

@AllArgsConstructor
Expand Down Expand Up @@ -199,6 +192,22 @@ public String ref() {
}
}

@AllArgsConstructor
public static class AgentEntityRef implements EntityRef {

private final AgentConfiguration agentConfiguration;

@Override
public String ref() {
return "agent configuration (agent: '%s', type: '%s')"
.formatted(
agentConfiguration.getName() == null
? agentConfiguration.getId()
: agentConfiguration.getName(),
agentConfiguration.getType());
}
}

@SneakyThrows
public static void validateAssetModelFromClass(
AssetDefinition asset,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package ai.langstream.runtime.impl.k8s.agents;

import ai.langstream.api.doc.AgentConfig;
import ai.langstream.api.doc.AgentConfigurationModel;
import ai.langstream.api.doc.ConfigProperty;
import ai.langstream.api.model.AgentConfiguration;
import ai.langstream.api.model.Application;
Expand All @@ -28,18 +29,58 @@
import ai.langstream.api.runtime.PluginsRegistry;
import ai.langstream.impl.agents.AbstractComposableAgentProvider;
import ai.langstream.impl.agents.ai.steps.QueryConfiguration;
import ai.langstream.impl.uti.ClassConfigValidator;
import ai.langstream.runtime.impl.k8s.KubernetesClusterRuntime;
import ai.langstream.runtime.impl.k8s.agents.vectors.CassandraVectorDatabaseWriterConfig;
import ai.langstream.runtime.impl.k8s.agents.vectors.JDBCVectorDatabaseWriterConfig;
import ai.langstream.runtime.impl.k8s.agents.vectors.MilvusVectorDatabaseWriterConfig;
import ai.langstream.runtime.impl.k8s.agents.vectors.OpenSearchVectorDatabaseWriterConfig;
import ai.langstream.runtime.impl.k8s.agents.vectors.PineconeVectorDatabaseWriterConfig;
import ai.langstream.runtime.impl.k8s.agents.vectors.SolrVectorDatabaseWriterConfig;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import lombok.Data;
import lombok.Getter;
import lombok.Setter;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class QueryVectorDBAgentProvider extends AbstractComposableAgentProvider {

protected static final ObjectMapper MAPPER = new ObjectMapper();

@Getter
@Setter
public abstract static class VectorDatabaseWriterConfig {
@ConfigProperty(
description =
"""
The defined datasource ID to use to store the vectors.
""",
required = true)
String datasource;

public abstract Class getAgentConfigModelClass();

public abstract boolean isAgentConfigModelAllowUnknownProperties();
}

protected static final String QUERY_VECTOR_DB = "query-vector-db";
protected static final String VECTOR_DB_SINK = "vector-db-sink";
protected static final Map<String, VectorDatabaseWriterConfig>
SUPPORTED_VECTOR_DB_SINK_DATASOURCES =
Map.of(
"cassandra", CassandraVectorDatabaseWriterConfig.CASSANDRA,
"astra", CassandraVectorDatabaseWriterConfig.ASTRA,
"jdbc", JDBCVectorDatabaseWriterConfig.INSTANCE,
"pinecone", PineconeVectorDatabaseWriterConfig.INSTANCE,
"opensearch", OpenSearchVectorDatabaseWriterConfig.INSTANCE,
"solr", SolrVectorDatabaseWriterConfig.INSTANCE,
"milvus", MilvusVectorDatabaseWriterConfig.INSTANCE);

public QueryVectorDBAgentProvider() {
super(
Expand Down Expand Up @@ -76,30 +117,79 @@ protected Map<String, Object> computeAgentConfiguration(
// get the datasource configuration and inject it into the agent configuration
String resourceId = (String) originalConfiguration.remove("datasource");
if (resourceId == null) {
throw new IllegalStateException(
"datasource is required but this exception should have been raised before ?");
throw new IllegalArgumentException(
ClassConfigValidator.formatErrString(
new ClassConfigValidator.AgentEntityRef(agentConfiguration),
"datasource",
"is required"));
}
generateDataSourceConfiguration(
resourceId,
executionPlan.getApplication(),
originalConfiguration,
clusterRuntime,
pluginsRegistry);
pluginsRegistry,
agentConfiguration);

return originalConfiguration;
}

private boolean isAgentConfigModelAllowUnknownProperties(String type, String service) {
switch (type) {
case QUERY_VECTOR_DB:
return false;
case VECTOR_DB_SINK:
{
final VectorDatabaseWriterConfig vectorDatabaseSinkConfig =
SUPPORTED_VECTOR_DB_SINK_DATASOURCES.get(service);
if (vectorDatabaseSinkConfig == null) {
throw new IllegalArgumentException(
"Unsupported vector database service: "
+ service
+ ". Supported services are: "
+ SUPPORTED_VECTOR_DB_SINK_DATASOURCES.keySet());
}
return vectorDatabaseSinkConfig.isAgentConfigModelAllowUnknownProperties();
}
default:
throw new IllegalStateException();
}
}

private Class getAgentConfigModelClass(String type, String service) {
switch (type) {
case QUERY_VECTOR_DB:
return QueryVectorDBConfig.class;
case VECTOR_DB_SINK:
{
final VectorDatabaseWriterConfig vectorDatabaseSinkConfig =
SUPPORTED_VECTOR_DB_SINK_DATASOURCES.get(service);
if (vectorDatabaseSinkConfig == null) {
throw new IllegalArgumentException(
"Unsupported vector database service: "
+ service
+ ". Supported services are: "
+ SUPPORTED_VECTOR_DB_SINK_DATASOURCES.keySet());
}
return vectorDatabaseSinkConfig.getAgentConfigModelClass();
}
default:
throw new IllegalStateException();
}
}

private void generateDataSourceConfiguration(
String resourceId,
Application applicationInstance,
Map<String, Object> configuration,
ComputeClusterRuntime computeClusterRuntime,
PluginsRegistry pluginsRegistry) {
PluginsRegistry pluginsRegistry,
AgentConfiguration agentConfiguration) {

Resource resource = applicationInstance.getResources().get(resourceId);
log.info("Generating datasource configuration for {}", resourceId);
if (resource != null) {
Map<String, Object> resourceImplementation =
Map<String, Object> resourceConfiguration =
computeClusterRuntime.getResourceImplementation(resource, pluginsRegistry);
if (!resource.type().equals("datasource")
&& !resource.type().equals("vector-database")) {
Expand All @@ -108,57 +198,60 @@ private void generateDataSourceConfiguration(
+ resourceId
+ "' is not type=datasource or type=vector-database");
}
if (configuration.containsKey("datasource")) {
throw new IllegalArgumentException("Only one datasource is supported");
configuration.put("datasource", resourceConfiguration);
final String type = agentConfiguration.getType();
final String service = (String) resourceConfiguration.get("service");
final Class modelClass = getAgentConfigModelClass(type, service);
if (modelClass != null) {
ClassConfigValidator.validateAgentModelFromClass(
agentConfiguration,
modelClass,
agentConfiguration.getConfiguration(),
isAgentConfigModelAllowUnknownProperties(type, service));
}
configuration.put("datasource", resourceImplementation);
} else {
throw new IllegalArgumentException("Resource '" + resourceId + "' not found");
}
}

@Override
protected Class getAgentConfigModelClass(String type) {
return switch (type) {
case QUERY_VECTOR_DB -> QueryVectorDBConfig.class;
case VECTOR_DB_SINK -> VectorDBSinkConfig.class;
default -> throw new IllegalStateException(type);
};
}

@Override
protected boolean isAgentConfigModelAllowUnknownProperties(String type) {
return switch (type) {
case QUERY_VECTOR_DB -> false;
case VECTOR_DB_SINK -> true;
default -> throw new IllegalStateException(type);
};
}

@AgentConfig(
name = "Query a vector database",
description =
"""
Query a vector database using Vector Search capabilities.
""")
Query a vector database using Vector Search capabilities.
""")
@Data
public static class QueryVectorDBConfig extends QueryConfiguration {}

@AgentConfig(
name = "Vector database sink",
description =
"""
Store vectors in a vector database.
Configuration properties depends on the vector database implementation, specified by the "datasource" property.
""")
@Data
public static class VectorDBSinkConfig {
@ConfigProperty(
description =
"""
The defined datasource ID to use to store the vectors.
""",
required = true)
private String datasource;
@Override
public Map<String, AgentConfigurationModel> generateSupportedTypesDocumentation() {
Map<String, AgentConfigurationModel> result = new LinkedHashMap<>();
result.put(
QUERY_VECTOR_DB,
ClassConfigValidator.generateAgentModelFromClass(QueryVectorDBConfig.class));

for (Map.Entry<String, VectorDatabaseWriterConfig> datasource :
SUPPORTED_VECTOR_DB_SINK_DATASOURCES.entrySet()) {
final String service = datasource.getKey();
AgentConfigurationModel value =
ClassConfigValidator.generateAgentModelFromClass(
datasource.getValue().getAgentConfigModelClass());
value = deepCopy(value);
value.getProperties()
.get("datasource")
.setDescription(
"Resource id. The target resource must be type: 'datasource' or 'vector-database' and "
+ "service: '"
+ service
+ "'.");
value.setType(VECTOR_DB_SINK);
result.put(VECTOR_DB_SINK + "_" + service, value);
}
return result;
}

@SneakyThrows
private static AgentConfigurationModel deepCopy(AgentConfigurationModel instance) {
return MAPPER.readValue(MAPPER.writeValueAsBytes(instance), AgentConfigurationModel.class);
}
}
Loading
Loading