diff --git a/docs/v1/qdrant.md b/docs/v1/qdrant.md index 0ece18e..26bb5ba 100644 --- a/docs/v1/qdrant.md +++ b/docs/v1/qdrant.md @@ -19,7 +19,7 @@ Core Features: - Vector similarity search with abstracted SearchResult interface - Type\-safe collection creation and existence checks - Support for payload metadata and optional vector retrieval -- Extensible abstraction layer for alternate vector stores \(e.g., Pinecone, Postgres\) +- Extensible abstraction layer for alternate vector stores \(e.g. pgVector\) Basic Usage: diff --git a/v1/qdrant/client.go b/v1/qdrant/client.go index a176403..817cd3f 100644 --- a/v1/qdrant/client.go +++ b/v1/qdrant/client.go @@ -97,6 +97,10 @@ func NewQdrantClient(p QdrantParams) (*QdrantClient, error) { // // It should be lightweight and fast — typically used during startup or readiness probes. func (c *QdrantClient) healthCheck() error { + if !c.started { + return fmt.Errorf("[Qdrant] client not started") + } + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() @@ -114,15 +118,25 @@ func (c *QdrantClient) healthCheck() error { return nil } +// Client returns the underlying Qdrant SDK client. +// This is useful for direct access to low-level operations. +func (c *QdrantClient) Client() *qdrant.Client { + return c.api +} + // Close ────────────────────────────────────────────────────────────── // Close // ────────────────────────────────────────────────────────────── // // Close gracefully shuts down the Qdrant client. // -// Since the official Qdrant Go SDK doesn’t maintain persistent connections, +// Since the official Qdrant Go SDK doesn't maintain persistent connections, // this is currently a no-op. It exists for lifecycle symmetry and future safety. func (c *QdrantClient) Close() error { + if !c.started { + return nil + } + log.Println("[Qdrant] closing client (no-op)") return nil } diff --git a/v1/qdrant/configs.go b/v1/qdrant/configs.go index cffee45..11a8e8f 100644 --- a/v1/qdrant/configs.go +++ b/v1/qdrant/configs.go @@ -31,6 +31,12 @@ type Config struct { // Optional authentication token for secured deployments. ApiKey string `yaml:"api_key" env:"QDRANT_API_KEY"` + // ──────────────────────────────────────────────────────────────────────────── + // Connection Settings + // ──────────────────────────────────────────────────────────────────────────── + // Note: The following fields are reserved for future use. + // They are not currently passed to the Qdrant SDK client. + // Maximum request duration before timing out. Timeout time.Duration `yaml:"timeout" env:"QDRANT_TIMEOUT"` diff --git a/v1/qdrant/converter.go b/v1/qdrant/converter.go new file mode 100644 index 0000000..0e901e0 --- /dev/null +++ b/v1/qdrant/converter.go @@ -0,0 +1,371 @@ +package qdrant + +import ( + "fmt" + "strings" + "time" + + "github.com/Aleph-Alpha/std/v1/vectordb" + qdrant "github.com/qdrant/go-client/qdrant" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// UserPayloadPrefix is the prefix for user-defined metadata fields. +// User fields are stored under "custom." in the Qdrant payload. +const UserPayloadPrefix = "custom" + +// BuildPayload creates a Qdrant payload with separated internal and user fields. +// Internal fields are stored at the top level, while user fields are stored under +// the "custom" prefix. +// +// Example: +// +// payload := BuildPayload( +// map[string]any{"search_store_id": "store123"}, +// map[string]any{"document_id": "doc456"}, +// ) +// // Result: {"search_store_id": "store123", "custom": {"document_id": "doc456"}} +func BuildPayload(internal map[string]any, user map[string]any) map[string]any { + payload := make(map[string]any) + + // Add internal fields at top-level + for k, v := range internal { + payload[k] = v + } + + // Add user fields under "custom" prefix + if len(user) > 0 { + payload[UserPayloadPrefix] = user + } + + return payload +} + +// ── Filter Conversion ──────────────────────────────────────────────────────── + +// convertVectorDBFilterSet converts a vectordb.FilterSet to a Qdrant filter. +func convertVectorDBFilterSet(filters *vectordb.FilterSet) *qdrant.Filter { + if filters == nil { + return nil + } + + filter := &qdrant.Filter{} + + if filters.Must != nil { + filter.Must = convertVectorDBConditionSet(filters.Must) + } + if filters.Should != nil { + filter.Should = convertVectorDBConditionSet(filters.Should) + } + if filters.MustNot != nil { + filter.MustNot = convertVectorDBConditionSet(filters.MustNot) + } + + // Return nil if no conditions were added + if len(filter.Must) == 0 && len(filter.Should) == 0 && len(filter.MustNot) == 0 { + return nil + } + + return filter +} + +// convertVectorDBConditionSet converts a vectordb.ConditionSet to Qdrant conditions. +func convertVectorDBConditionSet(cs *vectordb.ConditionSet) []*qdrant.Condition { + if cs == nil { + return nil + } + + var conditions []*qdrant.Condition + for _, c := range cs.Conditions { + conds := convertVectorDBCondition(c) + for _, cond := range conds { + if cond != nil { + conditions = append(conditions, cond) + } + } + } + return conditions +} + +// convertVectorDBCondition converts a single vectordb.FilterCondition to Qdrant conditions. +func convertVectorDBCondition(c vectordb.FilterCondition) []*qdrant.Condition { + switch cond := c.(type) { + case *vectordb.MatchCondition: + return convertVectorDBMatchCondition(cond) + case *vectordb.MatchAnyCondition: + return convertVectorDBMatchAnyCondition(cond) + case *vectordb.MatchExceptCondition: + return convertVectorDBMatchExceptCondition(cond) + case *vectordb.NumericRangeCondition: + return convertVectorDBNumericRangeCondition(cond) + case *vectordb.TimeRangeCondition: + return convertVectorDBTimeRangeCondition(cond) + case *vectordb.IsNullCondition: + return convertVectorDBIsNullCondition(cond) + case *vectordb.IsEmptyCondition: + return convertVectorDBIsEmptyCondition(cond) + default: + return nil + } +} + +func convertVectorDBMatchCondition(c *vectordb.MatchCondition) []*qdrant.Condition { + key := resolveVectorDBFieldKey(c.Field, c.FieldType) + switch v := c.Value.(type) { + case string: + return []*qdrant.Condition{qdrant.NewMatch(key, v)} + case bool: + return []*qdrant.Condition{qdrant.NewMatchBool(key, v)} + case int: + return []*qdrant.Condition{qdrant.NewMatchInt(key, int64(v))} + case int64: + return []*qdrant.Condition{qdrant.NewMatchInt(key, v)} + case float64: + // Handle JSON numbers which are float64 by default + return []*qdrant.Condition{qdrant.NewMatchInt(key, int64(v))} + default: + return nil + } +} + +func convertVectorDBMatchAnyCondition(c *vectordb.MatchAnyCondition) []*qdrant.Condition { + if len(c.Values) == 0 { + return nil + } + key := resolveVectorDBFieldKey(c.Field, c.FieldType) + + // Detect type from first value + switch c.Values[0].(type) { + case string: + strs := make([]string, len(c.Values)) + for i, v := range c.Values { + if s, ok := v.(string); ok { + strs[i] = s + } + } + return []*qdrant.Condition{qdrant.NewMatchKeywords(key, strs...)} + case int, int64, float64: + ints := make([]int64, len(c.Values)) + for i, v := range c.Values { + switch n := v.(type) { + case int: + ints[i] = int64(n) + case int64: + ints[i] = n + case float64: + ints[i] = int64(n) + } + } + return []*qdrant.Condition{qdrant.NewMatchInts(key, ints...)} + } + return nil +} + +func convertVectorDBMatchExceptCondition(c *vectordb.MatchExceptCondition) []*qdrant.Condition { + if len(c.Values) == 0 { + return nil + } + key := resolveVectorDBFieldKey(c.Field, c.FieldType) + + // Detect type from first value + switch c.Values[0].(type) { + case string: + strs := make([]string, len(c.Values)) + for i, v := range c.Values { + if s, ok := v.(string); ok { + strs[i] = s + } + } + return []*qdrant.Condition{qdrant.NewMatchExceptKeywords(key, strs...)} + case int, int64, float64: + ints := make([]int64, len(c.Values)) + for i, v := range c.Values { + switch n := v.(type) { + case int: + ints[i] = int64(n) + case int64: + ints[i] = n + case float64: + ints[i] = int64(n) + } + } + return []*qdrant.Condition{qdrant.NewMatchExceptInts(key, ints...)} + } + return nil +} + +func convertVectorDBNumericRangeCondition(c *vectordb.NumericRangeCondition) []*qdrant.Condition { + key := resolveVectorDBFieldKey(c.Field, c.FieldType) + rangeFilter := &qdrant.Range{ + Gt: c.Range.Gt, + Gte: c.Range.Gte, + Lt: c.Range.Lt, + Lte: c.Range.Lte, + } + + if rangeFilter.Gt == nil && rangeFilter.Gte == nil && + rangeFilter.Lt == nil && rangeFilter.Lte == nil { + return nil + } + + return []*qdrant.Condition{qdrant.NewRange(key, rangeFilter)} +} + +func convertVectorDBTimeRangeCondition(c *vectordb.TimeRangeCondition) []*qdrant.Condition { + key := resolveVectorDBFieldKey(c.Field, c.FieldType) + dateRange := &qdrant.DatetimeRange{ + Gt: toVectorDBTimestamp(c.Range.Gt), + Gte: toVectorDBTimestamp(c.Range.Gte), + Lt: toVectorDBTimestamp(c.Range.Lt), + Lte: toVectorDBTimestamp(c.Range.Lte), + } + + if dateRange.Gt == nil && dateRange.Gte == nil && + dateRange.Lt == nil && dateRange.Lte == nil { + return nil + } + + return []*qdrant.Condition{qdrant.NewDatetimeRange(key, dateRange)} +} + +func convertVectorDBIsNullCondition(c *vectordb.IsNullCondition) []*qdrant.Condition { + key := resolveVectorDBFieldKey(c.Field, c.FieldType) + return []*qdrant.Condition{qdrant.NewIsNull(key)} +} + +func convertVectorDBIsEmptyCondition(c *vectordb.IsEmptyCondition) []*qdrant.Condition { + key := resolveVectorDBFieldKey(c.Field, c.FieldType) + return []*qdrant.Condition{qdrant.NewIsEmpty(key)} +} + +// resolveVectorDBFieldKey returns the full field path based on FieldType. +// Internal fields: "search_store_id" -> "search_store_id" +// User fields: "document_id" -> "custom.document_id" +func resolveVectorDBFieldKey(key string, fieldType vectordb.FieldType) string { + if fieldType == vectordb.UserField { + if strings.HasPrefix(key, UserPayloadPrefix+".") { + return key + } + return UserPayloadPrefix + "." + key + } + return key +} + +func toVectorDBTimestamp(t *time.Time) *timestamppb.Timestamp { + if t == nil { + return nil + } + return timestamppb.New(*t) +} + +// ── Result Conversion ──────────────────────────────────────────────────────── + +// parseVectorDBSearchResults converts Qdrant response to vectordb.SearchResult slice. +func parseVectorDBSearchResults(resp []*qdrant.ScoredPoint) ([]vectordb.SearchResult, error) { + results := make([]vectordb.SearchResult, 0, len(resp)) + for _, r := range resp { + id, err := extractVectorDBPointID(r.Id) + if err != nil { + return nil, err + } + + results = append(results, vectordb.SearchResult{ + ID: id, + Score: r.Score, + Payload: convertVectorDBPayload(r.Payload), + }) + } + return results, nil +} + +// extractVectorDBPointID extracts a string ID from Qdrant's PointId type. +func extractVectorDBPointID(id *qdrant.PointId) (string, error) { + if id == nil { + return "", fmt.Errorf("nil point ID") + } + switch v := id.PointIdOptions.(type) { + case *qdrant.PointId_Num: + return fmt.Sprintf("%d", v.Num), nil + case *qdrant.PointId_Uuid: + return v.Uuid, nil + default: + return "", fmt.Errorf("unexpected PointId type: %T", v) + } +} + +// convertVectorDBPayload converts Qdrant's protobuf payload to a generic map. +func convertVectorDBPayload(payload map[string]*qdrant.Value) map[string]any { + if payload == nil { + return nil + } + result := make(map[string]any, len(payload)) + for k, v := range payload { + result[k] = extractVectorDBValue(v) + } + return result +} + +// extractVectorDBValue recursively converts a Qdrant Value to a Go native type. +func extractVectorDBValue(v *qdrant.Value) any { + if v == nil { + return nil + } + switch val := v.Kind.(type) { + case *qdrant.Value_StringValue: + return val.StringValue + case *qdrant.Value_IntegerValue: + return val.IntegerValue + case *qdrant.Value_DoubleValue: + return val.DoubleValue + case *qdrant.Value_BoolValue: + return val.BoolValue + case *qdrant.Value_NullValue: + return nil + case *qdrant.Value_StructValue: + if val.StructValue == nil { + return nil + } + return convertVectorDBPayload(val.StructValue.Fields) + case *qdrant.Value_ListValue: + if val.ListValue == nil { + return nil + } + items := make([]any, len(val.ListValue.Values)) + for i, item := range val.ListValue.Values { + items[i] = extractVectorDBValue(item) + } + return items + default: + return nil + } +} + +// convertVectorDBFilterSets converts an array of FilterSets to a single Qdrant filter. +// Multiple filter sets are combined with AND logic (all must match). +func convertVectorDBFilterSets(filters []*vectordb.FilterSet) *qdrant.Filter { + if len(filters) == 0 { + return nil + } + + // Single filter set - convert directly + if len(filters) == 1 { + return convertVectorDBFilterSet(filters[0]) + } + + // Multiple filter sets - combine with AND + var allConditions []*qdrant.Condition + for _, fs := range filters { + converted := convertVectorDBFilterSet(fs) + if converted != nil { + allConditions = append(allConditions, &qdrant.Condition{ + ConditionOneOf: &qdrant.Condition_Filter{Filter: converted}, + }) + } + } + + if len(allConditions) == 0 { + return nil + } + + return &qdrant.Filter{Must: allConditions} +} diff --git a/v1/qdrant/doc.go b/v1/qdrant/doc.go index 7ee0879..78cf98a 100644 --- a/v1/qdrant/doc.go +++ b/v1/qdrant/doc.go @@ -5,95 +5,124 @@ // collection management, embedding insertion, similarity search, and deletion. It integrates // seamlessly with the fx dependency injection framework and supports builder-style configuration. // -// Core Features: +// # Core Features // // - Managed Qdrant client lifecycle with Fx integration // - Config struct supporting environment and YAML loading // - Automatic health checks on client initialization // - Safe, batched insertion of embeddings with configurable batch size -// - Vector similarity search with abstracted SearchResult interface +// - Database-agnostic interface via vectordb.Service // - Type-safe collection creation and existence checks // - Support for payload metadata and optional vector retrieval -// - Extensible abstraction layer for alternate vector stores (e.g., Pinecone, Postgres) +// - Extensible abstraction layer for alternate vector stores (e.g. pgVector) // -// Basic Usage: +// # VectorDB Interface // -// import "github.com/Aleph-Alpha/std/v1/qdrant" +// This package includes [Adapter] which implements the database-agnostic +// [vectordb.Service] interface. Use this for new projects to enable +// easy switching between vector databases: +// +// import ( +// "github.com/Aleph-Alpha/std/v1/vectordb" +// "github.com/Aleph-Alpha/std/v1/qdrant" +// ) +// +// // Create your existing QdrantClient +// qc, _ := qdrant.NewQdrantClient(qdrant.QdrantParams{ +// Config: &qdrant.Config{ +// Endpoint: "localhost", +// Port: 6334, +// }, +// }) +// +// // Create adapter for DB-agnostic usage +// var db vectordb.Service = qdrant.NewAdapter(qc.Client()) +// +// This allows switching between vector databases (Qdrant, pgVector) without +// changing application code. +// +// # Basic Usage +// +// import ( +// "github.com/Aleph-Alpha/std/v1/qdrant" +// "github.com/Aleph-Alpha/std/v1/vectordb" +// ) // // // Create a new client // client, err := qdrant.NewQdrantClient(qdrant.QdrantParams{ -// Config: &qdrant.Config{ -// Endpoint: "localhost:6334", -// ApiKey: "", -// }, +// Config: &qdrant.Config{ +// Endpoint: "localhost", +// Port: 6334, +// }, // }) // if err != nil { -// log.Fatal(err) +// log.Fatal(err) // } // +// // Create adapter +// adapter := qdrant.NewAdapter(client.Client()) +// // collectionName := "documents" // -// // Insert single embedding -// input := qdrant.EmbeddingInput{ -// ID: "doc_1", -// Vector: []float32{0.12, 0.43, 0.85}, -// Meta: map[string]any{"title": "My Document"}, -// } -// if err := client.Insert(ctx, collectionName, input); err != nil { -// log.Fatal(err) +// // Ensure collection exists +// if err := adapter.EnsureCollection(ctx, collectionName, 1536); err != nil { +// log.Fatal(err) // } // -// // Batch insert embeddings -// batch := []qdrant.EmbeddingInput{input1, input2, input3} -// if err := client.BatchInsert(ctx, collectionName, batch); err != nil { -// log.Fatal(err) +// // Insert embeddings +// inputs := []vectordb.EmbeddingInput{ +// { +// ID: "doc_1", +// Vector: []float32{0.12, 0.43, 0.85, ...}, +// Payload: map[string]any{"title": "My Document"}, +// }, +// } +// if err := adapter.Insert(ctx, collectionName, inputs); err != nil { +// log.Fatal(err) // } // // // Perform similarity search -// results, err := client.Search(ctx, qdrant.SearchRequest{ -// CollectionName: collectionName, -// Vector: queryVector, -// TopK: 5, +// results, err := adapter.Search(ctx, vectordb.SearchRequest{ +// CollectionName: collectionName, +// Vector: queryVector, +// TopK: 5, // }) // for _, res := range results[0] { -// fmt.Printf("ID=%s Score=%.4f\n", res.GetID(), res.GetScore()) +// fmt.Printf("ID=%s Score=%.4f\n", res.ID, res.Score) // } // -// FX Module Integration: +// # FX Module Integration // // The package exposes an Fx module for automatic dependency injection: // // app := fx.New( -// qdrant.FXModule, -// // other modules... +// qdrant.FXModule, +// // other modules... // ) // app.Run() // -// Abstractions: +// # Search Results // -// The package defines a lightweight SearchResultInterface that encapsulates -// search results via methods such as GetID(), GetScore(), GetMeta(), and GetCollectionName(). -// The underlying concrete type remains SearchResult, allowing both strong typing internally -// and loose coupling externally. +// Search results are returned as [vectordb.SearchResult] structs with public fields: // -// Example: -// -// type SearchResultInterface interface { -// GetID() string -// GetScore() float32 -// GetMeta() map[string]*qdrant.Value -// GetCollectionName() string +// type SearchResult struct { +// ID string // Unique identifier of the matched point +// Score float32 // Similarity score +// Payload map[string]any // Metadata stored with the vector +// Vector []float32 // Stored embedding (if requested) +// CollectionName string // Source collection name // } // -// type SearchResult struct { /* implements SearchResultInterface */ } +// Access fields directly (no getter methods needed): // -// // Function signature: -// func (c *QdrantClient) Search(ctx context.Context, vector []float32, topK int) ([]SearchResultInterface, error) +// for _, result := range results[0] { +// fmt.Println(result.ID, result.Score, result.Payload["title"]) +// } // // # Filtering // -// The package provides a comprehensive, type-safe filtering system for vector searches. -// Filters support boolean logic (AND, OR, NOT) and various condition types. +// Filters are defined in the [vectordb] package and support boolean logic (AND, OR, NOT). +// The qdrant adapter converts these to native Qdrant filters automatically. // // Filter Structure: // @@ -103,17 +132,13 @@ // MustNot *ConditionSet // NOT - none of the conditions should match // } // -// Condition Types: -// -// - TextCondition: Exact string match -// - BoolCondition: Exact boolean match -// - IntCondition: Exact integer match -// - TextAnyCondition: String IN operator (match any of values) -// - IntAnyCondition: Integer IN operator -// - TextExceptCondition: String NOT IN operator -// - IntExceptCondition: Integer NOT IN operator -// - TimeRangeCondition: DateTime range filter (gte, lte, gt, lt) -// - NumericRangeCondition: Numeric range filter +// Condition Types (all in vectordb package): +// +// - MatchCondition: Exact match (string, bool, int64) +// - MatchAnyCondition: IN operator (match any of values) +// - MatchExceptCondition: NOT IN operator +// - NumericRangeCondition: Numeric range filter (gt, gte, lt, lte) +// - TimeRangeCondition: DateTime range filter // - IsNullCondition: Check if field is null // - IsEmptyCondition: Check if field is empty, null, or missing // @@ -122,176 +147,172 @@ // The package distinguishes between system-managed and user-defined metadata: // // const ( -// InternalField FieldType = iota // Top-level: "search_store_id" -// UserField // Nested: "custom.document_id" +// InternalField FieldType = iota // Top-level: "status" +// UserField // Prefixed: "custom.document_id" // ) // // User fields are automatically prefixed with "custom." when querying Qdrant. // -// Basic Filter Example: +// # Filter Examples Using Convenience Constructors // -// // Filter: city = "London" AND active = true -// filters := &qdrant.FilterSet{ -// Must: &qdrant.ConditionSet{ -// Conditions: []qdrant.FilterCondition{ -// qdrant.TextCondition{Key: "city", Value: "London"}, -// qdrant.BoolCondition{Key: "active", Value: true}, -// }, -// }, -// } +// The vectordb package provides convenience constructors for clean filter creation: +// +// Basic Filter (Must - AND logic): // -// results, err := client.Search(ctx, qdrant.SearchRequest{ +// // Filter: city = "London" AND active = true +// results, err := adapter.Search(ctx, vectordb.SearchRequest{ // CollectionName: "documents", // Vector: queryVector, // TopK: 10, -// Filters: filters, +// Filters: []*vectordb.FilterSet{ +// vectordb.NewFilterSet( +// vectordb.Must( +// vectordb.NewMatch("city", "London"), +// vectordb.NewMatch("active", true), +// ), +// ), +// }, // }) // // OR Conditions (Should): // // // Filter: city = "London" OR city = "Berlin" -// filters := &qdrant.FilterSet{ -// Should: &qdrant.ConditionSet{ -// Conditions: []qdrant.FilterCondition{ -// qdrant.TextCondition{Key: "city", Value: "London"}, -// qdrant.TextCondition{Key: "city", Value: "Berlin"}, -// }, -// }, +// filters := []*vectordb.FilterSet{ +// vectordb.NewFilterSet( +// vectordb.Should( +// vectordb.NewMatch("city", "London"), +// vectordb.NewMatch("city", "Berlin"), +// ), +// ), // } // // IN Operator (MatchAny): // // // Filter: city IN ["London", "Berlin", "Paris"] -// filters := &qdrant.FilterSet{ -// Must: &qdrant.ConditionSet{ -// Conditions: []qdrant.FilterCondition{ -// qdrant.TextAnyCondition{ -// Key: "city", -// Values: []string{"London", "Berlin", "Paris"}, -// }, -// }, -// }, +// filters := []*vectordb.FilterSet{ +// vectordb.NewFilterSet( +// vectordb.Must( +// vectordb.NewMatchAny("city", "London", "Berlin", "Paris"), +// ), +// ), +// } +// +// Numeric Range Filter: +// +// // Filter: price >= 100 AND price < 500 +// min, max := float64(100), float64(500) +// filters := []*vectordb.FilterSet{ +// vectordb.NewFilterSet( +// vectordb.Must( +// vectordb.NewNumericRange("price", vectordb.NumericRange{ +// Gte: &min, +// Lt: &max, +// }), +// ), +// ), // } // // Time Range Filter: // +// // Filter: created_at >= yesterday AND created_at < now // now := time.Now() // yesterday := now.Add(-24 * time.Hour) -// -// filters := &qdrant.FilterSet{ -// Must: &qdrant.ConditionSet{ -// Conditions: []qdrant.FilterCondition{ -// qdrant.TimeRangeCondition{ -// Key: "created_at", -// Value: qdrant.TimeRange{ -// Gte: &yesterday, -// Lt: &now, -// }, -// }, -// }, -// }, +// filters := []*vectordb.FilterSet{ +// vectordb.NewFilterSet( +// vectordb.Must( +// vectordb.NewTimeRange("created_at", vectordb.TimeRange{ +// Gte: &yesterday, +// Lt: &now, +// }), +// ), +// ), // } // // Complex Filter (Combined Clauses): // -// // Filter: (status = "published") AND (category = "tech" OR "science") AND NOT (deleted = true) -// filters := &qdrant.FilterSet{ -// Must: &qdrant.ConditionSet{ -// Conditions: []qdrant.FilterCondition{ -// qdrant.TextCondition{Key: "status", Value: "published"}, -// }, -// }, -// Should: &qdrant.ConditionSet{ -// Conditions: []qdrant.FilterCondition{ -// qdrant.TextCondition{Key: "category", Value: "tech", FieldType: qdrant.UserField}, -// qdrant.TextCondition{Key: "category", Value: "science", FieldType: qdrant.UserField}, -// }, -// }, -// MustNot: &qdrant.ConditionSet{ -// Conditions: []qdrant.FilterCondition{ -// qdrant.BoolCondition{Key: "deleted", Value: true}, -// }, -// }, +// // Filter: status = "published" AND (tag = "ml" OR tag = "ai") AND NOT deleted = true +// filters := []*vectordb.FilterSet{ +// vectordb.NewFilterSet( +// vectordb.Must(vectordb.NewMatch("status", "published")), +// vectordb.Should( +// vectordb.NewMatch("tag", "ml"), +// vectordb.NewMatch("tag", "ai"), +// ), +// vectordb.MustNot(vectordb.NewMatch("deleted", true)), +// ), // } // -// UUID Filtering: +// User-Defined Fields: // -// UUIDs are filtered as strings using TextCondition: +// For fields stored under a custom prefix, use the User* constructors: // -// filters := &qdrant.FilterSet{ -// Must: &qdrant.ConditionSet{ -// Conditions: []qdrant.FilterCondition{ -// qdrant.TextCondition{ -// Key: "document_id", -// Value: "f47ac10b-58cc-4372-a567-0e02b2c3d479", -// FieldType: qdrant.UserField, -// }, -// }, -// }, +// // Filter on user-defined field: custom.document_id = "doc-123" +// filters := []*vectordb.FilterSet{ +// vectordb.NewFilterSet( +// vectordb.Must( +// vectordb.NewUserMatch("document_id", "doc-123"), +// ), +// ), // } // -// Payload Structure Helper: -// -// The BuildPayload function creates properly structured payloads that separate -// internal and user fields: -// -// payload := qdrant.BuildPayload( -// map[string]any{"search_store_id": "store-123"}, // Internal (top-level) -// map[string]any{"document_id": "doc-456"}, // User (under "custom.") -// ) +// Multiple FilterSets (AND between sets): // -// // Result: -// // { -// // "search_store_id": "store-123", -// // "custom": { -// // "document_id": "doc-456" -// // } -// // } +// When you provide multiple FilterSets, they are combined with AND logic: // -// This structure ensures user-defined filter indexes are created at the correct -// path (custom.field_name). +// // Filter: (color = "red") AND (size < 20) +// lt := float64(20) +// filters := []*vectordb.FilterSet{ +// vectordb.NewFilterSet(vectordb.Must(vectordb.NewMatch("color", "red"))), +// vectordb.NewFilterSet(vectordb.Must(vectordb.NewNumericRange("size", vectordb.NumericRange{Lt: <}))), +// } // -// Configuration: +// # Configuration // // Qdrant can be configured via environment variables or YAML: // -// QDRANT_ENDPOINT=http://localhost:6334 +// QDRANT_ENDPOINT=localhost +// QDRANT_PORT=6334 // QDRANT_API_KEY=your-api-key // -// Performance Considerations: +// # Performance Considerations // -// The BatchInsert method automatically splits large embedding inserts into smaller -// upserts (default batch size = 500). This minimizes memory usage and avoids timeouts +// The Insert method automatically splits large embedding batches into smaller +// upserts (default batch size = 100). This minimizes memory usage and avoids timeouts // when ingesting large datasets. // -// Thread Safety: +// # Thread Safety // -// All exported methods on QdrantClient are safe for concurrent use by multiple goroutines. +// All exported methods on the Adapter are safe for concurrent use by multiple goroutines. // -// Testing: +// # Testing // -// For testing and mocking, application code should depend on the public interface types -// (e.g., SearchResultInterface, EmbeddingInput) instead of concrete Qdrant structs. -// This allows replacing the QdrantClient with in-memory or mock implementations in tests. +// For testing and mocking, depend on the [vectordb.Service] interface: // -// Example Mock: +// type MockVectorDB struct{} // -// type MockResult struct { -// id string -// score float32 -// meta map[string]any +// func (m *MockVectorDB) Search(ctx context.Context, requests ...vectordb.SearchRequest) ([][]vectordb.SearchResult, error) { +// return [][]vectordb.SearchResult{ +// {{ID: "doc-1", Score: 0.95, Payload: map[string]any{"title": "Test"}}}, +// }, nil // } -// func (m MockResult) GetID() string { return m.id } -// func (m MockResult) GetScore() float32 { return m.score } -// func (m MockResult) GetMeta() map[string]any { return m.meta } // -// Package Layout: +// // Use in tests: +// var db vectordb.Service = &MockVectorDB{} +// +// # Package Layout // // qdrant/ -// ├── client.go // Qdrant client implementation -// ├── operations.go // CRUD operations (Insert, Search, Delete, etc.) -// ├── filters.go // Type-safe filtering system -// ├── utils.go // Shared types and interfaces -// ├── configs.go // Configuration struct and builder methods +// ├── client.go // Qdrant client wrapper and lifecycle +// ├── operations.go // Service implementation (Adapter) +// ├── converter.go // vectordb ↔ Qdrant type conversion +// ├── utils.go // Qdrant-specific helper functions +// ├── configs.go // Configuration struct // └── fx_module.go // Fx dependency injection module +// +// # Related Packages +// +// - [vectordb]: Database-agnostic types and interfaces +// - [vectordb.FilterSet]: Filter structures for search queries +// - [vectordb.SearchResult]: Search result type +// - [vectordb.EmbeddingInput]: Input type for inserting vectors package qdrant diff --git a/v1/qdrant/filters.go b/v1/qdrant/filters.go deleted file mode 100644 index 20222e8..0000000 --- a/v1/qdrant/filters.go +++ /dev/null @@ -1,398 +0,0 @@ -package qdrant - -import ( - "encoding/json" - "strings" - "time" - - qdrant "github.com/qdrant/go-client/qdrant" - "google.golang.org/protobuf/types/known/timestamppb" -) - -// UserPayloadPrefix is the prefix for user-defined metadata fields -const UserPayloadPrefix = "custom" - -// FilterCondition is the interface for all filter conditions -type FilterCondition interface { - ToQdrantCondition() []*qdrant.Condition -} - -// FieldType indicates whether a field is internal or user-defined -type FieldType int - -const ( - // InternalField - system-managed fields stored at top-level - InternalField FieldType = iota - // UserField - user-defined fields stored under "custom." prefix - UserField -) - -// TimeRange represents a time-based filter condition -type TimeRange struct { - Gt *time.Time `json:"after,omitempty"` // Greater than this time - Gte *time.Time `json:"atOrAfter,omitempty"` // Greater than or equal to this time - Lt *time.Time `json:"before,omitempty"` // Less than this time - Lte *time.Time `json:"atOrBefore,omitempty"` // Less than or equal to this time -} - -// NumericRange represents a numeric range filter condition -type NumericRange struct { - Gt *float64 `json:"greaterThan,omitempty"` // Greater than - Gte *float64 `json:"greaterThanOrEqualTo,omitempty"` // Greater than or equal - Lt *float64 `json:"lessThan,omitempty"` // Less than - Lte *float64 `json:"lessThanOrEqualTo,omitempty"` // Less than or equal -} - -// MatchCondition represents an exact match condition for a field value. -// Supports string, bool, and int64 types. The FieldType defaults to InternalField -// if not specified, meaning the field is stored at the top level of the payload. -// Use UserField to indicate the field is stored under the "custom." prefix. -type MatchCondition[T comparable] struct { - Key string `json:"field"` - Value T `json:"equalTo"` - FieldType FieldType `json:"-"` // Internal or User field (default: InternalField) -} - -// ToQdrantCondition converts the MatchCondition to Qdrant conditions. -// Supports string, bool, and int64 types. Returns nil for unsupported types. -func (c MatchCondition[T]) ToQdrantCondition() []*qdrant.Condition { - key := resolveFieldKey(c.Key, c.FieldType) - switch v := any(c.Value).(type) { - case string: - return []*qdrant.Condition{qdrant.NewMatch(key, v)} - case bool: - return []*qdrant.Condition{qdrant.NewMatchBool(key, v)} - case int64: - return []*qdrant.Condition{qdrant.NewMatchInt(key, v)} - default: - // Unsupported type - returns nil - return nil - } -} - -// MatchAnyCondition matches if value is one of the given values (IN operator). -// Applicable to keyword (string) and integer payloads. -// Returns nil if Values is empty. The FieldType defaults to InternalField if not specified. -type MatchAnyCondition[T string | int64] struct { - Key string `json:"field"` - Values []T `json:"anyOf"` - FieldType FieldType `json:"-"` // Internal or User field (default: InternalField) -} - -func (c MatchAnyCondition[T]) ToQdrantCondition() []*qdrant.Condition { - if len(c.Values) == 0 { - return nil - } - key := resolveFieldKey(c.Key, c.FieldType) - switch v := any(c.Values).(type) { - case []string: - return []*qdrant.Condition{qdrant.NewMatchKeywords(key, v...)} - case []int64: - return []*qdrant.Condition{qdrant.NewMatchInts(key, v...)} - default: - return nil - } -} - -// MatchExceptCondition matches if value is NOT one of the given values (NOT IN operator). -// Applicable to keyword (string) and integer payloads. -// Returns nil if Values is empty. The FieldType defaults to InternalField if not specified. -type MatchExceptCondition[T string | int64] struct { - Key string `json:"field"` - Values []T `json:"noneOf"` - FieldType FieldType `json:"-"` // Internal or User field (default: InternalField) -} - -func (c MatchExceptCondition[T]) ToQdrantCondition() []*qdrant.Condition { - if len(c.Values) == 0 { - return nil - } - key := resolveFieldKey(c.Key, c.FieldType) - switch v := any(c.Values).(type) { - case []string: - return []*qdrant.Condition{qdrant.NewMatchExceptKeywords(key, v...)} - case []int64: - return []*qdrant.Condition{qdrant.NewMatchExceptInts(key, v...)} - default: - return nil - } -} - -type TextCondition = MatchCondition[string] // Exact string match -type BoolCondition = MatchCondition[bool] // Exact boolean match -type IntCondition = MatchCondition[int64] // Exact integer match -type TextAnyCondition = MatchAnyCondition[string] // String IN operator -type IntAnyCondition = MatchAnyCondition[int64] // Integer IN operator -type TextExceptCondition = MatchExceptCondition[string] // String NOT IN -type IntExceptCondition = MatchExceptCondition[int64] // Integer NOT IN - -// TimeRangeCondition represents a time range filter condition -type TimeRangeCondition struct { - Key string `json:"field"` - Value TimeRange `json:"-"` - FieldType FieldType `json:"-"` -} - -func (c TimeRangeCondition) ToQdrantCondition() []*qdrant.Condition { - return buildDateTimeRangeConditions(resolveFieldKey(c.Key, c.FieldType), c.Value) -} - -func (c TimeRangeCondition) MarshalJSON() ([]byte, error) { - type Alias struct { - Field string `json:"field"` - After *time.Time `json:"after,omitempty"` - AtOrAfter *time.Time `json:"atOrAfter,omitempty"` - Before *time.Time `json:"before,omitempty"` - AtOrBefore *time.Time `json:"atOrBefore,omitempty"` - } - return json.Marshal(Alias{ - Field: c.Key, - After: c.Value.Gt, - AtOrAfter: c.Value.Gte, - Before: c.Value.Lt, - AtOrBefore: c.Value.Lte, - }) -} - -func (c *TimeRangeCondition) UnmarshalJSON(data []byte) error { - type Alias struct { - Field string `json:"field"` - After *time.Time `json:"after,omitempty"` - AtOrAfter *time.Time `json:"atOrAfter,omitempty"` - Before *time.Time `json:"before,omitempty"` - AtOrBefore *time.Time `json:"atOrBefore,omitempty"` - } - var alias Alias - if err := json.Unmarshal(data, &alias); err != nil { - return err - } - c.Key = alias.Field - c.Value = TimeRange{ - Gt: alias.After, - Gte: alias.AtOrAfter, - Lt: alias.Before, - Lte: alias.AtOrBefore, - } - return nil -} - -// NumericRangeCondition represents a numeric range filter -type NumericRangeCondition struct { - Key string `json:"field"` - Value NumericRange `json:"-"` - FieldType FieldType `json:"-"` -} - -func (c NumericRangeCondition) ToQdrantCondition() []*qdrant.Condition { - return buildNumericRangeConditions(resolveFieldKey(c.Key, c.FieldType), c.Value) -} - -func (c NumericRangeCondition) MarshalJSON() ([]byte, error) { - type Alias struct { - Field string `json:"field"` - GreaterThan *float64 `json:"greaterThan,omitempty"` - GreaterThanOrEqualTo *float64 `json:"greaterThanOrEqualTo,omitempty"` - LessThan *float64 `json:"lessThan,omitempty"` - LessThanOrEqualTo *float64 `json:"lessThanOrEqualTo,omitempty"` - } - return json.Marshal(Alias{ - Field: c.Key, - GreaterThan: c.Value.Gt, - GreaterThanOrEqualTo: c.Value.Gte, - LessThan: c.Value.Lt, - LessThanOrEqualTo: c.Value.Lte, - }) -} - -func (c *NumericRangeCondition) UnmarshalJSON(data []byte) error { - type Alias struct { - Field string `json:"field"` - GreaterThan *float64 `json:"greaterThan,omitempty"` - GreaterThanOrEqualTo *float64 `json:"greaterThanOrEqualTo,omitempty"` - LessThan *float64 `json:"lessThan,omitempty"` - LessThanOrEqualTo *float64 `json:"lessThanOrEqualTo,omitempty"` - } - var alias Alias - if err := json.Unmarshal(data, &alias); err != nil { - return err - } - c.Key = alias.Field - c.Value = NumericRange{ - Gt: alias.GreaterThan, - Gte: alias.GreaterThanOrEqualTo, - Lt: alias.LessThan, - Lte: alias.LessThanOrEqualTo, - } - return nil -} - -// resolveFieldKey returns the full field path based on FieldType -// Internal fields: "search_store_id" -> "search_store_id" -// User fields: "document_id" -> "custom.document_id" -func resolveFieldKey(key string, fieldType FieldType) string { - if fieldType == UserField { - // Prevent double-prefixing - if strings.HasPrefix(key, UserPayloadPrefix+".") { - return key - } - return UserPayloadPrefix + "." + key - } - return key -} - -// ConditionSet holds conditions for a single clause -type ConditionSet struct { - Conditions []FilterCondition `json:"conditions,omitempty"` -} - -// FilterSet supports Must (AND), Should (OR), and MustNot (NOT) clauses. -// Use with SearchRequest.Filters to filter search results. -// -// Example: -// -// filters := &FilterSet{ -// Must: &ConditionSet{ -// Conditions: []FilterCondition{ -// TextCondition{Key: "city", Value: "London"}, -// }, -// }, -// } -type FilterSet struct { - Must *ConditionSet `json:"must,omitempty"` // AND - all conditions must match - Should *ConditionSet `json:"should,omitempty"` // OR - at least one condition must match - MustNot *ConditionSet `json:"mustNot,omitempty"` // NOT - none of the conditions should match -} - -// buildFilter constructs a Qdrant filter from FilterSet -func buildFilter(filters *FilterSet) *qdrant.Filter { - if filters == nil { - return nil - } - - filter := &qdrant.Filter{} - - if filters.Must != nil { - filter.Must = buildConditions(filters.Must) - } - - if filters.Should != nil { - filter.Should = buildConditions(filters.Should) - } - - if filters.MustNot != nil { - filter.MustNot = buildConditions(filters.MustNot) - } - - // Return nil if no conditions were added - if len(filter.Must) == 0 && len(filter.Should) == 0 && len(filter.MustNot) == 0 { - return nil - } - - return filter -} - -// buildConditions converts a ConditionSet to Qdrant conditions -// Filters out nil conditions that may be returned by invalid conditions (e.g., empty ranges) -func buildConditions(cs *ConditionSet) []*qdrant.Condition { - if cs == nil { - return nil - } - - var conditions []*qdrant.Condition - for _, c := range cs.Conditions { - conds := c.ToQdrantCondition() - for _, cond := range conds { - if cond != nil { - conditions = append(conditions, cond) - } - } - } - return conditions -} - -// buildDateTimeRangeConditions creates datetime range conditions -func buildDateTimeRangeConditions(key string, tr TimeRange) []*qdrant.Condition { - dateRange := &qdrant.DatetimeRange{ - Gt: toTimestamp(tr.Gt), - Gte: toTimestamp(tr.Gte), - Lt: toTimestamp(tr.Lt), - Lte: toTimestamp(tr.Lte), - } - - // Check if any field is set - if dateRange.Gt == nil && dateRange.Gte == nil && dateRange.Lt == nil && dateRange.Lte == nil { - return nil - } - - return []*qdrant.Condition{qdrant.NewDatetimeRange(key, dateRange)} -} - -// buildNumericRangeConditions creates numeric range conditions -func buildNumericRangeConditions(key string, nr NumericRange) []*qdrant.Condition { - rangeFilter := &qdrant.Range{ - Gt: nr.Gt, - Gte: nr.Gte, - Lt: nr.Lt, - Lte: nr.Lte, - } - - // Check if any field is set - if rangeFilter.Gt == nil && rangeFilter.Gte == nil && rangeFilter.Lt == nil && rangeFilter.Lte == nil { - return nil - } - - return []*qdrant.Condition{qdrant.NewRange(key, rangeFilter)} -} - -// toTimestamp converts a *time.Time to *timestamppb.Timestamp (nil-safe) -func toTimestamp(t *time.Time) *timestamppb.Timestamp { - if t == nil { - return nil - } - return timestamppb.New(*t) -} - -// IsNullCondition checks if a field is null -type IsNullCondition struct { - Key string `json:"field"` - FieldType FieldType `json:"-"` // Internal or User field (default: InternalField) -} - -func (c IsNullCondition) ToQdrantCondition() []*qdrant.Condition { - key := resolveFieldKey(c.Key, c.FieldType) - return []*qdrant.Condition{qdrant.NewIsNull(key)} -} - -// IsEmptyCondition checks if a field is empty (does not exist, null, or []) -type IsEmptyCondition struct { - Key string `json:"field"` - FieldType FieldType `json:"-"` // Internal or User field (default: InternalField) -} - -func (c IsEmptyCondition) ToQdrantCondition() []*qdrant.Condition { - key := resolveFieldKey(c.Key, c.FieldType) - return []*qdrant.Condition{qdrant.NewIsEmpty(key)} -} - -// === Payload Helpers === - -// BuildPayload creates a Qdrant payload with separated internal and user fields. -// Internal fields are stored at the top level, while user fields are stored under -// the "custom" prefix. If internal contains a "custom" key, it will be overwritten -// by the user fields map. -func BuildPayload(internal map[string]any, user map[string]any) map[string]any { - payload := make(map[string]any) - - // Add internal fields at top-level - for k, v := range internal { - payload[k] = v - } - - // Add user fields under "custom" prefix - // Note: This will overwrite any "custom" key that was in the internal map - if len(user) > 0 { - payload[UserPayloadPrefix] = user - } - - return payload -} diff --git a/v1/qdrant/filters_test.go b/v1/qdrant/filters_test.go deleted file mode 100644 index 495b583..0000000 --- a/v1/qdrant/filters_test.go +++ /dev/null @@ -1,1148 +0,0 @@ -package qdrant - -import ( - "encoding/json" - "testing" - "time" -) - -func TestBuildFilter_NilFilterSet(t *testing.T) { - result := buildFilter(nil) - if result != nil { - t.Errorf("expected nil, got %v", result) - } -} - -func TestBuildFilter_EmptyFilterSet(t *testing.T) { - filters := &FilterSet{} - result := buildFilter(filters) - if result != nil { - t.Errorf("expected nil, got %v", result) - } -} - -func TestBuildFilter_EmptyConditionSet(t *testing.T) { - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{}, - }, - } - result := buildFilter(filters) - if result != nil { - t.Errorf("expected nil, got %v", result) - } -} - -func TestBuildFilter_MustWithTextCondition(t *testing.T) { - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{ - TextCondition{Key: "city", Value: "London"}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.Must) != 1 { - t.Errorf("expected 1 Must condition, got %d", len(result.Must)) - } - if len(result.Should) != 0 { - t.Errorf("expected 0 Should conditions, got %d", len(result.Should)) - } - if len(result.MustNot) != 0 { - t.Errorf("expected 0 MustNot conditions, got %d", len(result.MustNot)) - } -} - -func TestBuildFilter_ShouldWithMultipleTextConditions(t *testing.T) { - // city = "London" OR city = "Berlin" - filters := &FilterSet{ - Should: &ConditionSet{ - Conditions: []FilterCondition{ - TextCondition{Key: "city", Value: "London"}, - TextCondition{Key: "city", Value: "Berlin"}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.Should) != 2 { - t.Errorf("expected 2 Should conditions, got %d", len(result.Should)) - } -} - -func TestBuildFilter_MustNotWithBoolCondition(t *testing.T) { - filters := &FilterSet{ - MustNot: &ConditionSet{ - Conditions: []FilterCondition{ - BoolCondition{Key: "archived", Value: true}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.MustNot) != 1 { - t.Errorf("expected 1 MustNot condition, got %d", len(result.MustNot)) - } -} - -func TestBuildFilter_MixedConditionTypes(t *testing.T) { - // city = "London" AND active = true AND priority = 1 - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{ - TextCondition{Key: "city", Value: "London"}, - BoolCondition{Key: "active", Value: true}, - IntCondition{Key: "priority", Value: 1}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.Must) != 3 { - t.Errorf("expected 3 Must conditions, got %d", len(result.Must)) - } -} - -func TestBuildFilter_CombinedClauses(t *testing.T) { - // (city = "London" AND active = true) AND NOT archived - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{ - TextCondition{Key: "city", Value: "London"}, - BoolCondition{Key: "active", Value: true}, - }, - }, - MustNot: &ConditionSet{ - Conditions: []FilterCondition{ - BoolCondition{Key: "archived", Value: true}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.Must) != 2 { - t.Errorf("expected 2 Must conditions, got %d", len(result.Must)) - } - if len(result.MustNot) != 1 { - t.Errorf("expected 1 MustNot condition, got %d", len(result.MustNot)) - } -} - -func TestBuildFilter_AllThreeClauses(t *testing.T) { - // Must AND Should AND MustNot - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{ - TextCondition{Key: "status", Value: "active"}, - }, - }, - Should: &ConditionSet{ - Conditions: []FilterCondition{ - TextCondition{Key: "city", Value: "London"}, - TextCondition{Key: "city", Value: "Berlin"}, - }, - }, - MustNot: &ConditionSet{ - Conditions: []FilterCondition{ - BoolCondition{Key: "deleted", Value: true}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.Must) != 1 { - t.Errorf("expected 1 Must condition, got %d", len(result.Must)) - } - if len(result.Should) != 2 { - t.Errorf("expected 2 Should conditions, got %d", len(result.Should)) - } - if len(result.MustNot) != 1 { - t.Errorf("expected 1 MustNot condition, got %d", len(result.MustNot)) - } -} - -func TestBuildFilter_TimeRangeCondition(t *testing.T) { - startTime := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) - endTime := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) - - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{ - TimeRangeCondition{ - Key: "created_at", - Value: TimeRange{ - Gte: &startTime, - Lt: &endTime, - }, - }, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.Must) != 1 { - t.Errorf("expected 1 Must condition, got %d", len(result.Must)) - } -} - -func TestBuildFilter_TimeRangeAllBounds(t *testing.T) { - gt := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) - gte := time.Date(2023, 2, 1, 0, 0, 0, 0, time.UTC) - lt := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) - lte := time.Date(2024, 2, 1, 0, 0, 0, 0, time.UTC) - - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{ - TimeRangeCondition{ - Key: "updated_at", - Value: TimeRange{ - Gt: >, - Gte: >e, - Lt: <, - Lte: <e, - }, - }, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.Must) != 1 { - t.Errorf("expected 1 Must condition, got %d", len(result.Must)) - } -} - -func TestBuildFilter_EmptyTimeRange(t *testing.T) { - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{ - TimeRangeCondition{ - Key: "created_at", - Value: TimeRange{}, // All nil - }, - }, - }, - } - result := buildFilter(filters) - - // Empty TimeRange returns nil condition, so filter should be nil - if result != nil { - t.Errorf("expected nil for empty time range, got %v", result) - } -} - -func TestBuildConditions_NilConditionSet(t *testing.T) { - result := buildConditions(nil) - if result != nil { - t.Errorf("expected nil, got %v", result) - } -} - -func TestBuildConditions_FiltersNilConditions(t *testing.T) { - // Test that buildConditions filters out nil conditions - cs := &ConditionSet{ - Conditions: []FilterCondition{ - TextCondition{Key: "city", Value: "London"}, - TimeRangeCondition{Key: "created_at", Value: TimeRange{}}, // Empty range returns nil - TextAnyCondition{Key: "status", Values: []string{}}, // Empty slice returns nil - BoolCondition{Key: "active", Value: true}, - }, - } - result := buildConditions(cs) - - // Should only have 2 conditions (TextCondition and BoolCondition) - // Empty TimeRange and empty TextAnyCondition should be filtered out - if len(result) != 2 { - t.Errorf("expected 2 conditions (nil ones filtered out), got %d", len(result)) - } -} - -func TestTextCondition_ToQdrantCondition(t *testing.T) { - c := TextCondition{Key: "city", Value: "London"} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestBoolCondition_ToQdrantCondition(t *testing.T) { - c := BoolCondition{Key: "active", Value: true} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestIntCondition_ToQdrantCondition(t *testing.T) { - c := IntCondition{Key: "priority", Value: 42} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestTimeRangeCondition_ToQdrantCondition(t *testing.T) { - now := time.Now() - c := TimeRangeCondition{ - Key: "created_at", - Value: TimeRange{Gte: &now}, - } - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestTimeRangeCondition_EmptyRange(t *testing.T) { - c := TimeRangeCondition{ - Key: "created_at", - Value: TimeRange{}, // All nil - } - result := c.ToQdrantCondition() - - if result != nil { - t.Errorf("expected nil for empty time range, got %v", result) - } -} - -func TestToTimestamp_Nil(t *testing.T) { - result := toTimestamp(nil) - if result != nil { - t.Errorf("expected nil, got %v", result) - } -} - -func TestToTimestamp_ValidTime(t *testing.T) { - now := time.Now() - result := toTimestamp(&now) - - if result == nil { - t.Fatal("expected timestamp, got nil") - } - if result.AsTime().Unix() != now.Unix() { - t.Errorf("timestamp mismatch: expected %v, got %v", now.Unix(), result.AsTime().Unix()) - } -} - -// === FieldType Tests === - -func TestResolveFieldKey_InternalField(t *testing.T) { - key := resolveFieldKey("search_store_id", InternalField) - expected := "search_store_id" - if key != expected { - t.Errorf("expected %q, got %q", expected, key) - } -} - -func TestResolveFieldKey_UserField(t *testing.T) { - key := resolveFieldKey("document_id", UserField) - expected := "custom.document_id" - if key != expected { - t.Errorf("expected %q, got %q", expected, key) - } -} - -func TestResolveFieldKey_UserField_PreventDoublePrefix(t *testing.T) { - // If key already has prefix, don't add again - key := resolveFieldKey("custom.document_id", UserField) - expected := "custom.document_id" - if key != expected { - t.Errorf("expected %q, got %q (double prefix detected)", expected, key) - } -} - -func TestTextCondition_InternalField(t *testing.T) { - c := TextCondition{Key: "search_store_id", Value: "store-123", FieldType: InternalField} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } - // Internal field should NOT have prefix -} - -func TestTextCondition_UserField(t *testing.T) { - c := TextCondition{Key: "document_id", Value: "doc-456", FieldType: UserField} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } - // User field should have "custom." prefix -} - -func TestBoolCondition_UserField(t *testing.T) { - c := BoolCondition{Key: "is_reviewed", Value: true, FieldType: UserField} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestIntCondition_UserField(t *testing.T) { - c := IntCondition{Key: "version", Value: 2, FieldType: UserField} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestTimeRangeCondition_UserField(t *testing.T) { - now := time.Now() - c := TimeRangeCondition{ - Key: "uploaded_at", - Value: TimeRange{Gte: &now}, - FieldType: UserField, - } - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestBuildFilter_MixedInternalAndUserFields(t *testing.T) { - // search_store_id = "store-123" (internal) AND custom.category = "reports" (user) - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{ - TextCondition{Key: "search_store_id", Value: "store-123", FieldType: InternalField}, - TextCondition{Key: "category", Value: "reports", FieldType: UserField}, - BoolCondition{Key: "is_published", Value: true, FieldType: UserField}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.Must) != 3 { - t.Errorf("expected 3 Must conditions, got %d", len(result.Must)) - } -} - -// === BuildPayload Tests === - -func TestBuildPayload_OnlyInternal(t *testing.T) { - internal := map[string]any{ - "search_store_id": "store-123", - "modalities": []string{"text"}, - } - payload := BuildPayload(internal, nil) - - if payload["search_store_id"] != "store-123" { - t.Errorf("expected search_store_id at top-level") - } - if _, exists := payload["custom"]; exists { - t.Errorf("custom should not exist when user is nil") - } -} - -func TestBuildPayload_OnlyUser(t *testing.T) { - user := map[string]any{ - "document_id": "doc-456", - "author": "John", - } - payload := BuildPayload(nil, user) - - custom, ok := payload["custom"].(map[string]any) - if !ok { - t.Fatal("expected custom field") - } - if custom["document_id"] != "doc-456" { - t.Errorf("expected document_id in custom") - } - if custom["author"] != "John" { - t.Errorf("expected author in custom") - } -} - -func TestBuildPayload_BothInternalAndUser(t *testing.T) { - internal := map[string]any{ - "search_store_id": "store-123", - } - user := map[string]any{ - "document_id": "doc-456", - "category": "reports", - } - payload := BuildPayload(internal, user) - - // Check internal at top-level - if payload["search_store_id"] != "store-123" { - t.Errorf("expected search_store_id at top-level") - } - - // Check user under custom - custom, ok := payload["custom"].(map[string]any) - if !ok { - t.Fatal("expected custom field") - } - if custom["document_id"] != "doc-456" { - t.Errorf("expected document_id in custom") - } - if custom["category"] != "reports" { - t.Errorf("expected category in custom") - } -} - -func TestBuildPayload_EmptyUser(t *testing.T) { - internal := map[string]any{ - "search_store_id": "store-123", - } - user := map[string]any{} // Empty, not nil - payload := BuildPayload(internal, user) - - if _, exists := payload["custom"]; exists { - t.Errorf("custom should not exist when user is empty") - } -} - -func TestResolveFieldKey_ActualPath(t *testing.T) { - tests := []struct { - key string - fieldType FieldType - expected string - }{ - {"city", InternalField, "city"}, - {"city", UserField, "custom.city"}, - {"custom.city", UserField, "custom.city"}, // No double prefix - } - - for _, tt := range tests { - result := resolveFieldKey(tt.key, tt.fieldType) - if result != tt.expected { - t.Errorf("resolveFieldKey(%q, %v) = %q, want %q", - tt.key, tt.fieldType, result, tt.expected) - } - } -} - -// === MatchAnyCondition Tests === - -func TestTextAnyCondition_ToQdrantCondition(t *testing.T) { - c := TextAnyCondition{Key: "city", Values: []string{"London", "Berlin", "Paris"}} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestIntAnyCondition_ToQdrantCondition(t *testing.T) { - c := IntAnyCondition{Key: "priority", Values: []int64{1, 2, 3}} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestTextAnyCondition_UserField(t *testing.T) { - c := TextAnyCondition{Key: "category", Values: []string{"tech", "science"}, FieldType: UserField} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestBuildFilter_WithTextAnyCondition(t *testing.T) { - // city IN ("London", "Berlin") - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{ - TextAnyCondition{Key: "city", Values: []string{"London", "Berlin"}}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.Must) != 1 { - t.Errorf("expected 1 Must condition, got %d", len(result.Must)) - } -} - -func TestTextAnyCondition_EmptySlice(t *testing.T) { - // Empty slice should return nil - c := TextAnyCondition{Key: "city", Values: []string{}} - result := c.ToQdrantCondition() - - if result != nil { - t.Errorf("expected nil for empty slice, got %v", result) - } -} - -func TestIntAnyCondition_EmptySlice(t *testing.T) { - // Empty slice should return nil - c := IntAnyCondition{Key: "priority", Values: []int64{}} - result := c.ToQdrantCondition() - - if result != nil { - t.Errorf("expected nil for empty slice, got %v", result) - } -} - -func TestBuildFilter_WithEmptyTextAnyCondition(t *testing.T) { - // Empty TextAnyCondition should be filtered out - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{ - TextAnyCondition{Key: "city", Values: []string{}}, - TextCondition{Key: "status", Value: "active"}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - // Should only have the TextCondition, empty TextAnyCondition should be filtered out - if len(result.Must) != 1 { - t.Errorf("expected 1 Must condition (empty one filtered out), got %d", len(result.Must)) - } -} - -// === MatchExceptCondition Tests === - -func TestTextExceptCondition_ToQdrantCondition(t *testing.T) { - c := TextExceptCondition{Key: "city", Values: []string{"Paris", "Madrid"}} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestIntExceptCondition_ToQdrantCondition(t *testing.T) { - c := IntExceptCondition{Key: "priority", Values: []int64{0, -1}} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestTextExceptCondition_UserField(t *testing.T) { - c := TextExceptCondition{Key: "status", Values: []string{"draft", "deleted"}, FieldType: UserField} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestBuildFilter_WithTextExceptCondition(t *testing.T) { - // city NOT IN ("Paris", "Madrid") - filters := &FilterSet{ - MustNot: &ConditionSet{ - Conditions: []FilterCondition{ - TextExceptCondition{Key: "city", Values: []string{"Paris", "Madrid"}}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.MustNot) != 1 { - t.Errorf("expected 1 MustNot condition, got %d", len(result.MustNot)) - } -} - -func TestTextExceptCondition_EmptySlice(t *testing.T) { - // Empty slice should return nil - c := TextExceptCondition{Key: "city", Values: []string{}} - result := c.ToQdrantCondition() - - if result != nil { - t.Errorf("expected nil for empty slice, got %v", result) - } -} - -func TestIntExceptCondition_EmptySlice(t *testing.T) { - // Empty slice should return nil - c := IntExceptCondition{Key: "priority", Values: []int64{}} - result := c.ToQdrantCondition() - - if result != nil { - t.Errorf("expected nil for empty slice, got %v", result) - } -} - -func TestBuildFilter_WithEmptyTextExceptCondition(t *testing.T) { - // Empty TextExceptCondition should be filtered out - filters := &FilterSet{ - MustNot: &ConditionSet{ - Conditions: []FilterCondition{ - TextExceptCondition{Key: "city", Values: []string{}}, - BoolCondition{Key: "deleted", Value: true}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - // Should only have the BoolCondition, empty TextExceptCondition should be filtered out - if len(result.MustNot) != 1 { - t.Errorf("expected 1 MustNot condition (empty one filtered out), got %d", len(result.MustNot)) - } -} - -// === NumericRangeCondition Tests === - -func TestNumericRangeCondition_ToQdrantCondition(t *testing.T) { - minPrice := 100.0 - maxPrice := 500.0 - c := NumericRangeCondition{ - Key: "price", - Value: NumericRange{ - Gte: &minPrice, - Lte: &maxPrice, - }, - } - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestNumericRangeCondition_AllBounds(t *testing.T) { - gt := 10.0 - gte := 20.0 - lt := 100.0 - lte := 90.0 - c := NumericRangeCondition{ - Key: "score", - Value: NumericRange{ - Gt: >, - Gte: >e, - Lt: <, - Lte: <e, - }, - } - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestNumericRangeCondition_EmptyRange(t *testing.T) { - c := NumericRangeCondition{ - Key: "price", - Value: NumericRange{}, // All nil - } - result := c.ToQdrantCondition() - - if result != nil { - t.Errorf("expected nil for empty numeric range, got %v", result) - } -} - -func TestNumericRangeCondition_UserField(t *testing.T) { - minPrice := 50.0 - c := NumericRangeCondition{ - Key: "price", - Value: NumericRange{Gte: &minPrice}, - FieldType: UserField, - } - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestBuildFilter_WithNumericRangeCondition(t *testing.T) { - minPrice := 100.0 - maxPrice := 500.0 - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{ - NumericRangeCondition{ - Key: "price", - Value: NumericRange{ - Gte: &minPrice, - Lte: &maxPrice, - }, - }, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.Must) != 1 { - t.Errorf("expected 1 Must condition, got %d", len(result.Must)) - } -} - -// === IsNullCondition Tests === - -func TestIsNullCondition_ToQdrantCondition(t *testing.T) { - c := IsNullCondition{Key: "deleted_at"} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestIsNullCondition_UserField(t *testing.T) { - c := IsNullCondition{Key: "review_date", FieldType: UserField} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestBuildFilter_WithIsNullCondition(t *testing.T) { - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{ - IsNullCondition{Key: "deleted_at"}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.Must) != 1 { - t.Errorf("expected 1 Must condition, got %d", len(result.Must)) - } -} - -// === IsEmptyCondition Tests === - -func TestIsEmptyCondition_ToQdrantCondition(t *testing.T) { - c := IsEmptyCondition{Key: "tags"} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestIsEmptyCondition_UserField(t *testing.T) { - c := IsEmptyCondition{Key: "categories", FieldType: UserField} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestBuildFilter_WithIsEmptyCondition(t *testing.T) { - // Find documents where tags is NOT empty (using MustNot) - filters := &FilterSet{ - MustNot: &ConditionSet{ - Conditions: []FilterCondition{ - IsEmptyCondition{Key: "tags"}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.MustNot) != 1 { - t.Errorf("expected 1 MustNot condition, got %d", len(result.MustNot)) - } -} - -// === MarshalJSON Tests === - -func TestTimeRangeCondition_MarshalJSON(t *testing.T) { - startTime := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) - endTime := time.Date(2024, 12, 31, 0, 0, 0, 0, time.UTC) - - c := TimeRangeCondition{ - Key: "created_at", - Value: TimeRange{ - Gte: &startTime, - Lt: &endTime, - }, - } - - data, err := c.MarshalJSON() - if err != nil { - t.Fatalf("MarshalJSON failed: %v", err) - } - - jsonStr := string(data) - // Check that it contains expected fields - if !contains(jsonStr, `"field":"created_at"`) { - t.Errorf("expected field in JSON, got %s", jsonStr) - } - if !contains(jsonStr, `"atOrAfter"`) { - t.Errorf("expected atOrAfter in JSON, got %s", jsonStr) - } - if !contains(jsonStr, `"before"`) { - t.Errorf("expected before in JSON, got %s", jsonStr) - } -} - -func TestNumericRangeCondition_MarshalJSON(t *testing.T) { - minPrice := 100.0 - maxPrice := 500.0 - - c := NumericRangeCondition{ - Key: "price", - Value: NumericRange{ - Gte: &minPrice, - Lte: &maxPrice, - }, - } - - data, err := c.MarshalJSON() - if err != nil { - t.Fatalf("MarshalJSON failed: %v", err) - } - - jsonStr := string(data) - if !contains(jsonStr, `"field":"price"`) { - t.Errorf("expected field in JSON, got %s", jsonStr) - } - if !contains(jsonStr, `"greaterThanOrEqualTo"`) { - t.Errorf("expected greaterThanOrEqualTo in JSON, got %s", jsonStr) - } - if !contains(jsonStr, `"lessThanOrEqualTo"`) { - t.Errorf("expected lessThanOrEqualTo in JSON, got %s", jsonStr) - } -} - -func TestTimeRangeCondition_UnmarshalJSON(t *testing.T) { - jsonData := `{ - "field": "created_at", - "atOrAfter": "2024-01-01T00:00:00Z", - "before": "2024-12-31T00:00:00Z" - }` - - var c TimeRangeCondition - err := json.Unmarshal([]byte(jsonData), &c) - if err != nil { - t.Fatalf("UnmarshalJSON failed: %v", err) - } - - if c.Key != "created_at" { - t.Errorf("expected field 'created_at', got %q", c.Key) - } - if c.Value.Gte == nil { - t.Error("expected Gte to be set") - } - if c.Value.Lt == nil { - t.Error("expected Lt to be set") - } - if c.Value.Gte != nil { - expected := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) - if !c.Value.Gte.Equal(expected) { - t.Errorf("expected Gte %v, got %v", expected, c.Value.Gte) - } - } -} - -func TestTimeRangeCondition_RoundTripJSON(t *testing.T) { - startTime := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) - endTime := time.Date(2024, 12, 31, 0, 0, 0, 0, time.UTC) - - original := TimeRangeCondition{ - Key: "created_at", - Value: TimeRange{ - Gte: &startTime, - Lt: &endTime, - }, - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("Marshal failed: %v", err) - } - - var unmarshaled TimeRangeCondition - err = json.Unmarshal(data, &unmarshaled) - if err != nil { - t.Fatalf("Unmarshal failed: %v", err) - } - - if unmarshaled.Key != original.Key { - t.Errorf("field mismatch: expected %q, got %q", original.Key, unmarshaled.Key) - } - if unmarshaled.Value.Gte == nil || !unmarshaled.Value.Gte.Equal(*original.Value.Gte) { - t.Errorf("Gte mismatch: expected %v, got %v", original.Value.Gte, unmarshaled.Value.Gte) - } - if unmarshaled.Value.Lt == nil || !unmarshaled.Value.Lt.Equal(*original.Value.Lt) { - t.Errorf("Lt mismatch: expected %v, got %v", original.Value.Lt, unmarshaled.Value.Lt) - } -} - -func TestNumericRangeCondition_UnmarshalJSON(t *testing.T) { - jsonData := `{ - "field": "price", - "greaterThanOrEqualTo": 100.0, - "lessThanOrEqualTo": 500.0 - }` - - var c NumericRangeCondition - err := json.Unmarshal([]byte(jsonData), &c) - if err != nil { - t.Fatalf("UnmarshalJSON failed: %v", err) - } - - if c.Key != "price" { - t.Errorf("expected field 'price', got %q", c.Key) - } - if c.Value.Gte == nil || *c.Value.Gte != 100.0 { - t.Errorf("expected Gte to be 100.0, got %v", c.Value.Gte) - } - if c.Value.Lte == nil || *c.Value.Lte != 500.0 { - t.Errorf("expected Lte to be 500.0, got %v", c.Value.Lte) - } -} - -func TestNumericRangeCondition_RoundTripJSON(t *testing.T) { - minPrice := 100.0 - maxPrice := 500.0 - - original := NumericRangeCondition{ - Key: "price", - Value: NumericRange{ - Gte: &minPrice, - Lte: &maxPrice, - }, - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("Marshal failed: %v", err) - } - - var unmarshaled NumericRangeCondition - err = json.Unmarshal(data, &unmarshaled) - if err != nil { - t.Fatalf("Unmarshal failed: %v", err) - } - - if unmarshaled.Key != original.Key { - t.Errorf("field mismatch: expected %q, got %q", original.Key, unmarshaled.Key) - } - if unmarshaled.Value.Gte == nil || *unmarshaled.Value.Gte != *original.Value.Gte { - t.Errorf("Gte mismatch: expected %v, got %v", original.Value.Gte, unmarshaled.Value.Gte) - } - if unmarshaled.Value.Lte == nil || *unmarshaled.Value.Lte != *original.Value.Lte { - t.Errorf("Lte mismatch: expected %v, got %v", original.Value.Lte, unmarshaled.Value.Lte) - } -} - -// === Complex Filter Tests === - -func TestBuildFilter_ComplexCombination(t *testing.T) { - startTime := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) - minPrice := 100.0 - - // Complex filter: - // (search_store_id = "store-123" AND created_at >= 2024-01-01 AND price >= 100) - // AND (city IN ("London", "Berlin")) - // AND NOT (deleted = true OR status IN ("draft", "archived")) - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{ - TextCondition{Key: "search_store_id", Value: "store-123"}, - TimeRangeCondition{ - Key: "created_at", - Value: TimeRange{Gte: &startTime}, - }, - NumericRangeCondition{ - Key: "price", - Value: NumericRange{Gte: &minPrice}, - FieldType: UserField, - }, - }, - }, - Should: &ConditionSet{ - Conditions: []FilterCondition{ - TextAnyCondition{Key: "city", Values: []string{"London", "Berlin"}}, - }, - }, - MustNot: &ConditionSet{ - Conditions: []FilterCondition{ - BoolCondition{Key: "deleted", Value: true}, - TextAnyCondition{Key: "status", Values: []string{"draft", "archived"}}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.Must) != 3 { - t.Errorf("expected 3 Must conditions, got %d", len(result.Must)) - } - if len(result.Should) != 1 { - t.Errorf("expected 1 Should condition, got %d", len(result.Should)) - } - if len(result.MustNot) != 2 { - t.Errorf("expected 2 MustNot conditions, got %d", len(result.MustNot)) - } -} - -// Helper function for string contains check -func contains(s, substr string) bool { - return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr)) -} - -func containsHelper(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false -} diff --git a/v1/qdrant/operations.go b/v1/qdrant/operations.go index fbf3e2b..7a51989 100644 --- a/v1/qdrant/operations.go +++ b/v1/qdrant/operations.go @@ -5,452 +5,207 @@ import ( "fmt" "log" "slices" + "sync" + "github.com/Aleph-Alpha/std/v1/vectordb" qdrant "github.com/qdrant/go-client/qdrant" - "golang.org/x/sync/errgroup" "golang.org/x/sync/semaphore" ) -// EnsureCollection ────────────────────────────────────────────────────────────── -// EnsureCollection -// ────────────────────────────────────────────────────────────── +// Ensure Adapter implements Service at compile time +var _ vectordb.Service = (*Adapter)(nil) + +// ══════════════════════════════════════════════════════════════════════════════ +// Adapter - implements vectordb.Service interface +// ══════════════════════════════════════════════════════════════════════════════ + +// Adapter implements vectordb.Service for Qdrant. +// It wraps a Qdrant client and converts between generic vectordb types +// and Qdrant-specific protobuf types. // -// EnsureCollection verifies if a given collection exists, and creates it if missing. +// This is the recommended way to use Qdrant - it provides a database-agnostic +// interface that allows switching between different vector databases. +type Adapter struct { + client *qdrant.Client +} + +// NewAdapter creates a new Qdrant adapter for the vectordb interface. +// Pass the underlying SDK client via QdrantClient.Client(). // -// It’s safe to call this multiple times — if the collection already exists, -// the function exits early. This pattern simplifies startup logic for embedding -// services that may bootstrap their own Qdrant collections. -func (c *QdrantClient) EnsureCollection(ctx context.Context, name string) error { - if name == "" { - return fmt.Errorf("collection name cannot be empty") - } +// Example: +// +// qc, _ := qdrant.NewQdrantClient(params) +// adapter := qdrant.NewAdapter(qc.Client()) +// var db vectordb.Service = adapter +func NewAdapter(client *qdrant.Client) *Adapter { + return &Adapter{client: client} +} - collections, err := c.api.ListCollections(ctx) - if err != nil { - return fmt.Errorf("[Qdrant] failed to list collections: %w", err) +// Search performs similarity search across one or more requests. +func (a *Adapter) Search(ctx context.Context, requests ...vectordb.SearchRequest) ([][]vectordb.SearchResult, []error, error) { + if len(requests) == 0 { + return nil, nil, fmt.Errorf("at least one search request is required") } - if slices.Contains(collections, name) { - log.Printf("[Qdrant] Collection '%s' already exists", name) - return nil + log.Printf("[Qdrant] Starting search batch with %d requests", len(requests)) + + // Validate all requests first + for i, searchReq := range requests { + if err := validateSearchInput(searchReq.CollectionName, searchReq.Vector, searchReq.TopK); err != nil { + return nil, nil, fmt.Errorf("request [%d]: %w", i, err) + } } - log.Printf("[Qdrant] Collection '%s' not found, creating it...", name) + results := make([][]vectordb.SearchResult, len(requests)) + errs := make([]error, len(requests)) - req := &qdrant.CreateCollection{ - CollectionName: name, - VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{ - Size: 1536, // default dimension (model-dependent) - Distance: qdrant.Distance_Cosine, // cosine similarity - }), - } + // Use WaitGroup for partial results + var wg sync.WaitGroup - if err := c.api.CreateCollection(ctx, req); err != nil { - return fmt.Errorf("[Qdrant] failed to create collection '%s': %w", name, err) + // Create semaphore to limit concurrent searches + sem := semaphore.NewWeighted(maxConcurrentSearches) + + for i, searchReq := range requests { + i, searchReq := i, searchReq // Capture loop variables + wg.Add(1) + go func() { + defer wg.Done() + // Acquire semaphore (blocks if at max concurrency) + if err := sem.Acquire(ctx, 1); err != nil { + errs[i] = fmt.Errorf("request [%d]: failed to acquire semaphore: %w", i, err) + results[i] = []vectordb.SearchResult{} + return + } + defer sem.Release(1) + + res, err := searchInternal(ctx, a.client, searchReq) + if err != nil { + errs[i] = fmt.Errorf("request [%d]: search failed: %w", i, err) + results[i] = []vectordb.SearchResult{} + return + } + results[i] = res + log.Printf("[Qdrant] Search request [%d] returned %d results", i, len(res)) + }() } - log.Printf("[Qdrant] Created collection '%s' successfully", name) - return nil -} + wg.Wait() -// Insert ────────────────────────────────────────────────────────────── -// Insert -// ────────────────────────────────────────────────────────────── -// -// Insert adds a single embedding to Qdrant. -// -// Internally, it reuses the BatchInsert logic to ensure consistent handling -// of payload serialization and error management. -func (c *QdrantClient) Insert(ctx context.Context, collectionName string, input EmbeddingInput) error { - return c.BatchInsert(ctx, collectionName, []EmbeddingInput{input}) + // Check for systemic failure (context cancelled) + if ctx.Err() != nil { + return results, errs, fmt.Errorf("search batch interrupted: %w", ctx.Err()) + } + return results, errs, nil } -// BatchInsert ────────────────────────────────────────────────────────────── -// BatchInsert -// ────────────────────────────────────────────────────────────── -// -// BatchInsert efficiently inserts multiple embeddings in batches -// to reduce network overhead. -// -// This method is safe to call for large datasets — it will automatically -// split inserts into smaller chunks (`defaultBatchSize`) and perform -// multiple upserts sequentially. -// -// Logs batch indices and collection name for debugging. -func (c *QdrantClient) BatchInsert(ctx context.Context, collectionName string, inputs []EmbeddingInput) error { +// Insert adds embeddings to a collection using batch processing. +func (a *Adapter) Insert(ctx context.Context, collectionName string, inputs []vectordb.EmbeddingInput) error { if len(inputs) == 0 { return nil } - if collectionName == "" { return fmt.Errorf("collection name cannot be empty") } - - // Convert all inputs into internal embeddings - embeddings := make([]Embedding, len(inputs)) - for i, in := range inputs { - embeddings[i] = NewEmbedding(in) - } - - for start := 0; start < len(embeddings); start += defaultBatchSize { - end := start + defaultBatchSize - if end > len(embeddings) { - end = len(embeddings) - } - batch := embeddings[start:end] - - if err := c.upsertBatch(ctx, batch, collectionName); err != nil { - return fmt.Errorf("[Qdrant] batch upsert failed at [%d:%d]: %w", start, end, err) - } - log.Printf("[Qdrant] Inserted batch [%d:%d] (collection=%s)", start, end, collectionName) - } - - return nil + return insertInternal(ctx, a.client, collectionName, inputs) } -// ────────────────────────────────────────────────────────────── -// upsertBatch -// ────────────────────────────────────────────────────────────── -// -// upsertBatch sends a single `Upsert` request for a slice of embeddings. -// -// Converts Embedding structs into Qdrant’s `PointStruct` objects and -// triggers a blocking insert (`Wait=true`) to ensure data persistence -// before returning. -func (c *QdrantClient) upsertBatch(ctx context.Context, batch []Embedding, collectionName string) error { - points := make([]*qdrant.PointStruct, 0, len(batch)) - for _, e := range batch { - points = append(points, &qdrant.PointStruct{ - Id: qdrant.NewID(e.ID), - Vectors: qdrant.NewVectors(e.Vector...), - Payload: qdrant.NewValueMap(e.Meta), - }) - } - - wait := true - req := &qdrant.UpsertPoints{ - CollectionName: collectionName, - Points: points, - Wait: &wait, - } - - if _, err := c.api.Upsert(ctx, req); err != nil { - return fmt.Errorf("[Qdrant] upsert failed: %w", err) - } - return nil +// Delete removes points by their IDs from a collection. +func (a *Adapter) Delete(ctx context.Context, collectionName string, ids []string) error { + return deleteInternal(ctx, a.client, collectionName, ids) } -// GetCollection ────────────────────────────────────────────────────────────── -// GetCollection -// ────────────────────────────────────────────────────────────── -// -// GetCollection retrieves detailed metadata about a specific collection -// from the connected Qdrant instance. -// -// It returns a high-level, decoupled `Collection` struct containing -// core details such as: -// • Collection name -// • Status (e.g., "Green", "Yellow") -// • Total vectors and points -// • Vector size (embedding dimension) -// • Distance metric (e.g., "Cosine", "Dot", "Euclid") -// -// This abstraction intentionally hides Qdrant SDK internals (`qdrant.CollectionInfo`) -// so that the application layer remains independent of Qdrant’s client library. -// -// Example: -// -// collection, err := client.GetCollection(ctx, "my_collection") -// if err != nil { -// log.Printf("Failed to fetch collection info: %v", err) -// return -// } -// log.Printf("Collection '%s': vectors=%d, points=%d, vector_size=%d, distance=%s", -// collection.Name, -// collection.Vectors, -// collection.Points, -// collection.VectorSize, -// collection.Distance, -// ) - -func (c *QdrantClient) GetCollection(ctx context.Context, name string) (*Collection, error) { - if c.api == nil { - return nil, fmt.Errorf("[Qdrant] client not initialized") - } +// EnsureCollection creates a collection if it doesn't exist. +func (a *Adapter) EnsureCollection(ctx context.Context, name string, vectorSize uint64) error { + return ensureCollectionInternal(ctx, a.client, name, vectorSize) +} +// GetCollection retrieves metadata about a collection. +func (a *Adapter) GetCollection(ctx context.Context, name string) (*vectordb.Collection, error) { if name == "" { return nil, fmt.Errorf("collection name cannot be empty") } - info, err := c.api.GetCollectionInfo(ctx, name) + info, err := a.client.GetCollectionInfo(ctx, name) if err != nil { - return nil, fmt.Errorf("[Qdrant] failed to get collection '%s': %w", name, err) + return nil, fmt.Errorf("failed to get collection '%s': %w", name, err) } size, distance := extractVectorDetails(info) - collection := &Collection{ - Name: name, - Status: info.Status.String(), - Vectors: derefUint64(info.IndexedVectorsCount), - Points: derefUint64(info.PointsCount), - VectorSize: size, - Distance: distance, + collection := &vectordb.Collection{ + Name: name, + Status: info.Status.String(), + VectorSize: size, + Distance: distance, + VectorCount: derefUint64(info.IndexedVectorsCount), + PointCount: derefUint64(info.PointsCount), } return collection, nil } -// ListCollections ────────────────────────────────────────────────────────────── -// ListCollections -// ────────────────────────────────────────────────────────────── -// -// ListCollections retrieves all existing collections from Qdrant and returns -// their names as a string slice. This can be extended to preload metadata -// using GetCollection for each name if needed. -// -// Example: -// -// names, err := client.ListCollections(ctx) -// if err != nil { -// log.Fatalf("failed to list collections: %v", err) -// } -// log.Printf("Found collections: %v", names) -func (c *QdrantClient) ListCollections(ctx context.Context) ([]string, error) { - if c.api == nil { - return nil, fmt.Errorf("[Qdrant] client not initialized") - } - - names, err := c.api.ListCollections(ctx) - if err != nil { - return nil, fmt.Errorf("[Qdrant] failed to list collections: %w", err) - } - - log.Printf("[Qdrant] Found %d collections", len(names)) - return names, nil +// ListCollections returns names of all collections. +func (a *Adapter) ListCollections(ctx context.Context) ([]string, error) { + return listCollectionsInternal(ctx, a.client) } -// Search ────────────────────────────────────────────────────────────── -// Search -// ────────────────────────────────────────────────────────────── -// -// Search performs a similarity search in the configured collection. -// -// Parameters: -// - collectionName — the collection to search in -// - vector — the query embedding to search against. -// - topK — maximum number of nearest results to return. -// -// Returns: -// -// A slice of `SearchResultInterface` instances representing the -// most similar stored embeddings. -// func (c *QdrantClient) Search(ctx context.Context, collectionName string, vector []float32, topK int) ([]SearchResultInterface, error) { -// if err := validateSearchInput(collectionName, vector, topK); err != nil { -// return nil, err -// } - -// limit := uint64(topK) -// req := &qdrant.QueryPoints{ -// CollectionName: collectionName, -// Query: qdrant.NewQuery(vector...), -// Limit: &limit, -// WithPayload: qdrant.NewWithPayload(true), -// } - -// resp, err := c.api.Query(ctx, req) -// if err != nil { -// return nil, fmt.Errorf("[Qdrant] search failed: %w", err) -// } - -// results, err := c.parseSearchResults(resp) -// if err != nil { -// return nil, err -// } - -// log.Printf("[Qdrant] Search returned %d results", len(results)) -// return results, nil -// } - -// SearchWithFilter ────────────────────────────────────────────────────────────── -// SearchWithFilter -// ────────────────────────────────────────────────────────────── -// -// SearchWithFilter performs a similarity search with required filters (AND logic). -// Returns error if filters is nil or empty - use Search() for unfiltered searches. -// -// Parameters: -// - collectionName — the collection to search in -// - vector — the query embedding to search against -// - topK — maximum number of nearest results to return -// - filters — required key-value filters (all must match) -// -// Returns: -// -// A slice of SearchResultInterface instances matching the filter criteria. -// func (c *QdrantClient) SearchWithFilter(ctx context.Context, collectionName string, vector []float32, topK int, filters map[string]string) ([]SearchResultInterface, error) { -// if err := validateSearchInput(collectionName, vector, topK); err != nil { -// return nil, err -// } - -// if len(filters) == 0 { -// return nil, fmt.Errorf("filters cannot be empty, use Search() for unfiltered searches") -// } - -// limit := uint64(topK) -// req := &qdrant.QueryPoints{ -// CollectionName: collectionName, -// Query: qdrant.NewQuery(vector...), -// Limit: &limit, -// WithPayload: qdrant.NewWithPayload(true), -// Filter: buildFilter(filters), -// } - -// resp, err := c.api.Query(ctx, req) -// if err != nil { -// return nil, fmt.Errorf("[Qdrant] search failed: %w", err) -// } - -// results, err := c.parseSearchResults(resp) -// if err != nil { -// return nil, err -// } - -// log.Printf("[Qdrant] SearchWithFilter returned %d results", len(results)) -// return results, nil -// } - -// executeSearch performs a single search request against Qdrant -func (c *QdrantClient) executeSearch(ctx context.Context, searchReq SearchRequest) ([]SearchResultInterface, error) { +// ══════════════════════════════════════════════════════════════════════════════ +// Internal Functions +// ══════════════════════════════════════════════════════════════════════════════ + +func searchInternal(ctx context.Context, client *qdrant.Client, searchReq vectordb.SearchRequest) ([]vectordb.SearchResult, error) { limit := uint64(searchReq.TopK) - req := &qdrant.QueryPoints{ + queryReq := &qdrant.QueryPoints{ CollectionName: searchReq.CollectionName, Query: qdrant.NewQuery(searchReq.Vector...), Limit: &limit, WithPayload: qdrant.NewWithPayload(true), - Filter: buildFilter(searchReq.Filters), + Filter: convertVectorDBFilterSets(searchReq.Filters), } - resp, err := c.api.Query(ctx, req) + resp, err := client.Query(ctx, queryReq) if err != nil { - return nil, fmt.Errorf("search failed: %w", err) + return nil, fmt.Errorf("query failed: %w", err) } - - return c.parseSearchResults(resp) + return parseVectorDBSearchResults(resp) } -// Search ────────────────────────────────────────────────────────────── -// Search -// ────────────────────────────────────────────────────────────── -// -// Search performs multiple searches and returns results for each request. -// Each request can optionally include filters. -// -// Parameters: -// - requests — variadic SearchRequest structs, each containing: -// - CollectionName: the collection to search in -// - Vector: the query embedding -// - TopK: maximum number of results per request -// - Filters: optional key-value filters (AND logic) -// -// Returns: -// -// A slice of result slices — one []SearchResultInterface per request. -// -// Example: -// -// results, err := client.Search(ctx, -// SearchRequest{CollectionName: "docs", Vector: vec1, TopK: 10, Filters: map[string]string{"partition_id": "store-A"}}, -// SearchRequest{CollectionName: "docs", Vector: vec2, TopK: 5}, -// ) -// // results[0] = results for first request -// // results[1] = results for second request -func (c *QdrantClient) Search(ctx context.Context, requests ...SearchRequest) ([][]SearchResultInterface, error) { - if len(requests) == 0 { - return nil, fmt.Errorf("at least one search request is required") - } - - log.Printf("[Qdrant] Starting search batch with %d requests", len(requests)) - - // Validate all requests first (fail fast) - for i, searchReq := range requests { - if err := validateSearchInput(searchReq.CollectionName, searchReq.Vector, searchReq.TopK); err != nil { - return nil, fmt.Errorf("request [%d]: %w", i, err) +func insertInternal(ctx context.Context, client *qdrant.Client, collectionName string, inputs []vectordb.EmbeddingInput) error { + for start := 0; start < len(inputs); start += defaultBatchSize { + end := start + defaultBatchSize + if end > len(inputs) { + end = len(inputs) + } + batch := inputs[start:end] + + points := make([]*qdrant.PointStruct, 0, len(batch)) + for _, e := range batch { + points = append(points, &qdrant.PointStruct{ + Id: qdrant.NewID(e.ID), + Vectors: qdrant.NewVectors(e.Vector...), + Payload: qdrant.NewValueMap(e.Payload), + }) } - } - - results := make([][]SearchResultInterface, len(requests)) - - // Create errgroup with context - g, ctx := errgroup.WithContext(ctx) - - // Semaphore to limit concurrency - sem := semaphore.NewWeighted(maxConcurrentSearches) - - for i, searchReq := range requests { - i, searchReq := i, searchReq // Capture loop variables - - g.Go(func() error { - // Acquire semaphore (blocks if at max concurrency) - if err := sem.Acquire(ctx, 1); err != nil { - return fmt.Errorf("request [%d]: failed to acquire semaphore: %w", i, err) - } - defer sem.Release(1) - - res, err := c.executeSearch(ctx, searchReq) - if err != nil { - return fmt.Errorf("request [%d]: search failed: %w", i, err) - } - - results[i] = res - log.Printf("[Qdrant] Search request [%d] returned %d results", i, len(res)) - return nil - }) - } - - if err := g.Wait(); err != nil { - return nil, fmt.Errorf("search batch failed: %w", err) - } - - return results, nil -} -// parseSearchResults converts Qdrant response to SearchResultInterface slice -func (c *QdrantClient) parseSearchResults(resp []*qdrant.ScoredPoint) ([]SearchResultInterface, error) { - results := make([]SearchResultInterface, 0, len(resp)) - for _, r := range resp { - var id string - switch v := r.Id.PointIdOptions.(type) { - case *qdrant.PointId_Num: - id = fmt.Sprintf("%d", v.Num) - case *qdrant.PointId_Uuid: - id = v.Uuid - default: - return nil, fmt.Errorf("[Qdrant] unexpected PointId type: %T", v) + wait := true + req := &qdrant.UpsertPoints{ + CollectionName: collectionName, + Points: points, + Wait: &wait, } - results = append(results, SearchResult{ - ID: id, - Score: r.Score, - Meta: r.Payload, - }) + if _, err := client.Upsert(ctx, req); err != nil { + return fmt.Errorf("[Qdrant] batch upsert failed at [%d:%d]: %w", start, end, err) + } + log.Printf("[Qdrant] Inserted batch [%d:%d] (collection=%s)", start, end, collectionName) } - return results, nil + return nil } -// Delete ────────────────────────────────────────────────────────────── -// Delete -// ────────────────────────────────────────────────────────────── -// -// Delete removes embeddings from a collection by their IDs. -// -// It constructs a `DeletePoints` request containing a list of `PointId`s, -// waits synchronously for completion, and logs the operation status. -func (c *QdrantClient) Delete(ctx context.Context, collectionName string, ids []string) error { +func deleteInternal(ctx context.Context, client *qdrant.Client, collectionName string, ids []string) error { if len(ids) == 0 { return nil } - if collectionName == "" { return fmt.Errorf("collection name cannot be empty") } @@ -471,12 +226,53 @@ func (c *QdrantClient) Delete(ctx context.Context, collectionName string, ids [] Wait: &wait, } - resp, err := c.api.Delete(ctx, req) + resp, err := client.Delete(ctx, req) if err != nil { return fmt.Errorf("[Qdrant] delete failed: %w", err) } - log.Printf("[Qdrant] Delete completed (status=%s, collection=%s)", - resp.Status.String(), collectionName) + log.Printf("[Qdrant] Delete completed (status=%s, collection=%s)", resp.Status.String(), collectionName) + return nil +} + +func ensureCollectionInternal(ctx context.Context, client *qdrant.Client, name string, vectorSize uint64) error { + if name == "" { + return fmt.Errorf("collection name cannot be empty") + } + + collections, err := client.ListCollections(ctx) + if err != nil { + return fmt.Errorf("[Qdrant] failed to list collections: %w", err) + } + + if slices.Contains(collections, name) { + log.Printf("[Qdrant] Collection '%s' already exists", name) + return nil + } + + log.Printf("[Qdrant] Collection '%s' not found, creating it...", name) + + req := &qdrant.CreateCollection{ + CollectionName: name, + VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{ + Size: vectorSize, + Distance: qdrant.Distance_Cosine, + }), + } + + if err := client.CreateCollection(ctx, req); err != nil { + return fmt.Errorf("[Qdrant] failed to create collection '%s': %w", name, err) + } + + log.Printf("[Qdrant] Created collection '%s' successfully", name) return nil } + +func listCollectionsInternal(ctx context.Context, client *qdrant.Client) ([]string, error) { + names, err := client.ListCollections(ctx) + if err != nil { + return nil, fmt.Errorf("[Qdrant] failed to list collections: %w", err) + } + log.Printf("[Qdrant] Found %d collections", len(names)) + return names, nil +} diff --git a/v1/qdrant/qdrant_integration_test.go b/v1/qdrant/qdrant_integration_test.go index 100e707..69bf199 100644 --- a/v1/qdrant/qdrant_integration_test.go +++ b/v1/qdrant/qdrant_integration_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/Aleph-Alpha/std/v1/vectordb" "github.com/docker/docker/api/types/container" "github.com/docker/go-connections/nat" "github.com/stretchr/testify/assert" @@ -189,76 +190,80 @@ func TestQdrantWithFXModule(t *testing.T) { err = qdrantClient.healthCheck() assert.NoError(t, err) + // Create adapter for operations + adapter := NewAdapter(qdrantClient.api) + // Test collection operations t.Run("EnsureCollection", func(t *testing.T) { // First call should create the collection - err := qdrantClient.EnsureCollection(ctx, "test_collection_1") + err := adapter.EnsureCollection(ctx, "test_collection_1", 1536) assert.NoError(t, err) // Second call should be idempotent - err = qdrantClient.EnsureCollection(ctx, "test_collection_1") + err = adapter.EnsureCollection(ctx, "test_collection_1", 1536) assert.NoError(t, err) // Empty collection name should fail - err = qdrantClient.EnsureCollection(ctx, "") + err = adapter.EnsureCollection(ctx, "", 1536) assert.Error(t, err) }) // Test basic CRUD operations t.Run("BasicCRUDOperations", func(t *testing.T) { collectionName := "test_crud" - err := qdrantClient.EnsureCollection(ctx, collectionName) + err := adapter.EnsureCollection(ctx, collectionName, 1536) require.NoError(t, err) - // Insert single embedding (use numeric ID or UUID format) - embedding := EmbeddingInput{ - ID: "00000000-0000-0000-0000-000000000001", // UUID format + // Insert single embedding (use UUID format) + embedding := vectordb.EmbeddingInput{ + ID: "00000000-0000-0000-0000-000000000001", Vector: generateRandomVector(1536), - Meta: map[string]any{ + Payload: map[string]any{ "title": "Test Document 1", "content": "This is a test document", }, } - err = qdrantClient.Insert(ctx, collectionName, embedding) + err = adapter.Insert(ctx, collectionName, []vectordb.EmbeddingInput{embedding}) assert.NoError(t, err) // Search for the inserted embedding time.Sleep(1 * time.Second) // Allow time for indexing - batchResults, err := qdrantClient.Search(ctx, SearchRequest{ + batchResults, errs, err := adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: collectionName, Vector: embedding.Vector, TopK: 5, }) assert.NoError(t, err) + assert.Nil(t, errs[0]) assert.Greater(t, len(batchResults), 0) results := batchResults[0] assert.Greater(t, len(results), 0) // Verify the result if len(results) > 0 { - assert.Equal(t, embedding.ID, results[0].GetID()) - assert.Greater(t, results[0].GetScore(), float32(0.9)) // Should be very similar + assert.Equal(t, embedding.ID, results[0].ID) + assert.Greater(t, results[0].Score, float32(0.9)) // Should be very similar } // Delete the embedding - err = qdrantClient.Delete(ctx, collectionName, []string{embedding.ID}) + err = adapter.Delete(ctx, collectionName, []string{embedding.ID}) assert.NoError(t, err) }) // Test batch insert t.Run("BatchInsert", func(t *testing.T) { collectionName := "test_batch" - err := qdrantClient.EnsureCollection(ctx, collectionName) + err := adapter.EnsureCollection(ctx, collectionName, 1536) require.NoError(t, err) // Create multiple embeddings (use UUID format) - embeddings := make([]EmbeddingInput, 10) + embeddings := make([]vectordb.EmbeddingInput, 10) for i := 0; i < 10; i++ { - embeddings[i] = EmbeddingInput{ + embeddings[i] = vectordb.EmbeddingInput{ ID: fmt.Sprintf("00000000-0000-0000-0000-%012d", i+1), Vector: generateRandomVector(1536), - Meta: map[string]any{ + Payload: map[string]any{ "title": fmt.Sprintf("Document %d", i), "index": i, }, @@ -266,12 +271,12 @@ func TestQdrantWithFXModule(t *testing.T) { } // Batch insert - err = qdrantClient.BatchInsert(ctx, collectionName, embeddings) + err = adapter.Insert(ctx, collectionName, embeddings) assert.NoError(t, err) // Search and verify time.Sleep(1 * time.Second) // Allow time for indexing - batchResults, err := qdrantClient.Search(ctx, SearchRequest{ + batchResults, _, err := adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: collectionName, Vector: embeddings[0].Vector, TopK: 10, @@ -284,22 +289,22 @@ func TestQdrantWithFXModule(t *testing.T) { for i, emb := range embeddings { ids[i] = emb.ID } - err = qdrantClient.Delete(ctx, collectionName, ids) + err = adapter.Delete(ctx, collectionName, ids) assert.NoError(t, err) }) // Test empty operations t.Run("EmptyOperations", func(t *testing.T) { collectionName := "test_empty" - err := qdrantClient.EnsureCollection(ctx, collectionName) + err := adapter.EnsureCollection(ctx, collectionName, 1536) require.NoError(t, err) // Empty batch insert should be no-op - err = qdrantClient.BatchInsert(ctx, collectionName, []EmbeddingInput{}) + err = adapter.Insert(ctx, collectionName, []vectordb.EmbeddingInput{}) assert.NoError(t, err) // Empty delete should be no-op - err = qdrantClient.Delete(ctx, collectionName, []string{}) + err = adapter.Delete(ctx, collectionName, []string{}) assert.NoError(t, err) }) @@ -307,8 +312,8 @@ func TestQdrantWithFXModule(t *testing.T) { require.NoError(t, app.Stop(ctx)) } -// TestQdrantClientOperations tests various client operations -func TestQdrantClientOperations(t *testing.T) { +// TestVectorDBAdapterOperations tests various adapter operations +func TestVectorDBAdapterOperations(t *testing.T) { // Skip if running in short mode if testing.Short() { t.Skip("Skipping integration test in short mode") @@ -339,20 +344,23 @@ func TestQdrantClientOperations(t *testing.T) { require.NotNil(t, client) defer client.Close() + // Create adapter + adapter := NewAdapter(client.api) + collectionName := "test_operations" // Ensure collection exists - err = client.EnsureCollection(ctx, collectionName) + err = adapter.EnsureCollection(ctx, collectionName, 1536) require.NoError(t, err) t.Run("GetCollectionByName", func(t *testing.T) { // Fetch collection info using GetCollection - col, err := client.GetCollection(ctx, collectionName) + col, err := adapter.GetCollection(ctx, collectionName) assert.NoError(t, err, "expected GetCollection to succeed") assert.NotNil(t, col, "expected non-nil collection info") // Validate expected metadata fields - assert.GreaterOrEqual(t, int(col.Vectors), 0, "vector count should be >= 0") - assert.GreaterOrEqual(t, int(col.Points), 0, "points count should be >= 0") + assert.GreaterOrEqual(t, int(col.VectorCount), 0, "vector count should be >= 0") + assert.GreaterOrEqual(t, int(col.PointCount), 0, "points count should be >= 0") // Validate vector config details (size and distance) assert.NotZero(t, col.VectorSize, "vector size should not be zero") @@ -362,8 +370,8 @@ func TestQdrantClientOperations(t *testing.T) { t.Logf("Collection '%s': status=%s, vectors=%d, points=%d, vectorSize=%d, distance=%s", col.Name, col.Status, - col.Vectors, - col.Points, + col.VectorCount, + col.PointCount, col.VectorSize, col.Distance, ) @@ -371,22 +379,22 @@ func TestQdrantClientOperations(t *testing.T) { t.Run("SearchReturnsTopK", func(t *testing.T) { // Insert multiple embeddings (use UUID format) - embeddings := make([]EmbeddingInput, 20) + embeddings := make([]vectordb.EmbeddingInput, 20) for i := 0; i < 20; i++ { - embeddings[i] = EmbeddingInput{ - ID: fmt.Sprintf("00000000-0000-0000-0001-%012d", i+1), - Vector: generateRandomVector(1536), - Meta: map[string]any{"index": i}, + embeddings[i] = vectordb.EmbeddingInput{ + ID: fmt.Sprintf("00000000-0000-0000-0001-%012d", i+1), + Vector: generateRandomVector(1536), + Payload: map[string]any{"index": i}, } } - err := client.BatchInsert(ctx, collectionName, embeddings) + err := adapter.Insert(ctx, collectionName, embeddings) require.NoError(t, err) time.Sleep(1 * time.Second) // Allow time for indexing // Search with topK = 5 - batchResults, err := client.Search(ctx, SearchRequest{ + batchResults, _, err := adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: collectionName, Vector: embeddings[0].Vector, TopK: 5, @@ -395,7 +403,7 @@ func TestQdrantClientOperations(t *testing.T) { assert.LessOrEqual(t, len(batchResults[0]), 5) // Search with topK = 10 - batchResults, err = client.Search(ctx, SearchRequest{ + batchResults, _, err = adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: collectionName, Vector: embeddings[0].Vector, TopK: 10, @@ -408,30 +416,30 @@ func TestQdrantClientOperations(t *testing.T) { for i, emb := range embeddings { ids[i] = emb.ID } - err = client.Delete(ctx, collectionName, ids) + err = adapter.Delete(ctx, collectionName, ids) assert.NoError(t, err) }) t.Run("SearchWithMetadata", func(t *testing.T) { // Insert embedding with rich metadata (UUID format, simple types only) - embedding := EmbeddingInput{ + embedding := vectordb.EmbeddingInput{ ID: "00000000-0000-0000-0002-000000000001", Vector: generateRandomVector(1536), - Meta: map[string]any{ + Payload: map[string]any{ "title": "Test Title", "author": "Test Author", "timestamp": time.Now().Unix(), - "category": "test", // Use simple types instead of arrays + "category": "test", }, } - err := client.Insert(ctx, collectionName, embedding) + err := adapter.Insert(ctx, collectionName, []vectordb.EmbeddingInput{embedding}) require.NoError(t, err) time.Sleep(1 * time.Second) // Search and verify metadata - batchResults, err := client.Search(ctx, SearchRequest{ + batchResults, _, err := adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: collectionName, Vector: embedding.Vector, TopK: 1, @@ -440,39 +448,39 @@ func TestQdrantClientOperations(t *testing.T) { assert.Greater(t, len(batchResults[0]), 0) if len(batchResults[0]) > 0 { - meta := batchResults[0][0].GetMeta() - assert.NotNil(t, meta) + payload := batchResults[0][0].Payload + assert.NotNil(t, payload) } // Clean up - err = client.Delete(ctx, collectionName, []string{embedding.ID}) + err = adapter.Delete(ctx, collectionName, []string{embedding.ID}) assert.NoError(t, err) }) t.Run("LargeBatchInsert", func(t *testing.T) { collectionName := "test_large_batch" - err := client.EnsureCollection(ctx, collectionName) + err := adapter.EnsureCollection(ctx, collectionName, 1536) require.NoError(t, err) // Create a large batch (more than defaultBatchSize, use UUID format) largeCount := 500 - embeddings := make([]EmbeddingInput, largeCount) + embeddings := make([]vectordb.EmbeddingInput, largeCount) for i := 0; i < largeCount; i++ { - embeddings[i] = EmbeddingInput{ - ID: fmt.Sprintf("00000000-0000-0000-0003-%012d", i+1), - Vector: generateRandomVector(1536), - Meta: map[string]any{"index": i}, + embeddings[i] = vectordb.EmbeddingInput{ + ID: fmt.Sprintf("00000000-0000-0000-0003-%012d", i+1), + Vector: generateRandomVector(1536), + Payload: map[string]any{"index": i}, } } // Should handle batching automatically - err = client.BatchInsert(ctx, collectionName, embeddings) + err = adapter.Insert(ctx, collectionName, embeddings) assert.NoError(t, err) time.Sleep(2 * time.Second) // Verify some embeddings exist - batchResults, err := client.Search(ctx, SearchRequest{ + batchResults, _, err := adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: collectionName, Vector: embeddings[0].Vector, TopK: 10, @@ -485,7 +493,7 @@ func TestQdrantClientOperations(t *testing.T) { for i, emb := range embeddings { ids[i] = emb.ID } - err = client.Delete(ctx, collectionName, ids) + err = adapter.Delete(ctx, collectionName, ids) assert.NoError(t, err) }) } @@ -522,6 +530,8 @@ func TestQdrantErrorHandling(t *testing.T) { require.NotNil(t, client) defer client.Close() + adapter := NewAdapter(client.api) + t.Run("InvalidEndpoint", func(t *testing.T) { invalidCfg := &Config{ Endpoint: "invalid-host:9999", @@ -534,19 +544,21 @@ func TestQdrantErrorHandling(t *testing.T) { }) t.Run("EmptyCollectionName", func(t *testing.T) { - err := client.EnsureCollection(ctx, "") + err := adapter.EnsureCollection(ctx, "", 1536) assert.Error(t, err) assert.Contains(t, err.Error(), "collection name cannot be empty") }) t.Run("SearchOnNonExistentCollection", func(t *testing.T) { vector := generateRandomVector(1536) - _, err := client.Search(ctx, SearchRequest{ + _, errs, err := adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: "non_existent_collection", Vector: vector, TopK: 5, }) - assert.Error(t, err) + // With partial results, systemic error is nil but per-request error exists + assert.NoError(t, err) + assert.NotNil(t, errs[0]) }) } @@ -586,9 +598,10 @@ func TestQdrantLifecycleAndHealthCheck(t *testing.T) { err = client.healthCheck() require.NoError(t, err, "Qdrant health check failed") - // Ensure collection exists + // Create adapter and ensure collection exists + adapter := NewAdapter(client.api) collectionName := "test_collection" - err = client.EnsureCollection(context.Background(), collectionName) + err = adapter.EnsureCollection(context.Background(), collectionName, 1536) require.NoError(t, err, "failed to ensure collection") // Close client @@ -598,6 +611,396 @@ func TestQdrantLifecycleAndHealthCheck(t *testing.T) { t.Log("Qdrant client lifecycle test passed successfully") } +// TestFilterOperations tests various filter scenarios +func TestFilterOperations(t *testing.T) { + // Skip if running in short mode + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + containerInstance, err := setupQdrantContainer(ctx) + require.NoError(t, err) + defer func() { + if err := containerInstance.Terminate(ctx); err != nil { + t.Fatalf("failed to terminate container: %s", err) + } + }() + + // Convert port to uint + portNum, err := strconv.Atoi(containerInstance.Port) + require.NoError(t, err) + + cfg := &Config{ + Endpoint: containerInstance.Host, + Port: portNum, + CheckCompatibility: false, + Timeout: 10 * time.Second, + } + + client, err := NewQdrantClient(QdrantParams{Config: cfg}) + require.NoError(t, err) + defer client.Close() + + adapter := NewAdapter(client.api) + + t.Run("SearchWithMustFilter", func(t *testing.T) { + collectionName := "test_must_filter" + err := adapter.EnsureCollection(ctx, collectionName, 1536) + require.NoError(t, err) + + // Insert embeddings with different metadata + embeddings := []vectordb.EmbeddingInput{ + { + ID: "00000000-0000-0000-0004-000000000001", + Vector: generateRandomVector(1536), + Payload: map[string]any{"color": "red", "size": int64(10)}, + }, + { + ID: "00000000-0000-0000-0004-000000000002", + Vector: generateRandomVector(1536), + Payload: map[string]any{"color": "blue", "size": int64(20)}, + }, + { + ID: "00000000-0000-0000-0004-000000000003", + Vector: generateRandomVector(1536), + Payload: map[string]any{"color": "red", "size": int64(30)}, + }, + } + + err = adapter.Insert(ctx, collectionName, embeddings) + require.NoError(t, err) + time.Sleep(1 * time.Second) + + // Test Must filter (color == "red") + batchResults, _, err := adapter.Search(ctx, vectordb.SearchRequest{ + CollectionName: collectionName, + Vector: embeddings[0].Vector, + TopK: 10, + Filters: []*vectordb.FilterSet{ + vectordb.NewFilterSet( + vectordb.Must(vectordb.NewMatch("color", "red")), + ), + }, + }) + assert.NoError(t, err) + assert.Equal(t, 2, len(batchResults[0])) // Only red items + + // Clean up + ids := []string{embeddings[0].ID, embeddings[1].ID, embeddings[2].ID} + err = adapter.Delete(ctx, collectionName, ids) + assert.NoError(t, err) + }) + + t.Run("SearchWithNumericRangeFilter", func(t *testing.T) { + collectionName := "test_numeric_range" + err := adapter.EnsureCollection(ctx, collectionName, 1536) + require.NoError(t, err) + + embeddings := []vectordb.EmbeddingInput{ + {ID: "00000000-0000-0000-0010-000000000001", Vector: generateRandomVector(1536), Payload: map[string]any{"size": int64(10)}}, + {ID: "00000000-0000-0000-0010-000000000002", Vector: generateRandomVector(1536), Payload: map[string]any{"size": int64(20)}}, + {ID: "00000000-0000-0000-0010-000000000003", Vector: generateRandomVector(1536), Payload: map[string]any{"size": int64(30)}}, + } + + err = adapter.Insert(ctx, collectionName, embeddings) + require.NoError(t, err) + time.Sleep(1 * time.Second) + + // Test numeric range filter (size >= 20) + gte := float64(20) + batchResults, _, err := adapter.Search(ctx, vectordb.SearchRequest{ + CollectionName: collectionName, + Vector: embeddings[0].Vector, + TopK: 10, + Filters: []*vectordb.FilterSet{ + vectordb.NewFilterSet( + vectordb.Must(vectordb.NewNumericRange("size", vectordb.NumericRange{Gte: >e})), + ), + }, + }) + assert.NoError(t, err) + assert.Equal(t, 2, len(batchResults[0])) // size 20 and 30 + + // Clean up + ids := []string{embeddings[0].ID, embeddings[1].ID, embeddings[2].ID} + err = adapter.Delete(ctx, collectionName, ids) + assert.NoError(t, err) + }) + + t.Run("SearchWithMustNotFilter", func(t *testing.T) { + collectionName := "test_must_not" + err := adapter.EnsureCollection(ctx, collectionName, 1536) + require.NoError(t, err) + + embeddings := []vectordb.EmbeddingInput{ + {ID: "00000000-0000-0000-0005-000000000001", Vector: generateRandomVector(1536), Payload: map[string]any{"status": "published"}}, + {ID: "00000000-0000-0000-0005-000000000002", Vector: generateRandomVector(1536), Payload: map[string]any{"status": "draft"}}, + {ID: "00000000-0000-0000-0005-000000000003", Vector: generateRandomVector(1536), Payload: map[string]any{"status": "archived"}}, + } + + err = adapter.Insert(ctx, collectionName, embeddings) + require.NoError(t, err) + time.Sleep(1 * time.Second) + + // Exclude archived items + batchResults, _, err := adapter.Search(ctx, vectordb.SearchRequest{ + CollectionName: collectionName, + Vector: embeddings[0].Vector, + TopK: 10, + Filters: []*vectordb.FilterSet{ + vectordb.NewFilterSet( + vectordb.MustNot(vectordb.NewMatch("status", "archived")), + ), + }, + }) + assert.NoError(t, err) + assert.Equal(t, 2, len(batchResults[0])) // published and draft only + + // Clean up + ids := []string{embeddings[0].ID, embeddings[1].ID, embeddings[2].ID} + err = adapter.Delete(ctx, collectionName, ids) + assert.NoError(t, err) + }) + + t.Run("SearchWithMultipleFilterSets", func(t *testing.T) { + collectionName := "test_multi_filter" + err := adapter.EnsureCollection(ctx, collectionName, 1536) + require.NoError(t, err) + + embeddings := []vectordb.EmbeddingInput{ + {ID: "00000000-0000-0000-0006-000000000001", Vector: generateRandomVector(1536), Payload: map[string]any{"color": "red", "size": int64(10)}}, + {ID: "00000000-0000-0000-0006-000000000002", Vector: generateRandomVector(1536), Payload: map[string]any{"color": "red", "size": int64(50)}}, + {ID: "00000000-0000-0000-0006-000000000003", Vector: generateRandomVector(1536), Payload: map[string]any{"color": "blue", "size": int64(10)}}, + } + + err = adapter.Insert(ctx, collectionName, embeddings) + require.NoError(t, err) + time.Sleep(1 * time.Second) + + // color == red AND size < 20 + lt := float64(20) + batchResults, _, err := adapter.Search(ctx, vectordb.SearchRequest{ + CollectionName: collectionName, + Vector: embeddings[0].Vector, + TopK: 10, + Filters: []*vectordb.FilterSet{ + vectordb.NewFilterSet(vectordb.Must(vectordb.NewMatch("color", "red"))), + vectordb.NewFilterSet(vectordb.Must(vectordb.NewNumericRange("size", vectordb.NumericRange{Lt: <}))), + }, + }) + assert.NoError(t, err) + assert.Equal(t, 1, len(batchResults[0])) // Only red with size < 20 + + // Clean up + ids := []string{embeddings[0].ID, embeddings[1].ID, embeddings[2].ID} + err = adapter.Delete(ctx, collectionName, ids) + assert.NoError(t, err) + }) + + t.Run("SearchWithMatchAnyFilter", func(t *testing.T) { + collectionName := "test_match_any" + err := adapter.EnsureCollection(ctx, collectionName, 1536) + require.NoError(t, err) + + embeddings := []vectordb.EmbeddingInput{ + {ID: "00000000-0000-0000-0008-000000000001", Vector: generateRandomVector(1536), Payload: map[string]any{"tag": "ml"}}, + {ID: "00000000-0000-0000-0008-000000000002", Vector: generateRandomVector(1536), Payload: map[string]any{"tag": "nlp"}}, + {ID: "00000000-0000-0000-0008-000000000003", Vector: generateRandomVector(1536), Payload: map[string]any{"tag": "cv"}}, + } + + err = adapter.Insert(ctx, collectionName, embeddings) + require.NoError(t, err) + time.Sleep(1 * time.Second) + + // Match any of ["ml", "nlp"] + batchResults, _, err := adapter.Search(ctx, vectordb.SearchRequest{ + CollectionName: collectionName, + Vector: embeddings[0].Vector, + TopK: 10, + Filters: []*vectordb.FilterSet{ + vectordb.NewFilterSet( + vectordb.Must(vectordb.NewMatchAny("tag", "ml", "nlp")), + ), + }, + }) + assert.NoError(t, err) + assert.Equal(t, 2, len(batchResults[0])) // ml and nlp only + + // Clean up + ids := []string{embeddings[0].ID, embeddings[1].ID, embeddings[2].ID} + err = adapter.Delete(ctx, collectionName, ids) + assert.NoError(t, err) + }) + + t.Run("SearchWithShouldFilter", func(t *testing.T) { + collectionName := "test_should_filter" + err := adapter.EnsureCollection(ctx, collectionName, 1536) + require.NoError(t, err) + + embeddings := []vectordb.EmbeddingInput{ + {ID: "00000000-0000-0000-0009-000000000001", Vector: generateRandomVector(1536), Payload: map[string]any{"color": "red"}}, + {ID: "00000000-0000-0000-0009-000000000002", Vector: generateRandomVector(1536), Payload: map[string]any{"color": "blue"}}, + {ID: "00000000-0000-0000-0009-000000000003", Vector: generateRandomVector(1536), Payload: map[string]any{"color": "green"}}, + } + + err = adapter.Insert(ctx, collectionName, embeddings) + require.NoError(t, err) + time.Sleep(1 * time.Second) + + // Should match red OR blue + batchResults, _, err := adapter.Search(ctx, vectordb.SearchRequest{ + CollectionName: collectionName, + Vector: embeddings[0].Vector, + TopK: 10, + Filters: []*vectordb.FilterSet{ + vectordb.NewFilterSet( + vectordb.Should( + vectordb.NewMatch("color", "red"), + vectordb.NewMatch("color", "blue"), + ), + ), + }, + }) + assert.NoError(t, err) + assert.Equal(t, 2, len(batchResults[0])) // red and blue only + + // Clean up + ids := []string{embeddings[0].ID, embeddings[1].ID, embeddings[2].ID} + err = adapter.Delete(ctx, collectionName, ids) + assert.NoError(t, err) + }) +} + +// TestBatchSearch tests batch search with multiple queries +func TestBatchSearch(t *testing.T) { + // Skip if running in short mode + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + containerInstance, err := setupQdrantContainer(ctx) + require.NoError(t, err) + defer func() { + if err := containerInstance.Terminate(ctx); err != nil { + t.Fatalf("failed to terminate container: %s", err) + } + }() + + portNum, err := strconv.Atoi(containerInstance.Port) + require.NoError(t, err) + + cfg := &Config{ + Endpoint: containerInstance.Host, + Port: portNum, + CheckCompatibility: false, + Timeout: 10 * time.Second, + } + + client, err := NewQdrantClient(QdrantParams{Config: cfg}) + require.NoError(t, err) + defer client.Close() + + adapter := NewAdapter(client.api) + + collectionName := "test_batch_search" + err = adapter.EnsureCollection(ctx, collectionName, 1536) + require.NoError(t, err) + + embeddings := make([]vectordb.EmbeddingInput, 5) + for i := 0; i < 5; i++ { + embeddings[i] = vectordb.EmbeddingInput{ + ID: fmt.Sprintf("00000000-0000-0000-0007-%012d", i+1), + Vector: generateRandomVector(1536), + Payload: map[string]any{"index": int64(i)}, + } + } + + err = adapter.Insert(ctx, collectionName, embeddings) + require.NoError(t, err) + time.Sleep(1 * time.Second) + + // Batch search with multiple queries + batchResults, _, err := adapter.Search(ctx, + vectordb.SearchRequest{CollectionName: collectionName, Vector: embeddings[0].Vector, TopK: 3}, + vectordb.SearchRequest{CollectionName: collectionName, Vector: embeddings[1].Vector, TopK: 3}, + vectordb.SearchRequest{CollectionName: collectionName, Vector: embeddings[2].Vector, TopK: 3}, + ) + assert.NoError(t, err) + assert.Equal(t, 3, len(batchResults)) // 3 result sets + + // Each result set should have results + for i, results := range batchResults { + assert.Greater(t, len(results), 0, "result set %d should have results", i) + } + + // Clean up + ids := make([]string, len(embeddings)) + for i, emb := range embeddings { + ids[i] = emb.ID + } + err = adapter.Delete(ctx, collectionName, ids) + assert.NoError(t, err) +} + +// TestListCollections tests collection listing +func TestListCollections(t *testing.T) { + // Skip if running in short mode + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + containerInstance, err := setupQdrantContainer(ctx) + require.NoError(t, err) + defer func() { + if err := containerInstance.Terminate(ctx); err != nil { + t.Fatalf("failed to terminate container: %s", err) + } + }() + + portNum, err := strconv.Atoi(containerInstance.Port) + require.NoError(t, err) + + cfg := &Config{ + Endpoint: containerInstance.Host, + Port: portNum, + CheckCompatibility: false, + Timeout: 10 * time.Second, + } + + client, err := NewQdrantClient(QdrantParams{Config: cfg}) + require.NoError(t, err) + defer client.Close() + + adapter := NewAdapter(client.api) + + // Create a few collections + err = adapter.EnsureCollection(ctx, "list_test_1", 1536) + require.NoError(t, err) + err = adapter.EnsureCollection(ctx, "list_test_2", 1536) + require.NoError(t, err) + + // List collections + collections, err := adapter.ListCollections(ctx) + assert.NoError(t, err) + assert.GreaterOrEqual(t, len(collections), 2) + + // Verify our collections are in the list + found1, found2 := false, false + for _, name := range collections { + if name == "list_test_1" { + found1 = true + } + if name == "list_test_2" { + found2 = true + } + } + assert.True(t, found1, "list_test_1 should be in collections") + assert.True(t, found2, "list_test_2 should be in collections") +} + // Helper function to generate random vectors for testing func generateRandomVector(size int) []float32 { vector := make([]float32, size) diff --git a/v1/qdrant/utils.go b/v1/qdrant/utils.go index 9909631..964f0a0 100644 --- a/v1/qdrant/utils.go +++ b/v1/qdrant/utils.go @@ -6,103 +6,6 @@ import ( qdrant "github.com/qdrant/go-client/qdrant" ) -// EmbeddingInput is the type the application provides to insert embeddings. -// Keeps the app decoupled from internal Qdrant SDK structs. -type EmbeddingInput struct { - ID string // Unique identifier for the embedding (e.g., document ID) - Vector []float32 // Dense vector representation of the embedding - Meta map[string]any // Optional metadata associated with the embedding -} - -// Embedding represents a dense embedding vector. -type Embedding struct { - ID string // Unique identifier (same as Qdrant point ID) - Vector []float32 // Vector representation of the embedding - Meta map[string]any // Optional metadata associated with the embedding -} - -// SearchResult holds results from a similarity search. -type SearchResult struct { - ID string - Score float32 - Meta map[string]*qdrant.Value - Vector []float32 - Collection string -} - -// GetID returns the result's unique identifier. -func (r SearchResult) GetID() string { return r.ID } - -// GetScore returns the similarity score associated with the result. -func (r SearchResult) GetScore() float32 { return r.Score } - -// GetMeta returns the metadata stored with the vector. -func (r SearchResult) GetMeta() map[string]*qdrant.Value { return r.Meta } - -// GetVector returns the dense embedding vector if available. -func (r SearchResult) GetVector() []float32 { return r.Vector } - -// HasVector reports whether the result contains a non-empty vector payload. -func (r SearchResult) HasVector() bool { return len(r.Vector) > 0 } - -// GetCollectionName returns the name of the collection from which the result originated. -func (r SearchResult) GetCollectionName() string { return r.Collection } - -// NewEmbedding converts a high-level EmbeddingInput into the internal Embedding type. -// Having this builder allows for future validation or normalization logic. -// For now, it performs a shallow copy. -func NewEmbedding(input EmbeddingInput) Embedding { - return Embedding(input) -} - -// SearchResultInterface is the public interface for search results. -// It provides a consistent way to access search results regardless of the underlying implementation. -type SearchResultInterface interface { - GetID() string // Unique result identifier - GetScore() float32 // Similarity score - GetMeta() map[string]*qdrant.Value // Metadata associated with the result - GetVector() []float32 // Optional embedding vector - HasVector() bool // Whether a vector payload is present - GetCollectionName() string // Name of the Qdrant collection -} - -// Collection ────────────────────────────────────────────────────────────── -// Collection -// ────────────────────────────────────────────────────────────── -// -// Collection represents a high-level, decoupled view of a Qdrant collection. -// -// It provides essential metadata about a vector collection without exposing -// Qdrant SDK types, allowing the application layer to remain independent -// of the underlying database implementation. -// -// Fields: -// - Name — The unique name of the collection. -// - Status — Current operational state (e.g., "Green", "Yellow"). -// - VectorSize — The dimension of stored vectors (e.g., 1536). -// - Distance — The similarity metric used ("Cosine", "Dot", "Euclid"). -// - Vectors — Total number of stored vectors in the collection. -// - Points — Total number of indexed points/documents in the collection. -// -// This struct serves as an abstraction layer between Qdrant's low-level -// protobuf models and the higher-level application logic. -type Collection struct { - Name string - Status string - VectorSize int - Distance string - Vectors uint64 - Points uint64 -} - -// SearchRequest represents a single search request for batch operations -type SearchRequest struct { - CollectionName string - Vector []float32 - TopK int - Filters *FilterSet // Optional: key-value filters -} - // validateSearchInput validates common search parameters func validateSearchInput(collectionName string, vector []float32, topK int) error { if collectionName == "" { diff --git a/v1/vectordb/doc.go b/v1/vectordb/doc.go new file mode 100644 index 0000000..756c31f --- /dev/null +++ b/v1/vectordb/doc.go @@ -0,0 +1,138 @@ +// Package vectordb provides a database-agnostic abstraction for vector similarity search. +// +// # Overview +// +// This package defines a common interface [Service] that can be implemented +// by different vector database adapters (Qdrant, pgVector, etc.), allowing +// applications to switch between databases without changing application code. +// +// # Architecture +// +// ┌─────────────────────────────────────────────────────────────┐ +// │ Application Layer │ +// │ (uses vectordb.Service - no DB-specific imports) │ +// └──────────────────────────┬──────────────────────────────────┘ +// │ +// ▼ +// ┌─────────────────────────────────────────────────────────────┐ +// │ vectordb.Service │ +// │ (common interface + DB-agnostic types) │ +// └──────────────────────────┬──────────────────────────────────┘ +// │ +// ┌────────────┴────────────┐ +// ▼ ▼ +// ┌───────────────┐ ┌───────────────-┐ +// │ qdrant.Adapter│ │pgvector.Adapter│ +// │ (implements) │ │ (planned) │ +// └───────────────┘ └───────────────-┘ +// +// # Benefits +// +// - Single Source of Truth: Filter types, search interfaces, and result types defined once. +// - Easy to Add New DBs: Just add a new adapter; consuming projects don't change. +// - Consistent API: All projects using std get the same interface. +// - Testability: Mock the interface once, works for all DBs. +// +// # Usage +// +// In your application, depend only on the vectordb interface: +// +// import "github.com/Aleph-Alpha/std/v1/vectordb" +// +// type SearchService struct { +// db vectordb.Service +// } +// +// func NewSearchService(db vectordb.Service) *SearchService { +// return &SearchService{db: db} +// } +// +// func (s *SearchService) Search(ctx context.Context, query string, vector []float32) ([]vectordb.SearchResult, error) { +// results, err := s.db.Search(ctx, vectordb.SearchRequest{ +// CollectionName: "documents", +// Vector: vector, +// TopK: 10, +// Filters: []*vectordb.FilterSet{ +// { +// Must: &vectordb.ConditionSet{ +// Conditions: []vectordb.FilterCondition{ +// vectordb.NewMatch("status", "published"), +// }, +// }, +// }, +// }, +// }) +// if err != nil { +// return nil, err +// } +// return results[0], nil +// } +// +// # Wire Up with Qdrant +// +// In your main setup: +// +// import ( +// "github.com/Aleph-Alpha/std/v1/vectordb" +// "github.com/Aleph-Alpha/std/v1/qdrant" +// ) +// +// func main() { +// // Create Qdrant client (with health checks, config, etc.) +// qc, _ := qdrant.NewQdrantClient(qdrant.QdrantParams{ +// Config: &qdrant.Config{Endpoint: "localhost", Port: 6334}, +// }) +// +// // Create adapter for DB-agnostic usage +// db := qdrant.NewAdapter(qc.Client()) +// +// // Use in application +// svc := NewSearchService(db) +// // ... +// } +// +// # Package Layout +// +// vectordb/ +// ├── interface.go # Service interface +// ├── types.go # SearchRequest, SearchResult, EmbeddingInput, Collection +// ├── filters.go # FilterSet, FilterCondition, and condition types +// ├── utils.go # Convenience constructors (New*) and JSON helpers +// └── doc.go # This file +// +// qdrant/ # Qdrant package (includes adapter) +// ├── client.go # QdrantClient wrapper +// ├── operations.go # Adapter - implements Service +// ├── converter.go # vectordb types → qdrant types +// └── ... +// +// Future adapters would live in their own packages: +// +// pgvector/ # (planned) PostgreSQL pgvector adapter +// +// # Filter Types +// +// The package provides DB-agnostic filter conditions: +// +// | Type | Description | SQL Equivalent | +// |-----------------------|------------------------------|-----------------------------------| +// | MatchCondition | Exact value match | WHERE field = value | +// | MatchAnyCondition | Value in set | WHERE field IN (...) | +// | MatchExceptCondition | Value not in set | WHERE field NOT IN (...) | +// | NumericRangeCondition | Numeric range | WHERE field >= min AND field <= max| +// | TimeRangeCondition | Datetime range | WHERE created_at BETWEEN ... | +// | IsNullCondition | Field is null | WHERE field IS NULL | +// | IsEmptyCondition | Field is empty/null/missing | WHERE field IS NULL OR field = '' | +// +// Use convenience constructors for cleaner code: +// +// // Internal field (top-level in payload) +// vectordb.NewMatch("status", "published") +// +// // User-defined field (stored under "custom." prefix) +// vectordb.NewUserMatch("category", "research") +// +// // Range conditions with NumericRange/TimeRange structs +// vectordb.NewNumericRange("price", vectordb.NumericRange{Gte: &min, Lt: &max}) +// vectordb.NewTimeRange("created_at", vectordb.TimeRange{AtOrAfter: &start, Before: &end}) +package vectordb diff --git a/v1/vectordb/filters.go b/v1/vectordb/filters.go new file mode 100644 index 0000000..9fab712 --- /dev/null +++ b/v1/vectordb/filters.go @@ -0,0 +1,222 @@ +package vectordb + +import ( + "encoding/json" + "time" +) + +// FieldType indicates whether a field is internal (system-managed) +// or user-defined (stored under a prefix like "custom."). +type FieldType int + +const ( + // InternalField - system-managed fields stored at top-level + InternalField FieldType = iota + // UserField - user-defined fields stored under a prefix (e.g., "custom.") + UserField +) + +// FilterCondition is the interface all filter conditions must implement. +// Each database adapter converts these to its native filter format. +type FilterCondition interface { + // isFilterCondition is a marker method to ensure type safety + IsFilterCondition() +} + +// FilterSet supports Must (AND), Should (OR), and MustNot (NOT) clauses. +// Use with SearchRequest.Filters to filter search results. +// +// Example: +// +// filters := &FilterSet{ +// Must: &ConditionSet{ +// Conditions: []FilterCondition{ +// &MatchCondition{Field: "city", Value: "London"}, +// }, +// }, +// } +type FilterSet struct { + // Must: All conditions must match (AND) + Must *ConditionSet `json:"must,omitempty"` + // Should: At least one condition must match (OR) + Should *ConditionSet `json:"should,omitempty"` + // MustNot: None of the conditions should match (NOT) + MustNot *ConditionSet `json:"mustNot,omitempty"` +} + +// ConditionSet holds a group of conditions for a single clause. +type ConditionSet struct { + Conditions []FilterCondition `json:"conditions,omitempty"` +} + +// ── Match Conditions ───────────────────────────────────────────────────────── + +// MatchCondition represents an exact match filter (WHERE field = value). +// Supports string, bool, and int64 values. +type MatchCondition struct { + Field string `json:"field"` + Value any `json:"equalTo"` + FieldType FieldType `json:"-"` +} + +func (c *MatchCondition) IsFilterCondition() {} + +// MatchAnyCondition matches if value is one of the given values (IN operator). +// SQL equivalent: WHERE field IN (value1, value2, ...) +type MatchAnyCondition struct { + Field string `json:"field"` + Values []any `json:"anyOf"` + FieldType FieldType `json:"-"` +} + +func (c *MatchAnyCondition) IsFilterCondition() {} + +// MatchExceptCondition matches if value is NOT one of the given values (NOT IN). +// SQL equivalent: WHERE field NOT IN (value1, value2, ...) +type MatchExceptCondition struct { + Field string `json:"field"` + Values []any `json:"noneOf"` + FieldType FieldType `json:"-"` +} + +func (c *MatchExceptCondition) IsFilterCondition() {} + +// ── Range Types ────────────────────────────────────────────────────────────── + +// NumericRange defines bounds for numeric filtering. +// Used with NewNumericRange for cleaner constructor calls. +type NumericRange struct { + Gt *float64 `json:"greaterThan,omitempty"` // GreaterThan (exclusive) + Gte *float64 `json:"greaterThanOrEqualTo,omitempty"` // GreaterThanOrEqualTo (inclusive) + Lt *float64 `json:"lessThan,omitempty"` // LessThan (exclusive) + Lte *float64 `json:"lessThanOrEqualTo,omitempty"` // LessThanOrEqualTo (inclusive) +} + +// TimeRange defines bounds for time filtering. +// Used with NewTimeRange for cleaner constructor calls. +type TimeRange struct { + Gt *time.Time `json:"after,omitempty"` // After (exclusive) + Gte *time.Time `json:"atOrAfter,omitempty"` // AtOrAfter (inclusive) + Lt *time.Time `json:"before,omitempty"` // Before (exclusive) + Lte *time.Time `json:"atOrBefore,omitempty"` // AtOrBefore (inclusive) +} + +// ── Range Conditions ───────────────────────────────────────────────────────── + +// NumericRangeCondition filters by numeric range. +// SQL equivalent: WHERE field >= min AND field <= max +type NumericRangeCondition struct { + Field string `json:"field"` + Range NumericRange `json:"-"` + FieldType FieldType `json:"-"` +} + +func (c *NumericRangeCondition) IsFilterCondition() {} + +func (c *NumericRangeCondition) MarshalJSON() ([]byte, error) { + type Alias struct { + Field string `json:"field"` + GreaterThan *float64 `json:"greaterThan,omitempty"` + GreaterThanOrEqualTo *float64 `json:"greaterThanOrEqualTo,omitempty"` + LessThan *float64 `json:"lessThan,omitempty"` + LessThanOrEqualTo *float64 `json:"lessThanOrEqualTo,omitempty"` + } + return json.Marshal(Alias{ + Field: c.Field, + GreaterThan: c.Range.Gt, + GreaterThanOrEqualTo: c.Range.Gte, + LessThan: c.Range.Lt, + LessThanOrEqualTo: c.Range.Lte, + }) +} + +func (c *NumericRangeCondition) UnmarshalJSON(data []byte) error { + type Alias struct { + Field string `json:"field"` + GreaterThan *float64 `json:"greaterThan,omitempty"` + GreaterThanOrEqualTo *float64 `json:"greaterThanOrEqualTo,omitempty"` + LessThan *float64 `json:"lessThan,omitempty"` + LessThanOrEqualTo *float64 `json:"lessThanOrEqualTo,omitempty"` + } + var alias Alias + if err := json.Unmarshal(data, &alias); err != nil { + return err + } + c.Field = alias.Field + c.Range = NumericRange{ + Gt: alias.GreaterThan, + Gte: alias.GreaterThanOrEqualTo, + Lt: alias.LessThan, + Lte: alias.LessThanOrEqualTo, + } + return nil +} + +// TimeRangeCondition filters by datetime range. +// SQL equivalent: WHERE created_at >= '2024-01-01' AND created_at < '2025-01-01' +type TimeRangeCondition struct { + Field string `json:"field"` + Range TimeRange `json:"-"` + FieldType FieldType `json:"-"` +} + +func (c *TimeRangeCondition) IsFilterCondition() {} + +func (c TimeRangeCondition) MarshalJSON() ([]byte, error) { + type Alias struct { + Field string `json:"field"` + After *time.Time `json:"after,omitempty"` + AtOrAfter *time.Time `json:"atOrAfter,omitempty"` + Before *time.Time `json:"before,omitempty"` + AtOrBefore *time.Time `json:"atOrBefore,omitempty"` + } + return json.Marshal(Alias{ + Field: c.Field, + After: c.Range.Gt, + AtOrAfter: c.Range.Gte, + Before: c.Range.Lt, + AtOrBefore: c.Range.Lte, + }) +} + +func (c *TimeRangeCondition) UnmarshalJSON(data []byte) error { + type Alias struct { + Field string `json:"field"` + After *time.Time `json:"after,omitempty"` + AtOrAfter *time.Time `json:"atOrAfter,omitempty"` + Before *time.Time `json:"before,omitempty"` + AtOrBefore *time.Time `json:"atOrBefore,omitempty"` + } + var alias Alias + if err := json.Unmarshal(data, &alias); err != nil { + return err + } + c.Field = alias.Field + c.Range = TimeRange{ + Gt: alias.After, + Gte: alias.AtOrAfter, + Lt: alias.Before, + Lte: alias.AtOrBefore, + } + return nil +} + +// ── Null/Empty Conditions ──────────────────────────────────────────────────── + +// IsNullCondition checks if a field has a NULL value. +// SQL equivalent: WHERE field IS NULL +type IsNullCondition struct { + Field string `json:"field"` + FieldType FieldType `json:"-"` +} + +func (c *IsNullCondition) IsFilterCondition() {} + +// IsEmptyCondition checks if a field is empty (doesn't exist, null, or []). +// SQL equivalent: WHERE field IS NULL OR field = ” OR field = [] +type IsEmptyCondition struct { + Field string `json:"field"` + FieldType FieldType `json:"-"` +} + +func (c *IsEmptyCondition) IsFilterCondition() {} diff --git a/v1/vectordb/interface.go b/v1/vectordb/interface.go new file mode 100644 index 0000000..b461d90 --- /dev/null +++ b/v1/vectordb/interface.go @@ -0,0 +1,60 @@ +package vectordb + +import "context" + +// Service is the common interface for all vector databases. +// It provides a database-agnostic abstraction for vector similarity search, +// allowing applications to switch between different vector databases +// (Qdrant, pgVector, etc.) without changing application code. +// +// Example usage: +// +// func NewSearchService(db vectordb.Service) *SearchService { +// return &SearchService{db: db} +// } +// +// // Works with any implementation: +// // - vectordb.NewQdrantAdapter(qdrantClient) +// // - vectordb.NewPgVectorAdapter(pgVectorClient) +type Service interface { + // Search performs similarity search across one or more requests. + // Each request can target a different collection with different filters. + // Returns: + // - results: slice of result slices—one []SearchResult per request + // - errs: per-request errors (errs[i] corresponds to requests[i]) + // - err: systemic error (context cancelled, etc.) + // + // Example: + // results, errs, err := db.Search(ctx, + // SearchRequest{CollectionName: "docs", Vector: vec1, TopK: 10}, + // SearchRequest{CollectionName: "docs", Vector: vec2, TopK: 5, Filters: filters}, + // ) + // if err != nil { + // return err // systemic failure + // } + // for i, res := range results { + // if errs[i] != nil { + // log.Printf("request %d failed: %v", i, errs[i]) + // continue + // } + // // use res... + // } + Search(ctx context.Context, requests ...SearchRequest) ([][]SearchResult, []error, error) + + // Insert adds embeddings to a collection. + // Uses batch processing internally for efficiency. + Insert(ctx context.Context, collectionName string, inputs []EmbeddingInput) error + + // Delete removes points by their IDs from a collection. + Delete(ctx context.Context, collection string, ids []string) error + + // EnsureCollection creates a collection if it doesn't exist. + // Safe to call multiple times—no-op if collection already exists. + EnsureCollection(ctx context.Context, name string, vectorSize uint64) error + + // GetCollection retrieves metadata about a collection. + GetCollection(ctx context.Context, name string) (*Collection, error) + + // ListCollections returns names of all collections. + ListCollections(ctx context.Context) ([]string, error) +} diff --git a/v1/vectordb/types.go b/v1/vectordb/types.go new file mode 100644 index 0000000..d0ef969 --- /dev/null +++ b/v1/vectordb/types.go @@ -0,0 +1,69 @@ +package vectordb + +// SearchRequest represents a single similarity search query. +// Use with Service.Search() for single or batch queries. +type SearchRequest struct { + // CollectionName is the target collection to search in + CollectionName string `json:"collectionName"` + + // Vector is the query embedding to find similar vectors for + Vector []float32 `json:"vector"` + + // TopK is the maximum number of results to return + TopK int `json:"maxResults"` + + // Filters is optional metadata filtering (AND/OR/NOT logic) + Filters []*FilterSet `json:"filters,omitempty"` +} + +// SearchResult represents a single search result with its similarity score. +// This is database-agnostic—payload is converted to map[string]any. +type SearchResult struct { + // ID is the unique identifier of the matched point + ID string `json:"id"` + + // Score is the similarity score (higher = more similar for cosine) + Score float32 `json:"score"` + + // Payload contains the metadata stored with the vector + Payload map[string]any `json:"payload"` + + // Vector is the stored embedding (only populated if requested) + Vector []float32 `json:"vector,omitempty"` + + // CollectionName identifies which collection this result came from + CollectionName string `json:"collectionName,omitempty"` +} + +// EmbeddingInput is the input for inserting vectors into a collection. +type EmbeddingInput struct { + // ID is the unique identifier for this embedding + ID string `json:"id"` + + // Vector is the dense embedding representation + Vector []float32 `json:"vector"` + + // Payload is optional metadata to store with the vector + Payload map[string]any `json:"payload,omitempty"` +} + +// Collection contains metadata about a vector collection. +type Collection struct { + // Name is the unique identifier of the collection + Name string `json:"name"` + + // Status indicates the operational state (e.g., "Green", "Yellow") + Status string `json:"status"` + + // VectorSize is the dimension of vectors in this collection + VectorSize int `json:"vectorSize"` + + // Distance is the similarity metric (e.g., "Cosine", "Dot", "Euclid") + Distance string `json:"distance"` + + // VectorCount is the number of indexed vectors + VectorCount uint64 `json:"vectorCount"` + + // PointCount is the number of stored points/documents + PointCount uint64 `json:"pointCount"` +} diff --git a/v1/vectordb/utils.go b/v1/vectordb/utils.go new file mode 100644 index 0000000..e1328ba --- /dev/null +++ b/v1/vectordb/utils.go @@ -0,0 +1,276 @@ +package vectordb + +import ( + "encoding/json" + "fmt" +) + +// ── FilterSet Constructors ─────────────────────────────────────────────────── + +// NewFilterSet creates a FilterSet with the given clauses. +// Use with Must(), Should(), and MustNot() helpers. +// +// Example: +// +// vectordb.NewFilterSet( +// vectordb.Must(vectordb.NewMatch("status", "published")), +// vectordb.Should(vectordb.NewMatch("tag", "ml"), vectordb.NewMatch("tag", "ai")), +// ) +func NewFilterSet(clauses ...func(*FilterSet)) *FilterSet { + fs := &FilterSet{} + for _, clause := range clauses { + clause(fs) + } + return fs +} + +// Must creates a Must clause (AND logic) with the given conditions. +// All conditions must match for a document to be included. +func Must(conditions ...FilterCondition) func(*FilterSet) { + return func(fs *FilterSet) { + fs.Must = &ConditionSet{Conditions: conditions} + } +} + +// Should creates a Should clause (OR logic) with the given conditions. +// At least one condition must match for a document to be included. +func Should(conditions ...FilterCondition) func(*FilterSet) { + return func(fs *FilterSet) { + fs.Should = &ConditionSet{Conditions: conditions} + } +} + +// MustNot creates a MustNot clause (NOT logic) with the given conditions. +// Documents matching any of these conditions are excluded. +func MustNot(conditions ...FilterCondition) func(*FilterSet) { + return func(fs *FilterSet) { + fs.MustNot = &ConditionSet{Conditions: conditions} + } +} + +// ── Condition Constructors ─────────────────────────────────────────────────── + +// NewMatch creates a match condition for internal fields. +func NewMatch(field string, value any) *MatchCondition { + return &MatchCondition{Field: field, Value: value, FieldType: InternalField} +} + +// NewUserMatch creates a match condition for user-defined fields. +func NewUserMatch(field string, value any) *MatchCondition { + return &MatchCondition{Field: field, Value: value, FieldType: UserField} +} + +// NewMatchAny creates an IN condition for internal fields. +func NewMatchAny(field string, values ...any) *MatchAnyCondition { + validateHomogeneousTypes(values) + return &MatchAnyCondition{Field: field, Values: values, FieldType: InternalField} +} + +// NewUserMatchAny creates an IN condition for user-defined fields. +func NewUserMatchAny(field string, values ...any) *MatchAnyCondition { + validateHomogeneousTypes(values) + return &MatchAnyCondition{Field: field, Values: values, FieldType: UserField} +} + +// NewMatchExcept creates a NOT IN condition for internal fields. +func NewMatchExcept(field string, values ...any) *MatchExceptCondition { + validateHomogeneousTypes(values) + return &MatchExceptCondition{Field: field, Values: values, FieldType: InternalField} +} + +// NewUserMatchExcept creates a NOT IN condition for user-defined fields. +func NewUserMatchExcept(field string, values ...any) *MatchExceptCondition { + validateHomogeneousTypes(values) + return &MatchExceptCondition{Field: field, Values: values, FieldType: UserField} +} + +// NewNumericRange creates a numeric range condition for internal fields. +func NewNumericRange(field string, r NumericRange) *NumericRangeCondition { + return &NumericRangeCondition{ + Field: field, + Range: r, + FieldType: InternalField, + } +} + +// NewUserNumericRange creates a numeric range condition for user-defined fields. +func NewUserNumericRange(field string, r NumericRange) *NumericRangeCondition { + return &NumericRangeCondition{ + Field: field, + Range: r, + FieldType: UserField, + } +} + +// NewTimeRange creates a time range condition for internal fields. +func NewTimeRange(field string, t TimeRange) *TimeRangeCondition { + return &TimeRangeCondition{ + Field: field, + Range: t, + FieldType: InternalField, + } +} + +// NewUserTimeRange creates a time range condition for user-defined fields. +func NewUserTimeRange(field string, t TimeRange) *TimeRangeCondition { + return &TimeRangeCondition{ + Field: field, + Range: t, + FieldType: UserField, + } +} + +// NewIsNull creates an IS NULL condition for internal fields. +func NewIsNull(field string) *IsNullCondition { + return &IsNullCondition{Field: field, FieldType: InternalField} +} + +// NewUserIsNull creates an IS NULL condition for user-defined fields. +func NewUserIsNull(field string) *IsNullCondition { + return &IsNullCondition{Field: field, FieldType: UserField} +} + +// NewIsEmpty creates an IS EMPTY condition for internal fields. +func NewIsEmpty(field string) *IsEmptyCondition { + return &IsEmptyCondition{Field: field, FieldType: InternalField} +} + +// NewUserIsEmpty creates an IS EMPTY condition for user-defined fields. +func NewUserIsEmpty(field string) *IsEmptyCondition { + return &IsEmptyCondition{Field: field, FieldType: UserField} +} + +// ── JSON Serialization ─────────────────────────────────────────────────────── + +// MarshalJSON implements custom JSON marshaling for ConditionSet. +// This is needed because FilterCondition is an interface. +func (cs *ConditionSet) MarshalJSON() ([]byte, error) { + return json.Marshal(cs.Conditions) +} + +// UnmarshalJSON implements custom JSON unmarshaling for ConditionSet. +// It detects the condition type based on JSON keys and deserializes +// into the appropriate concrete type (MatchCondition, NumericRangeCondition, etc.) +func (cs *ConditionSet) UnmarshalJSON(data []byte) error { + var raw []json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + cs.Conditions = make([]FilterCondition, 0, len(raw)) + + for _, r := range raw { + cond, err := parseCondition(r) + if err != nil { + return err + } + cs.Conditions = append(cs.Conditions, cond) + } + + return nil +} + +// parseCondition detects and parses a single FilterCondition from JSON. +// It examines the JSON keys to determine the condition type: +// - "equalTo" → MatchCondition +// - "anyOf" → MatchAnyCondition +// - "noneOf" → MatchExceptCondition +// - "greaterThan", "lessThan", etc. → NumericRangeCondition +// - "after", "before", etc. → TimeRangeCondition +func parseCondition(data []byte) (FilterCondition, error) { + // Extract field names to determine condition type + var fields map[string]json.RawMessage + if err := json.Unmarshal(data, &fields); err != nil { + return nil, err + } + + switch { + case hasKey(fields, "equalTo"): + var c MatchCondition + if err := json.Unmarshal(data, &c); err != nil { + return nil, err + } + return &c, nil + + case hasKey(fields, "anyOf"): + var c MatchAnyCondition + if err := json.Unmarshal(data, &c); err != nil { + return nil, err + } + return &c, nil + + case hasKey(fields, "noneOf"): + var c MatchExceptCondition + if err := json.Unmarshal(data, &c); err != nil { + return nil, err + } + return &c, nil + + case hasKey(fields, "greaterThan"), hasKey(fields, "greaterThanOrEqualTo"), + hasKey(fields, "lessThan"), hasKey(fields, "lessThanOrEqualTo"): + var c NumericRangeCondition + if err := json.Unmarshal(data, &c); err != nil { + return nil, err + } + return &c, nil + + case hasKey(fields, "after"), hasKey(fields, "atOrAfter"), + hasKey(fields, "before"), hasKey(fields, "atOrBefore"): + var c TimeRangeCondition + if err := json.Unmarshal(data, &c); err != nil { + return nil, err + } + return &c, nil + + default: + return nil, fmt.Errorf("unknown filter condition type: %s", string(data)) + } +} + +// hasKey checks if a JSON object contains a specific key. +// Used by parseCondition to detect filter condition types. +func hasKey(m map[string]json.RawMessage, key string) bool { + _, ok := m[key] + return ok +} + +// validateHomogeneousTypes ensures all values are of the same type category. +// Panics if mixed types are detected - this catches programming errors early. +// +// TODO: Consider whether panic is appropriate here, or if we should: +// - Return an error instead (for runtime data validation) +// - Add a separate NewMatchAnyChecked() that returns error +// - Keep panic for constructor calls, error for JSON unmarshaling +func validateHomogeneousTypes(values []any) { + if len(values) <= 1 { + return + } + + expectedType := getType(values[0]) + if expectedType == "" { + panic(fmt.Sprintf("vectordb: unsupported value type: %T", values[0])) + } + + // Validate all values match expected type + for i, v := range values[1:] { + actualType := getType(v) + if actualType == "" { + panic(fmt.Sprintf("vectordb: unsupported value type at index %d: %T", i+1, v)) + } + if actualType != expectedType { + panic(fmt.Sprintf("vectordb: mixed types not allowed in MatchAny/MatchExcept: expected %s but got %s at index %d", expectedType, actualType, i+1)) + } + } +} + +func getType(value any) string { + switch value.(type) { + case string: + return "string" + case int, int64, float64: + return "numeric" + case bool: + return "boolean" + } + return "" +}