From fe8512fba0155264bdebf481c735b65c0a000b98 Mon Sep 17 00:00:00 2001 From: Shinu Joseph Date: Thu, 11 Dec 2025 14:48:45 +0100 Subject: [PATCH 1/4] feat: Decouple filter types from Qdrant-specific implementation --- v1/qdrant/client.go | 2 +- v1/qdrant/converter.go | 346 ++++++++ v1/qdrant/doc.go | 23 +- v1/qdrant/filters.go | 398 --------- v1/qdrant/filters_test.go | 1148 -------------------------- v1/qdrant/operations.go | 564 ++++--------- v1/qdrant/qdrant_integration_test.go | 132 +-- v1/qdrant/utils.go | 3 +- v1/vectordb/doc.go | 138 ++++ v1/vectordb/filters.go | 240 ++++++ v1/vectordb/interface.go | 49 ++ v1/vectordb/types.go | 69 ++ 12 files changed, 1113 insertions(+), 1999 deletions(-) create mode 100644 v1/qdrant/converter.go delete mode 100644 v1/qdrant/filters.go delete mode 100644 v1/qdrant/filters_test.go create mode 100644 v1/vectordb/doc.go create mode 100644 v1/vectordb/filters.go create mode 100644 v1/vectordb/interface.go create mode 100644 v1/vectordb/types.go diff --git a/v1/qdrant/client.go b/v1/qdrant/client.go index a176403..8d50596 100644 --- a/v1/qdrant/client.go +++ b/v1/qdrant/client.go @@ -120,7 +120,7 @@ func (c *QdrantClient) healthCheck() error { // // Close gracefully shuts down the Qdrant client. // -// Since the official Qdrant Go SDK doesn’t maintain persistent connections, +// Since the official Qdrant Go SDK doesn't maintain persistent connections, // this is currently a no-op. It exists for lifecycle symmetry and future safety. func (c *QdrantClient) Close() error { log.Println("[Qdrant] closing client (no-op)") diff --git a/v1/qdrant/converter.go b/v1/qdrant/converter.go new file mode 100644 index 0000000..cd4ebad --- /dev/null +++ b/v1/qdrant/converter.go @@ -0,0 +1,346 @@ +package qdrant + +import ( + "fmt" + "strings" + "time" + + "github.com/Aleph-Alpha/std/v1/vectordb" + qdrant "github.com/qdrant/go-client/qdrant" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// UserPayloadPrefix is the prefix for user-defined metadata fields. +// User fields are stored under "custom." in the Qdrant payload. +const UserPayloadPrefix = "custom" + +// BuildPayload creates a Qdrant payload with separated internal and user fields. +// Internal fields are stored at the top level, while user fields are stored under +// the "custom" prefix. +// +// Example: +// +// payload := BuildPayload( +// map[string]any{"search_store_id": "store123"}, +// map[string]any{"document_id": "doc456"}, +// ) +// // Result: {"search_store_id": "store123", "custom": {"document_id": "doc456"}} +func BuildPayload(internal map[string]any, user map[string]any) map[string]any { + payload := make(map[string]any) + + // Add internal fields at top-level + for k, v := range internal { + payload[k] = v + } + + // Add user fields under "custom" prefix + if len(user) > 0 { + payload[UserPayloadPrefix] = user + } + + return payload +} + +// ── Filter Conversion ──────────────────────────────────────────────────────── + +// convertVectorDBFilterSet converts a vectordb.FilterSet to a Qdrant filter. +func convertVectorDBFilterSet(filters *vectordb.FilterSet) *qdrant.Filter { + if filters == nil { + return nil + } + + filter := &qdrant.Filter{} + + if filters.Must != nil { + filter.Must = convertVectorDBConditionSet(filters.Must) + } + if filters.Should != nil { + filter.Should = convertVectorDBConditionSet(filters.Should) + } + if filters.MustNot != nil { + filter.MustNot = convertVectorDBConditionSet(filters.MustNot) + } + + // Return nil if no conditions were added + if len(filter.Must) == 0 && len(filter.Should) == 0 && len(filter.MustNot) == 0 { + return nil + } + + return filter +} + +// convertVectorDBConditionSet converts a vectordb.ConditionSet to Qdrant conditions. +func convertVectorDBConditionSet(cs *vectordb.ConditionSet) []*qdrant.Condition { + if cs == nil { + return nil + } + + var conditions []*qdrant.Condition + for _, c := range cs.Conditions { + conds := convertVectorDBCondition(c) + for _, cond := range conds { + if cond != nil { + conditions = append(conditions, cond) + } + } + } + return conditions +} + +// convertVectorDBCondition converts a single vectordb.FilterCondition to Qdrant conditions. +func convertVectorDBCondition(c vectordb.FilterCondition) []*qdrant.Condition { + switch cond := c.(type) { + case *vectordb.MatchCondition: + return convertVectorDBMatchCondition(cond) + case *vectordb.MatchAnyCondition: + return convertVectorDBMatchAnyCondition(cond) + case *vectordb.MatchExceptCondition: + return convertVectorDBMatchExceptCondition(cond) + case *vectordb.NumericRangeCondition: + return convertVectorDBNumericRangeCondition(cond) + case *vectordb.TimeRangeCondition: + return convertVectorDBTimeRangeCondition(cond) + case *vectordb.IsNullCondition: + return convertVectorDBIsNullCondition(cond) + case *vectordb.IsEmptyCondition: + return convertVectorDBIsEmptyCondition(cond) + default: + return nil + } +} + +func convertVectorDBMatchCondition(c *vectordb.MatchCondition) []*qdrant.Condition { + key := resolveVectorDBFieldKey(c.Field, c.FieldType) + switch v := c.Value.(type) { + case string: + return []*qdrant.Condition{qdrant.NewMatch(key, v)} + case bool: + return []*qdrant.Condition{qdrant.NewMatchBool(key, v)} + case int: + return []*qdrant.Condition{qdrant.NewMatchInt(key, int64(v))} + case int64: + return []*qdrant.Condition{qdrant.NewMatchInt(key, v)} + case float64: + // Handle JSON numbers which are float64 by default + return []*qdrant.Condition{qdrant.NewMatchInt(key, int64(v))} + default: + return nil + } +} + +func convertVectorDBMatchAnyCondition(c *vectordb.MatchAnyCondition) []*qdrant.Condition { + if len(c.Values) == 0 { + return nil + } + key := resolveVectorDBFieldKey(c.Field, c.FieldType) + + // Detect type from first value + if len(c.Values) > 0 { + switch c.Values[0].(type) { + case string: + strs := make([]string, len(c.Values)) + for i, v := range c.Values { + if s, ok := v.(string); ok { + strs[i] = s + } + } + return []*qdrant.Condition{qdrant.NewMatchKeywords(key, strs...)} + case int, int64, float64: + ints := make([]int64, len(c.Values)) + for i, v := range c.Values { + switch n := v.(type) { + case int: + ints[i] = int64(n) + case int64: + ints[i] = n + case float64: + ints[i] = int64(n) + } + } + return []*qdrant.Condition{qdrant.NewMatchInts(key, ints...)} + } + } + return nil +} + +func convertVectorDBMatchExceptCondition(c *vectordb.MatchExceptCondition) []*qdrant.Condition { + if len(c.Values) == 0 { + return nil + } + key := resolveVectorDBFieldKey(c.Field, c.FieldType) + + // Detect type from first value + if len(c.Values) > 0 { + switch c.Values[0].(type) { + case string: + strs := make([]string, len(c.Values)) + for i, v := range c.Values { + if s, ok := v.(string); ok { + strs[i] = s + } + } + return []*qdrant.Condition{qdrant.NewMatchExceptKeywords(key, strs...)} + case int, int64, float64: + ints := make([]int64, len(c.Values)) + for i, v := range c.Values { + switch n := v.(type) { + case int: + ints[i] = int64(n) + case int64: + ints[i] = n + case float64: + ints[i] = int64(n) + } + } + return []*qdrant.Condition{qdrant.NewMatchExceptInts(key, ints...)} + } + } + return nil +} + +func convertVectorDBNumericRangeCondition(c *vectordb.NumericRangeCondition) []*qdrant.Condition { + key := resolveVectorDBFieldKey(c.Field, c.FieldType) + rangeFilter := &qdrant.Range{ + Gt: c.GreaterThan, + Gte: c.GreaterThanOrEqualTo, + Lt: c.LessThan, + Lte: c.LessThanOrEqualTo, + } + + if rangeFilter.Gt == nil && rangeFilter.Gte == nil && + rangeFilter.Lt == nil && rangeFilter.Lte == nil { + return nil + } + + return []*qdrant.Condition{qdrant.NewRange(key, rangeFilter)} +} + +func convertVectorDBTimeRangeCondition(c *vectordb.TimeRangeCondition) []*qdrant.Condition { + key := resolveVectorDBFieldKey(c.Field, c.FieldType) + dateRange := &qdrant.DatetimeRange{ + Gt: toVectorDBTimestamp(c.After), + Gte: toVectorDBTimestamp(c.AtOrAfter), + Lt: toVectorDBTimestamp(c.Before), + Lte: toVectorDBTimestamp(c.AtOrBefore), + } + + if dateRange.Gt == nil && dateRange.Gte == nil && + dateRange.Lt == nil && dateRange.Lte == nil { + return nil + } + + return []*qdrant.Condition{qdrant.NewDatetimeRange(key, dateRange)} +} + +func convertVectorDBIsNullCondition(c *vectordb.IsNullCondition) []*qdrant.Condition { + key := resolveVectorDBFieldKey(c.Field, c.FieldType) + return []*qdrant.Condition{qdrant.NewIsNull(key)} +} + +func convertVectorDBIsEmptyCondition(c *vectordb.IsEmptyCondition) []*qdrant.Condition { + key := resolveVectorDBFieldKey(c.Field, c.FieldType) + return []*qdrant.Condition{qdrant.NewIsEmpty(key)} +} + +// resolveVectorDBFieldKey returns the full field path based on FieldType. +// Internal fields: "search_store_id" -> "search_store_id" +// User fields: "document_id" -> "custom.document_id" +func resolveVectorDBFieldKey(key string, fieldType vectordb.FieldType) string { + if fieldType == vectordb.UserField { + if strings.HasPrefix(key, UserPayloadPrefix+".") { + return key + } + return UserPayloadPrefix + "." + key + } + return key +} + +func toVectorDBTimestamp(t *time.Time) *timestamppb.Timestamp { + if t == nil { + return nil + } + return timestamppb.New(*t) +} + +// ── Result Conversion ──────────────────────────────────────────────────────── + +// parseVectorDBSearchResults converts Qdrant response to vectordb.SearchResult slice. +func parseVectorDBSearchResults(resp []*qdrant.ScoredPoint, collectionName string) ([]vectordb.SearchResult, error) { + results := make([]vectordb.SearchResult, 0, len(resp)) + for _, r := range resp { + id, err := extractVectorDBPointID(r.Id) + if err != nil { + return nil, err + } + + results = append(results, vectordb.SearchResult{ + ID: id, + Score: r.Score, + Payload: convertVectorDBPayload(r.Payload), + CollectionName: collectionName, + }) + } + return results, nil +} + +// extractVectorDBPointID extracts a string ID from Qdrant's PointId type. +func extractVectorDBPointID(id *qdrant.PointId) (string, error) { + if id == nil { + return "", fmt.Errorf("nil point ID") + } + switch v := id.PointIdOptions.(type) { + case *qdrant.PointId_Num: + return fmt.Sprintf("%d", v.Num), nil + case *qdrant.PointId_Uuid: + return v.Uuid, nil + default: + return "", fmt.Errorf("unexpected PointId type: %T", v) + } +} + +// convertVectorDBPayload converts Qdrant's protobuf payload to a generic map. +func convertVectorDBPayload(payload map[string]*qdrant.Value) map[string]any { + if payload == nil { + return nil + } + result := make(map[string]any, len(payload)) + for k, v := range payload { + result[k] = extractVectorDBValue(v) + } + return result +} + +// extractVectorDBValue recursively converts a Qdrant Value to a Go native type. +func extractVectorDBValue(v *qdrant.Value) any { + if v == nil { + return nil + } + switch val := v.Kind.(type) { + case *qdrant.Value_StringValue: + return val.StringValue + case *qdrant.Value_IntegerValue: + return val.IntegerValue + case *qdrant.Value_DoubleValue: + return val.DoubleValue + case *qdrant.Value_BoolValue: + return val.BoolValue + case *qdrant.Value_NullValue: + return nil + case *qdrant.Value_StructValue: + if val.StructValue == nil { + return nil + } + return convertVectorDBPayload(val.StructValue.Fields) + case *qdrant.Value_ListValue: + if val.ListValue == nil { + return nil + } + items := make([]any, len(val.ListValue.Values)) + for i, item := range val.ListValue.Values { + items[i] = extractVectorDBValue(item) + } + return items + default: + return nil + } +} diff --git a/v1/qdrant/doc.go b/v1/qdrant/doc.go index 7ee0879..71d93bb 100644 --- a/v1/qdrant/doc.go +++ b/v1/qdrant/doc.go @@ -15,8 +15,29 @@ // - Type-safe collection creation and existence checks // - Support for payload metadata and optional vector retrieval // - Extensible abstraction layer for alternate vector stores (e.g., Pinecone, Postgres) +// - VectorDBAdapter implementing vectordb.VectorDBService for DB-agnostic usage // -// Basic Usage: +// # VectorDB Interface +// +// This package includes [VectorDBAdapter] which implements the database-agnostic +// [vectordb.VectorDBService] interface. Use this for new projects to enable +// easy switching between vector databases: +// +// import ( +// "github.com/Aleph-Alpha/std/v1/vectordb" +// "github.com/Aleph-Alpha/std/v1/qdrant" +// ) +// +// // Create your existing QdrantClient +// qc, _ := qdrant.NewQdrantClient(params) +// +// // Create adapter for DB-agnostic usage +// var db vectordb.VectorDBService = qdrant.NewVectorDBAdapter(qc.API()) +// +// This allows switching between vector databases (Qdrant, Weaviate, Pinecone) without +// changing application code. +// +// # Basic Usage // // import "github.com/Aleph-Alpha/std/v1/qdrant" // diff --git a/v1/qdrant/filters.go b/v1/qdrant/filters.go deleted file mode 100644 index 20222e8..0000000 --- a/v1/qdrant/filters.go +++ /dev/null @@ -1,398 +0,0 @@ -package qdrant - -import ( - "encoding/json" - "strings" - "time" - - qdrant "github.com/qdrant/go-client/qdrant" - "google.golang.org/protobuf/types/known/timestamppb" -) - -// UserPayloadPrefix is the prefix for user-defined metadata fields -const UserPayloadPrefix = "custom" - -// FilterCondition is the interface for all filter conditions -type FilterCondition interface { - ToQdrantCondition() []*qdrant.Condition -} - -// FieldType indicates whether a field is internal or user-defined -type FieldType int - -const ( - // InternalField - system-managed fields stored at top-level - InternalField FieldType = iota - // UserField - user-defined fields stored under "custom." prefix - UserField -) - -// TimeRange represents a time-based filter condition -type TimeRange struct { - Gt *time.Time `json:"after,omitempty"` // Greater than this time - Gte *time.Time `json:"atOrAfter,omitempty"` // Greater than or equal to this time - Lt *time.Time `json:"before,omitempty"` // Less than this time - Lte *time.Time `json:"atOrBefore,omitempty"` // Less than or equal to this time -} - -// NumericRange represents a numeric range filter condition -type NumericRange struct { - Gt *float64 `json:"greaterThan,omitempty"` // Greater than - Gte *float64 `json:"greaterThanOrEqualTo,omitempty"` // Greater than or equal - Lt *float64 `json:"lessThan,omitempty"` // Less than - Lte *float64 `json:"lessThanOrEqualTo,omitempty"` // Less than or equal -} - -// MatchCondition represents an exact match condition for a field value. -// Supports string, bool, and int64 types. The FieldType defaults to InternalField -// if not specified, meaning the field is stored at the top level of the payload. -// Use UserField to indicate the field is stored under the "custom." prefix. -type MatchCondition[T comparable] struct { - Key string `json:"field"` - Value T `json:"equalTo"` - FieldType FieldType `json:"-"` // Internal or User field (default: InternalField) -} - -// ToQdrantCondition converts the MatchCondition to Qdrant conditions. -// Supports string, bool, and int64 types. Returns nil for unsupported types. -func (c MatchCondition[T]) ToQdrantCondition() []*qdrant.Condition { - key := resolveFieldKey(c.Key, c.FieldType) - switch v := any(c.Value).(type) { - case string: - return []*qdrant.Condition{qdrant.NewMatch(key, v)} - case bool: - return []*qdrant.Condition{qdrant.NewMatchBool(key, v)} - case int64: - return []*qdrant.Condition{qdrant.NewMatchInt(key, v)} - default: - // Unsupported type - returns nil - return nil - } -} - -// MatchAnyCondition matches if value is one of the given values (IN operator). -// Applicable to keyword (string) and integer payloads. -// Returns nil if Values is empty. The FieldType defaults to InternalField if not specified. -type MatchAnyCondition[T string | int64] struct { - Key string `json:"field"` - Values []T `json:"anyOf"` - FieldType FieldType `json:"-"` // Internal or User field (default: InternalField) -} - -func (c MatchAnyCondition[T]) ToQdrantCondition() []*qdrant.Condition { - if len(c.Values) == 0 { - return nil - } - key := resolveFieldKey(c.Key, c.FieldType) - switch v := any(c.Values).(type) { - case []string: - return []*qdrant.Condition{qdrant.NewMatchKeywords(key, v...)} - case []int64: - return []*qdrant.Condition{qdrant.NewMatchInts(key, v...)} - default: - return nil - } -} - -// MatchExceptCondition matches if value is NOT one of the given values (NOT IN operator). -// Applicable to keyword (string) and integer payloads. -// Returns nil if Values is empty. The FieldType defaults to InternalField if not specified. -type MatchExceptCondition[T string | int64] struct { - Key string `json:"field"` - Values []T `json:"noneOf"` - FieldType FieldType `json:"-"` // Internal or User field (default: InternalField) -} - -func (c MatchExceptCondition[T]) ToQdrantCondition() []*qdrant.Condition { - if len(c.Values) == 0 { - return nil - } - key := resolveFieldKey(c.Key, c.FieldType) - switch v := any(c.Values).(type) { - case []string: - return []*qdrant.Condition{qdrant.NewMatchExceptKeywords(key, v...)} - case []int64: - return []*qdrant.Condition{qdrant.NewMatchExceptInts(key, v...)} - default: - return nil - } -} - -type TextCondition = MatchCondition[string] // Exact string match -type BoolCondition = MatchCondition[bool] // Exact boolean match -type IntCondition = MatchCondition[int64] // Exact integer match -type TextAnyCondition = MatchAnyCondition[string] // String IN operator -type IntAnyCondition = MatchAnyCondition[int64] // Integer IN operator -type TextExceptCondition = MatchExceptCondition[string] // String NOT IN -type IntExceptCondition = MatchExceptCondition[int64] // Integer NOT IN - -// TimeRangeCondition represents a time range filter condition -type TimeRangeCondition struct { - Key string `json:"field"` - Value TimeRange `json:"-"` - FieldType FieldType `json:"-"` -} - -func (c TimeRangeCondition) ToQdrantCondition() []*qdrant.Condition { - return buildDateTimeRangeConditions(resolveFieldKey(c.Key, c.FieldType), c.Value) -} - -func (c TimeRangeCondition) MarshalJSON() ([]byte, error) { - type Alias struct { - Field string `json:"field"` - After *time.Time `json:"after,omitempty"` - AtOrAfter *time.Time `json:"atOrAfter,omitempty"` - Before *time.Time `json:"before,omitempty"` - AtOrBefore *time.Time `json:"atOrBefore,omitempty"` - } - return json.Marshal(Alias{ - Field: c.Key, - After: c.Value.Gt, - AtOrAfter: c.Value.Gte, - Before: c.Value.Lt, - AtOrBefore: c.Value.Lte, - }) -} - -func (c *TimeRangeCondition) UnmarshalJSON(data []byte) error { - type Alias struct { - Field string `json:"field"` - After *time.Time `json:"after,omitempty"` - AtOrAfter *time.Time `json:"atOrAfter,omitempty"` - Before *time.Time `json:"before,omitempty"` - AtOrBefore *time.Time `json:"atOrBefore,omitempty"` - } - var alias Alias - if err := json.Unmarshal(data, &alias); err != nil { - return err - } - c.Key = alias.Field - c.Value = TimeRange{ - Gt: alias.After, - Gte: alias.AtOrAfter, - Lt: alias.Before, - Lte: alias.AtOrBefore, - } - return nil -} - -// NumericRangeCondition represents a numeric range filter -type NumericRangeCondition struct { - Key string `json:"field"` - Value NumericRange `json:"-"` - FieldType FieldType `json:"-"` -} - -func (c NumericRangeCondition) ToQdrantCondition() []*qdrant.Condition { - return buildNumericRangeConditions(resolveFieldKey(c.Key, c.FieldType), c.Value) -} - -func (c NumericRangeCondition) MarshalJSON() ([]byte, error) { - type Alias struct { - Field string `json:"field"` - GreaterThan *float64 `json:"greaterThan,omitempty"` - GreaterThanOrEqualTo *float64 `json:"greaterThanOrEqualTo,omitempty"` - LessThan *float64 `json:"lessThan,omitempty"` - LessThanOrEqualTo *float64 `json:"lessThanOrEqualTo,omitempty"` - } - return json.Marshal(Alias{ - Field: c.Key, - GreaterThan: c.Value.Gt, - GreaterThanOrEqualTo: c.Value.Gte, - LessThan: c.Value.Lt, - LessThanOrEqualTo: c.Value.Lte, - }) -} - -func (c *NumericRangeCondition) UnmarshalJSON(data []byte) error { - type Alias struct { - Field string `json:"field"` - GreaterThan *float64 `json:"greaterThan,omitempty"` - GreaterThanOrEqualTo *float64 `json:"greaterThanOrEqualTo,omitempty"` - LessThan *float64 `json:"lessThan,omitempty"` - LessThanOrEqualTo *float64 `json:"lessThanOrEqualTo,omitempty"` - } - var alias Alias - if err := json.Unmarshal(data, &alias); err != nil { - return err - } - c.Key = alias.Field - c.Value = NumericRange{ - Gt: alias.GreaterThan, - Gte: alias.GreaterThanOrEqualTo, - Lt: alias.LessThan, - Lte: alias.LessThanOrEqualTo, - } - return nil -} - -// resolveFieldKey returns the full field path based on FieldType -// Internal fields: "search_store_id" -> "search_store_id" -// User fields: "document_id" -> "custom.document_id" -func resolveFieldKey(key string, fieldType FieldType) string { - if fieldType == UserField { - // Prevent double-prefixing - if strings.HasPrefix(key, UserPayloadPrefix+".") { - return key - } - return UserPayloadPrefix + "." + key - } - return key -} - -// ConditionSet holds conditions for a single clause -type ConditionSet struct { - Conditions []FilterCondition `json:"conditions,omitempty"` -} - -// FilterSet supports Must (AND), Should (OR), and MustNot (NOT) clauses. -// Use with SearchRequest.Filters to filter search results. -// -// Example: -// -// filters := &FilterSet{ -// Must: &ConditionSet{ -// Conditions: []FilterCondition{ -// TextCondition{Key: "city", Value: "London"}, -// }, -// }, -// } -type FilterSet struct { - Must *ConditionSet `json:"must,omitempty"` // AND - all conditions must match - Should *ConditionSet `json:"should,omitempty"` // OR - at least one condition must match - MustNot *ConditionSet `json:"mustNot,omitempty"` // NOT - none of the conditions should match -} - -// buildFilter constructs a Qdrant filter from FilterSet -func buildFilter(filters *FilterSet) *qdrant.Filter { - if filters == nil { - return nil - } - - filter := &qdrant.Filter{} - - if filters.Must != nil { - filter.Must = buildConditions(filters.Must) - } - - if filters.Should != nil { - filter.Should = buildConditions(filters.Should) - } - - if filters.MustNot != nil { - filter.MustNot = buildConditions(filters.MustNot) - } - - // Return nil if no conditions were added - if len(filter.Must) == 0 && len(filter.Should) == 0 && len(filter.MustNot) == 0 { - return nil - } - - return filter -} - -// buildConditions converts a ConditionSet to Qdrant conditions -// Filters out nil conditions that may be returned by invalid conditions (e.g., empty ranges) -func buildConditions(cs *ConditionSet) []*qdrant.Condition { - if cs == nil { - return nil - } - - var conditions []*qdrant.Condition - for _, c := range cs.Conditions { - conds := c.ToQdrantCondition() - for _, cond := range conds { - if cond != nil { - conditions = append(conditions, cond) - } - } - } - return conditions -} - -// buildDateTimeRangeConditions creates datetime range conditions -func buildDateTimeRangeConditions(key string, tr TimeRange) []*qdrant.Condition { - dateRange := &qdrant.DatetimeRange{ - Gt: toTimestamp(tr.Gt), - Gte: toTimestamp(tr.Gte), - Lt: toTimestamp(tr.Lt), - Lte: toTimestamp(tr.Lte), - } - - // Check if any field is set - if dateRange.Gt == nil && dateRange.Gte == nil && dateRange.Lt == nil && dateRange.Lte == nil { - return nil - } - - return []*qdrant.Condition{qdrant.NewDatetimeRange(key, dateRange)} -} - -// buildNumericRangeConditions creates numeric range conditions -func buildNumericRangeConditions(key string, nr NumericRange) []*qdrant.Condition { - rangeFilter := &qdrant.Range{ - Gt: nr.Gt, - Gte: nr.Gte, - Lt: nr.Lt, - Lte: nr.Lte, - } - - // Check if any field is set - if rangeFilter.Gt == nil && rangeFilter.Gte == nil && rangeFilter.Lt == nil && rangeFilter.Lte == nil { - return nil - } - - return []*qdrant.Condition{qdrant.NewRange(key, rangeFilter)} -} - -// toTimestamp converts a *time.Time to *timestamppb.Timestamp (nil-safe) -func toTimestamp(t *time.Time) *timestamppb.Timestamp { - if t == nil { - return nil - } - return timestamppb.New(*t) -} - -// IsNullCondition checks if a field is null -type IsNullCondition struct { - Key string `json:"field"` - FieldType FieldType `json:"-"` // Internal or User field (default: InternalField) -} - -func (c IsNullCondition) ToQdrantCondition() []*qdrant.Condition { - key := resolveFieldKey(c.Key, c.FieldType) - return []*qdrant.Condition{qdrant.NewIsNull(key)} -} - -// IsEmptyCondition checks if a field is empty (does not exist, null, or []) -type IsEmptyCondition struct { - Key string `json:"field"` - FieldType FieldType `json:"-"` // Internal or User field (default: InternalField) -} - -func (c IsEmptyCondition) ToQdrantCondition() []*qdrant.Condition { - key := resolveFieldKey(c.Key, c.FieldType) - return []*qdrant.Condition{qdrant.NewIsEmpty(key)} -} - -// === Payload Helpers === - -// BuildPayload creates a Qdrant payload with separated internal and user fields. -// Internal fields are stored at the top level, while user fields are stored under -// the "custom" prefix. If internal contains a "custom" key, it will be overwritten -// by the user fields map. -func BuildPayload(internal map[string]any, user map[string]any) map[string]any { - payload := make(map[string]any) - - // Add internal fields at top-level - for k, v := range internal { - payload[k] = v - } - - // Add user fields under "custom" prefix - // Note: This will overwrite any "custom" key that was in the internal map - if len(user) > 0 { - payload[UserPayloadPrefix] = user - } - - return payload -} diff --git a/v1/qdrant/filters_test.go b/v1/qdrant/filters_test.go deleted file mode 100644 index 495b583..0000000 --- a/v1/qdrant/filters_test.go +++ /dev/null @@ -1,1148 +0,0 @@ -package qdrant - -import ( - "encoding/json" - "testing" - "time" -) - -func TestBuildFilter_NilFilterSet(t *testing.T) { - result := buildFilter(nil) - if result != nil { - t.Errorf("expected nil, got %v", result) - } -} - -func TestBuildFilter_EmptyFilterSet(t *testing.T) { - filters := &FilterSet{} - result := buildFilter(filters) - if result != nil { - t.Errorf("expected nil, got %v", result) - } -} - -func TestBuildFilter_EmptyConditionSet(t *testing.T) { - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{}, - }, - } - result := buildFilter(filters) - if result != nil { - t.Errorf("expected nil, got %v", result) - } -} - -func TestBuildFilter_MustWithTextCondition(t *testing.T) { - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{ - TextCondition{Key: "city", Value: "London"}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.Must) != 1 { - t.Errorf("expected 1 Must condition, got %d", len(result.Must)) - } - if len(result.Should) != 0 { - t.Errorf("expected 0 Should conditions, got %d", len(result.Should)) - } - if len(result.MustNot) != 0 { - t.Errorf("expected 0 MustNot conditions, got %d", len(result.MustNot)) - } -} - -func TestBuildFilter_ShouldWithMultipleTextConditions(t *testing.T) { - // city = "London" OR city = "Berlin" - filters := &FilterSet{ - Should: &ConditionSet{ - Conditions: []FilterCondition{ - TextCondition{Key: "city", Value: "London"}, - TextCondition{Key: "city", Value: "Berlin"}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.Should) != 2 { - t.Errorf("expected 2 Should conditions, got %d", len(result.Should)) - } -} - -func TestBuildFilter_MustNotWithBoolCondition(t *testing.T) { - filters := &FilterSet{ - MustNot: &ConditionSet{ - Conditions: []FilterCondition{ - BoolCondition{Key: "archived", Value: true}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.MustNot) != 1 { - t.Errorf("expected 1 MustNot condition, got %d", len(result.MustNot)) - } -} - -func TestBuildFilter_MixedConditionTypes(t *testing.T) { - // city = "London" AND active = true AND priority = 1 - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{ - TextCondition{Key: "city", Value: "London"}, - BoolCondition{Key: "active", Value: true}, - IntCondition{Key: "priority", Value: 1}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.Must) != 3 { - t.Errorf("expected 3 Must conditions, got %d", len(result.Must)) - } -} - -func TestBuildFilter_CombinedClauses(t *testing.T) { - // (city = "London" AND active = true) AND NOT archived - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{ - TextCondition{Key: "city", Value: "London"}, - BoolCondition{Key: "active", Value: true}, - }, - }, - MustNot: &ConditionSet{ - Conditions: []FilterCondition{ - BoolCondition{Key: "archived", Value: true}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.Must) != 2 { - t.Errorf("expected 2 Must conditions, got %d", len(result.Must)) - } - if len(result.MustNot) != 1 { - t.Errorf("expected 1 MustNot condition, got %d", len(result.MustNot)) - } -} - -func TestBuildFilter_AllThreeClauses(t *testing.T) { - // Must AND Should AND MustNot - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{ - TextCondition{Key: "status", Value: "active"}, - }, - }, - Should: &ConditionSet{ - Conditions: []FilterCondition{ - TextCondition{Key: "city", Value: "London"}, - TextCondition{Key: "city", Value: "Berlin"}, - }, - }, - MustNot: &ConditionSet{ - Conditions: []FilterCondition{ - BoolCondition{Key: "deleted", Value: true}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.Must) != 1 { - t.Errorf("expected 1 Must condition, got %d", len(result.Must)) - } - if len(result.Should) != 2 { - t.Errorf("expected 2 Should conditions, got %d", len(result.Should)) - } - if len(result.MustNot) != 1 { - t.Errorf("expected 1 MustNot condition, got %d", len(result.MustNot)) - } -} - -func TestBuildFilter_TimeRangeCondition(t *testing.T) { - startTime := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) - endTime := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) - - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{ - TimeRangeCondition{ - Key: "created_at", - Value: TimeRange{ - Gte: &startTime, - Lt: &endTime, - }, - }, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.Must) != 1 { - t.Errorf("expected 1 Must condition, got %d", len(result.Must)) - } -} - -func TestBuildFilter_TimeRangeAllBounds(t *testing.T) { - gt := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) - gte := time.Date(2023, 2, 1, 0, 0, 0, 0, time.UTC) - lt := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) - lte := time.Date(2024, 2, 1, 0, 0, 0, 0, time.UTC) - - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{ - TimeRangeCondition{ - Key: "updated_at", - Value: TimeRange{ - Gt: >, - Gte: >e, - Lt: <, - Lte: <e, - }, - }, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.Must) != 1 { - t.Errorf("expected 1 Must condition, got %d", len(result.Must)) - } -} - -func TestBuildFilter_EmptyTimeRange(t *testing.T) { - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{ - TimeRangeCondition{ - Key: "created_at", - Value: TimeRange{}, // All nil - }, - }, - }, - } - result := buildFilter(filters) - - // Empty TimeRange returns nil condition, so filter should be nil - if result != nil { - t.Errorf("expected nil for empty time range, got %v", result) - } -} - -func TestBuildConditions_NilConditionSet(t *testing.T) { - result := buildConditions(nil) - if result != nil { - t.Errorf("expected nil, got %v", result) - } -} - -func TestBuildConditions_FiltersNilConditions(t *testing.T) { - // Test that buildConditions filters out nil conditions - cs := &ConditionSet{ - Conditions: []FilterCondition{ - TextCondition{Key: "city", Value: "London"}, - TimeRangeCondition{Key: "created_at", Value: TimeRange{}}, // Empty range returns nil - TextAnyCondition{Key: "status", Values: []string{}}, // Empty slice returns nil - BoolCondition{Key: "active", Value: true}, - }, - } - result := buildConditions(cs) - - // Should only have 2 conditions (TextCondition and BoolCondition) - // Empty TimeRange and empty TextAnyCondition should be filtered out - if len(result) != 2 { - t.Errorf("expected 2 conditions (nil ones filtered out), got %d", len(result)) - } -} - -func TestTextCondition_ToQdrantCondition(t *testing.T) { - c := TextCondition{Key: "city", Value: "London"} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestBoolCondition_ToQdrantCondition(t *testing.T) { - c := BoolCondition{Key: "active", Value: true} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestIntCondition_ToQdrantCondition(t *testing.T) { - c := IntCondition{Key: "priority", Value: 42} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestTimeRangeCondition_ToQdrantCondition(t *testing.T) { - now := time.Now() - c := TimeRangeCondition{ - Key: "created_at", - Value: TimeRange{Gte: &now}, - } - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestTimeRangeCondition_EmptyRange(t *testing.T) { - c := TimeRangeCondition{ - Key: "created_at", - Value: TimeRange{}, // All nil - } - result := c.ToQdrantCondition() - - if result != nil { - t.Errorf("expected nil for empty time range, got %v", result) - } -} - -func TestToTimestamp_Nil(t *testing.T) { - result := toTimestamp(nil) - if result != nil { - t.Errorf("expected nil, got %v", result) - } -} - -func TestToTimestamp_ValidTime(t *testing.T) { - now := time.Now() - result := toTimestamp(&now) - - if result == nil { - t.Fatal("expected timestamp, got nil") - } - if result.AsTime().Unix() != now.Unix() { - t.Errorf("timestamp mismatch: expected %v, got %v", now.Unix(), result.AsTime().Unix()) - } -} - -// === FieldType Tests === - -func TestResolveFieldKey_InternalField(t *testing.T) { - key := resolveFieldKey("search_store_id", InternalField) - expected := "search_store_id" - if key != expected { - t.Errorf("expected %q, got %q", expected, key) - } -} - -func TestResolveFieldKey_UserField(t *testing.T) { - key := resolveFieldKey("document_id", UserField) - expected := "custom.document_id" - if key != expected { - t.Errorf("expected %q, got %q", expected, key) - } -} - -func TestResolveFieldKey_UserField_PreventDoublePrefix(t *testing.T) { - // If key already has prefix, don't add again - key := resolveFieldKey("custom.document_id", UserField) - expected := "custom.document_id" - if key != expected { - t.Errorf("expected %q, got %q (double prefix detected)", expected, key) - } -} - -func TestTextCondition_InternalField(t *testing.T) { - c := TextCondition{Key: "search_store_id", Value: "store-123", FieldType: InternalField} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } - // Internal field should NOT have prefix -} - -func TestTextCondition_UserField(t *testing.T) { - c := TextCondition{Key: "document_id", Value: "doc-456", FieldType: UserField} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } - // User field should have "custom." prefix -} - -func TestBoolCondition_UserField(t *testing.T) { - c := BoolCondition{Key: "is_reviewed", Value: true, FieldType: UserField} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestIntCondition_UserField(t *testing.T) { - c := IntCondition{Key: "version", Value: 2, FieldType: UserField} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestTimeRangeCondition_UserField(t *testing.T) { - now := time.Now() - c := TimeRangeCondition{ - Key: "uploaded_at", - Value: TimeRange{Gte: &now}, - FieldType: UserField, - } - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestBuildFilter_MixedInternalAndUserFields(t *testing.T) { - // search_store_id = "store-123" (internal) AND custom.category = "reports" (user) - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{ - TextCondition{Key: "search_store_id", Value: "store-123", FieldType: InternalField}, - TextCondition{Key: "category", Value: "reports", FieldType: UserField}, - BoolCondition{Key: "is_published", Value: true, FieldType: UserField}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.Must) != 3 { - t.Errorf("expected 3 Must conditions, got %d", len(result.Must)) - } -} - -// === BuildPayload Tests === - -func TestBuildPayload_OnlyInternal(t *testing.T) { - internal := map[string]any{ - "search_store_id": "store-123", - "modalities": []string{"text"}, - } - payload := BuildPayload(internal, nil) - - if payload["search_store_id"] != "store-123" { - t.Errorf("expected search_store_id at top-level") - } - if _, exists := payload["custom"]; exists { - t.Errorf("custom should not exist when user is nil") - } -} - -func TestBuildPayload_OnlyUser(t *testing.T) { - user := map[string]any{ - "document_id": "doc-456", - "author": "John", - } - payload := BuildPayload(nil, user) - - custom, ok := payload["custom"].(map[string]any) - if !ok { - t.Fatal("expected custom field") - } - if custom["document_id"] != "doc-456" { - t.Errorf("expected document_id in custom") - } - if custom["author"] != "John" { - t.Errorf("expected author in custom") - } -} - -func TestBuildPayload_BothInternalAndUser(t *testing.T) { - internal := map[string]any{ - "search_store_id": "store-123", - } - user := map[string]any{ - "document_id": "doc-456", - "category": "reports", - } - payload := BuildPayload(internal, user) - - // Check internal at top-level - if payload["search_store_id"] != "store-123" { - t.Errorf("expected search_store_id at top-level") - } - - // Check user under custom - custom, ok := payload["custom"].(map[string]any) - if !ok { - t.Fatal("expected custom field") - } - if custom["document_id"] != "doc-456" { - t.Errorf("expected document_id in custom") - } - if custom["category"] != "reports" { - t.Errorf("expected category in custom") - } -} - -func TestBuildPayload_EmptyUser(t *testing.T) { - internal := map[string]any{ - "search_store_id": "store-123", - } - user := map[string]any{} // Empty, not nil - payload := BuildPayload(internal, user) - - if _, exists := payload["custom"]; exists { - t.Errorf("custom should not exist when user is empty") - } -} - -func TestResolveFieldKey_ActualPath(t *testing.T) { - tests := []struct { - key string - fieldType FieldType - expected string - }{ - {"city", InternalField, "city"}, - {"city", UserField, "custom.city"}, - {"custom.city", UserField, "custom.city"}, // No double prefix - } - - for _, tt := range tests { - result := resolveFieldKey(tt.key, tt.fieldType) - if result != tt.expected { - t.Errorf("resolveFieldKey(%q, %v) = %q, want %q", - tt.key, tt.fieldType, result, tt.expected) - } - } -} - -// === MatchAnyCondition Tests === - -func TestTextAnyCondition_ToQdrantCondition(t *testing.T) { - c := TextAnyCondition{Key: "city", Values: []string{"London", "Berlin", "Paris"}} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestIntAnyCondition_ToQdrantCondition(t *testing.T) { - c := IntAnyCondition{Key: "priority", Values: []int64{1, 2, 3}} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestTextAnyCondition_UserField(t *testing.T) { - c := TextAnyCondition{Key: "category", Values: []string{"tech", "science"}, FieldType: UserField} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestBuildFilter_WithTextAnyCondition(t *testing.T) { - // city IN ("London", "Berlin") - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{ - TextAnyCondition{Key: "city", Values: []string{"London", "Berlin"}}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.Must) != 1 { - t.Errorf("expected 1 Must condition, got %d", len(result.Must)) - } -} - -func TestTextAnyCondition_EmptySlice(t *testing.T) { - // Empty slice should return nil - c := TextAnyCondition{Key: "city", Values: []string{}} - result := c.ToQdrantCondition() - - if result != nil { - t.Errorf("expected nil for empty slice, got %v", result) - } -} - -func TestIntAnyCondition_EmptySlice(t *testing.T) { - // Empty slice should return nil - c := IntAnyCondition{Key: "priority", Values: []int64{}} - result := c.ToQdrantCondition() - - if result != nil { - t.Errorf("expected nil for empty slice, got %v", result) - } -} - -func TestBuildFilter_WithEmptyTextAnyCondition(t *testing.T) { - // Empty TextAnyCondition should be filtered out - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{ - TextAnyCondition{Key: "city", Values: []string{}}, - TextCondition{Key: "status", Value: "active"}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - // Should only have the TextCondition, empty TextAnyCondition should be filtered out - if len(result.Must) != 1 { - t.Errorf("expected 1 Must condition (empty one filtered out), got %d", len(result.Must)) - } -} - -// === MatchExceptCondition Tests === - -func TestTextExceptCondition_ToQdrantCondition(t *testing.T) { - c := TextExceptCondition{Key: "city", Values: []string{"Paris", "Madrid"}} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestIntExceptCondition_ToQdrantCondition(t *testing.T) { - c := IntExceptCondition{Key: "priority", Values: []int64{0, -1}} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestTextExceptCondition_UserField(t *testing.T) { - c := TextExceptCondition{Key: "status", Values: []string{"draft", "deleted"}, FieldType: UserField} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestBuildFilter_WithTextExceptCondition(t *testing.T) { - // city NOT IN ("Paris", "Madrid") - filters := &FilterSet{ - MustNot: &ConditionSet{ - Conditions: []FilterCondition{ - TextExceptCondition{Key: "city", Values: []string{"Paris", "Madrid"}}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.MustNot) != 1 { - t.Errorf("expected 1 MustNot condition, got %d", len(result.MustNot)) - } -} - -func TestTextExceptCondition_EmptySlice(t *testing.T) { - // Empty slice should return nil - c := TextExceptCondition{Key: "city", Values: []string{}} - result := c.ToQdrantCondition() - - if result != nil { - t.Errorf("expected nil for empty slice, got %v", result) - } -} - -func TestIntExceptCondition_EmptySlice(t *testing.T) { - // Empty slice should return nil - c := IntExceptCondition{Key: "priority", Values: []int64{}} - result := c.ToQdrantCondition() - - if result != nil { - t.Errorf("expected nil for empty slice, got %v", result) - } -} - -func TestBuildFilter_WithEmptyTextExceptCondition(t *testing.T) { - // Empty TextExceptCondition should be filtered out - filters := &FilterSet{ - MustNot: &ConditionSet{ - Conditions: []FilterCondition{ - TextExceptCondition{Key: "city", Values: []string{}}, - BoolCondition{Key: "deleted", Value: true}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - // Should only have the BoolCondition, empty TextExceptCondition should be filtered out - if len(result.MustNot) != 1 { - t.Errorf("expected 1 MustNot condition (empty one filtered out), got %d", len(result.MustNot)) - } -} - -// === NumericRangeCondition Tests === - -func TestNumericRangeCondition_ToQdrantCondition(t *testing.T) { - minPrice := 100.0 - maxPrice := 500.0 - c := NumericRangeCondition{ - Key: "price", - Value: NumericRange{ - Gte: &minPrice, - Lte: &maxPrice, - }, - } - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestNumericRangeCondition_AllBounds(t *testing.T) { - gt := 10.0 - gte := 20.0 - lt := 100.0 - lte := 90.0 - c := NumericRangeCondition{ - Key: "score", - Value: NumericRange{ - Gt: >, - Gte: >e, - Lt: <, - Lte: <e, - }, - } - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestNumericRangeCondition_EmptyRange(t *testing.T) { - c := NumericRangeCondition{ - Key: "price", - Value: NumericRange{}, // All nil - } - result := c.ToQdrantCondition() - - if result != nil { - t.Errorf("expected nil for empty numeric range, got %v", result) - } -} - -func TestNumericRangeCondition_UserField(t *testing.T) { - minPrice := 50.0 - c := NumericRangeCondition{ - Key: "price", - Value: NumericRange{Gte: &minPrice}, - FieldType: UserField, - } - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestBuildFilter_WithNumericRangeCondition(t *testing.T) { - minPrice := 100.0 - maxPrice := 500.0 - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{ - NumericRangeCondition{ - Key: "price", - Value: NumericRange{ - Gte: &minPrice, - Lte: &maxPrice, - }, - }, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.Must) != 1 { - t.Errorf("expected 1 Must condition, got %d", len(result.Must)) - } -} - -// === IsNullCondition Tests === - -func TestIsNullCondition_ToQdrantCondition(t *testing.T) { - c := IsNullCondition{Key: "deleted_at"} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestIsNullCondition_UserField(t *testing.T) { - c := IsNullCondition{Key: "review_date", FieldType: UserField} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestBuildFilter_WithIsNullCondition(t *testing.T) { - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{ - IsNullCondition{Key: "deleted_at"}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.Must) != 1 { - t.Errorf("expected 1 Must condition, got %d", len(result.Must)) - } -} - -// === IsEmptyCondition Tests === - -func TestIsEmptyCondition_ToQdrantCondition(t *testing.T) { - c := IsEmptyCondition{Key: "tags"} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestIsEmptyCondition_UserField(t *testing.T) { - c := IsEmptyCondition{Key: "categories", FieldType: UserField} - result := c.ToQdrantCondition() - - if len(result) != 1 { - t.Errorf("expected 1 condition, got %d", len(result)) - } -} - -func TestBuildFilter_WithIsEmptyCondition(t *testing.T) { - // Find documents where tags is NOT empty (using MustNot) - filters := &FilterSet{ - MustNot: &ConditionSet{ - Conditions: []FilterCondition{ - IsEmptyCondition{Key: "tags"}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.MustNot) != 1 { - t.Errorf("expected 1 MustNot condition, got %d", len(result.MustNot)) - } -} - -// === MarshalJSON Tests === - -func TestTimeRangeCondition_MarshalJSON(t *testing.T) { - startTime := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) - endTime := time.Date(2024, 12, 31, 0, 0, 0, 0, time.UTC) - - c := TimeRangeCondition{ - Key: "created_at", - Value: TimeRange{ - Gte: &startTime, - Lt: &endTime, - }, - } - - data, err := c.MarshalJSON() - if err != nil { - t.Fatalf("MarshalJSON failed: %v", err) - } - - jsonStr := string(data) - // Check that it contains expected fields - if !contains(jsonStr, `"field":"created_at"`) { - t.Errorf("expected field in JSON, got %s", jsonStr) - } - if !contains(jsonStr, `"atOrAfter"`) { - t.Errorf("expected atOrAfter in JSON, got %s", jsonStr) - } - if !contains(jsonStr, `"before"`) { - t.Errorf("expected before in JSON, got %s", jsonStr) - } -} - -func TestNumericRangeCondition_MarshalJSON(t *testing.T) { - minPrice := 100.0 - maxPrice := 500.0 - - c := NumericRangeCondition{ - Key: "price", - Value: NumericRange{ - Gte: &minPrice, - Lte: &maxPrice, - }, - } - - data, err := c.MarshalJSON() - if err != nil { - t.Fatalf("MarshalJSON failed: %v", err) - } - - jsonStr := string(data) - if !contains(jsonStr, `"field":"price"`) { - t.Errorf("expected field in JSON, got %s", jsonStr) - } - if !contains(jsonStr, `"greaterThanOrEqualTo"`) { - t.Errorf("expected greaterThanOrEqualTo in JSON, got %s", jsonStr) - } - if !contains(jsonStr, `"lessThanOrEqualTo"`) { - t.Errorf("expected lessThanOrEqualTo in JSON, got %s", jsonStr) - } -} - -func TestTimeRangeCondition_UnmarshalJSON(t *testing.T) { - jsonData := `{ - "field": "created_at", - "atOrAfter": "2024-01-01T00:00:00Z", - "before": "2024-12-31T00:00:00Z" - }` - - var c TimeRangeCondition - err := json.Unmarshal([]byte(jsonData), &c) - if err != nil { - t.Fatalf("UnmarshalJSON failed: %v", err) - } - - if c.Key != "created_at" { - t.Errorf("expected field 'created_at', got %q", c.Key) - } - if c.Value.Gte == nil { - t.Error("expected Gte to be set") - } - if c.Value.Lt == nil { - t.Error("expected Lt to be set") - } - if c.Value.Gte != nil { - expected := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) - if !c.Value.Gte.Equal(expected) { - t.Errorf("expected Gte %v, got %v", expected, c.Value.Gte) - } - } -} - -func TestTimeRangeCondition_RoundTripJSON(t *testing.T) { - startTime := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) - endTime := time.Date(2024, 12, 31, 0, 0, 0, 0, time.UTC) - - original := TimeRangeCondition{ - Key: "created_at", - Value: TimeRange{ - Gte: &startTime, - Lt: &endTime, - }, - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("Marshal failed: %v", err) - } - - var unmarshaled TimeRangeCondition - err = json.Unmarshal(data, &unmarshaled) - if err != nil { - t.Fatalf("Unmarshal failed: %v", err) - } - - if unmarshaled.Key != original.Key { - t.Errorf("field mismatch: expected %q, got %q", original.Key, unmarshaled.Key) - } - if unmarshaled.Value.Gte == nil || !unmarshaled.Value.Gte.Equal(*original.Value.Gte) { - t.Errorf("Gte mismatch: expected %v, got %v", original.Value.Gte, unmarshaled.Value.Gte) - } - if unmarshaled.Value.Lt == nil || !unmarshaled.Value.Lt.Equal(*original.Value.Lt) { - t.Errorf("Lt mismatch: expected %v, got %v", original.Value.Lt, unmarshaled.Value.Lt) - } -} - -func TestNumericRangeCondition_UnmarshalJSON(t *testing.T) { - jsonData := `{ - "field": "price", - "greaterThanOrEqualTo": 100.0, - "lessThanOrEqualTo": 500.0 - }` - - var c NumericRangeCondition - err := json.Unmarshal([]byte(jsonData), &c) - if err != nil { - t.Fatalf("UnmarshalJSON failed: %v", err) - } - - if c.Key != "price" { - t.Errorf("expected field 'price', got %q", c.Key) - } - if c.Value.Gte == nil || *c.Value.Gte != 100.0 { - t.Errorf("expected Gte to be 100.0, got %v", c.Value.Gte) - } - if c.Value.Lte == nil || *c.Value.Lte != 500.0 { - t.Errorf("expected Lte to be 500.0, got %v", c.Value.Lte) - } -} - -func TestNumericRangeCondition_RoundTripJSON(t *testing.T) { - minPrice := 100.0 - maxPrice := 500.0 - - original := NumericRangeCondition{ - Key: "price", - Value: NumericRange{ - Gte: &minPrice, - Lte: &maxPrice, - }, - } - - data, err := json.Marshal(original) - if err != nil { - t.Fatalf("Marshal failed: %v", err) - } - - var unmarshaled NumericRangeCondition - err = json.Unmarshal(data, &unmarshaled) - if err != nil { - t.Fatalf("Unmarshal failed: %v", err) - } - - if unmarshaled.Key != original.Key { - t.Errorf("field mismatch: expected %q, got %q", original.Key, unmarshaled.Key) - } - if unmarshaled.Value.Gte == nil || *unmarshaled.Value.Gte != *original.Value.Gte { - t.Errorf("Gte mismatch: expected %v, got %v", original.Value.Gte, unmarshaled.Value.Gte) - } - if unmarshaled.Value.Lte == nil || *unmarshaled.Value.Lte != *original.Value.Lte { - t.Errorf("Lte mismatch: expected %v, got %v", original.Value.Lte, unmarshaled.Value.Lte) - } -} - -// === Complex Filter Tests === - -func TestBuildFilter_ComplexCombination(t *testing.T) { - startTime := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) - minPrice := 100.0 - - // Complex filter: - // (search_store_id = "store-123" AND created_at >= 2024-01-01 AND price >= 100) - // AND (city IN ("London", "Berlin")) - // AND NOT (deleted = true OR status IN ("draft", "archived")) - filters := &FilterSet{ - Must: &ConditionSet{ - Conditions: []FilterCondition{ - TextCondition{Key: "search_store_id", Value: "store-123"}, - TimeRangeCondition{ - Key: "created_at", - Value: TimeRange{Gte: &startTime}, - }, - NumericRangeCondition{ - Key: "price", - Value: NumericRange{Gte: &minPrice}, - FieldType: UserField, - }, - }, - }, - Should: &ConditionSet{ - Conditions: []FilterCondition{ - TextAnyCondition{Key: "city", Values: []string{"London", "Berlin"}}, - }, - }, - MustNot: &ConditionSet{ - Conditions: []FilterCondition{ - BoolCondition{Key: "deleted", Value: true}, - TextAnyCondition{Key: "status", Values: []string{"draft", "archived"}}, - }, - }, - } - result := buildFilter(filters) - - if result == nil { - t.Fatal("expected filter, got nil") - } - if len(result.Must) != 3 { - t.Errorf("expected 3 Must conditions, got %d", len(result.Must)) - } - if len(result.Should) != 1 { - t.Errorf("expected 1 Should condition, got %d", len(result.Should)) - } - if len(result.MustNot) != 2 { - t.Errorf("expected 2 MustNot conditions, got %d", len(result.MustNot)) - } -} - -// Helper function for string contains check -func contains(s, substr string) bool { - return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr)) -} - -func containsHelper(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false -} diff --git a/v1/qdrant/operations.go b/v1/qdrant/operations.go index fbf3e2b..4c6b107 100644 --- a/v1/qdrant/operations.go +++ b/v1/qdrant/operations.go @@ -6,452 +6,197 @@ import ( "log" "slices" + "github.com/Aleph-Alpha/std/v1/vectordb" qdrant "github.com/qdrant/go-client/qdrant" "golang.org/x/sync/errgroup" "golang.org/x/sync/semaphore" ) -// EnsureCollection ────────────────────────────────────────────────────────────── -// EnsureCollection -// ────────────────────────────────────────────────────────────── -// -// EnsureCollection verifies if a given collection exists, and creates it if missing. -// -// It’s safe to call this multiple times — if the collection already exists, -// the function exits early. This pattern simplifies startup logic for embedding -// services that may bootstrap their own Qdrant collections. -func (c *QdrantClient) EnsureCollection(ctx context.Context, name string) error { - if name == "" { - return fmt.Errorf("collection name cannot be empty") - } +// Ensure VectorDBAdapter implements VectorDBService at compile time +var _ vectordb.VectorDBService = (*Adapter)(nil) - collections, err := c.api.ListCollections(ctx) - if err != nil { - return fmt.Errorf("[Qdrant] failed to list collections: %w", err) - } - - if slices.Contains(collections, name) { - log.Printf("[Qdrant] Collection '%s' already exists", name) - return nil - } +// ══════════════════════════════════════════════════════════════════════════════ +// Adapter - implements vectordb.VectorDBService interface +// ══════════════════════════════════════════════════════════════════════════════ - log.Printf("[Qdrant] Collection '%s' not found, creating it...", name) - - req := &qdrant.CreateCollection{ - CollectionName: name, - VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{ - Size: 1536, // default dimension (model-dependent) - Distance: qdrant.Distance_Cosine, // cosine similarity - }), - } - - if err := c.api.CreateCollection(ctx, req); err != nil { - return fmt.Errorf("[Qdrant] failed to create collection '%s': %w", name, err) - } - - log.Printf("[Qdrant] Created collection '%s' successfully", name) - return nil -} - -// Insert ────────────────────────────────────────────────────────────── -// Insert -// ────────────────────────────────────────────────────────────── +// VectorDBAdapter implements vectordb.VectorDBService for Qdrant. +// It wraps a Qdrant client and converts between generic vectordb types +// and Qdrant-specific protobuf types. // -// Insert adds a single embedding to Qdrant. -// -// Internally, it reuses the BatchInsert logic to ensure consistent handling -// of payload serialization and error management. -func (c *QdrantClient) Insert(ctx context.Context, collectionName string, input EmbeddingInput) error { - return c.BatchInsert(ctx, collectionName, []EmbeddingInput{input}) +// This is the recommended way to use Qdrant - it provides a database-agnostic +// interface that allows switching between different vector databases. +type Adapter struct { + client *qdrant.Client } -// BatchInsert ────────────────────────────────────────────────────────────── -// BatchInsert -// ────────────────────────────────────────────────────────────── -// -// BatchInsert efficiently inserts multiple embeddings in batches -// to reduce network overhead. +// NewVectorDBAdapter creates a new Qdrant adapter for the vectordb interface. +// Pass the underlying SDK client via QdrantClient.API(). // -// This method is safe to call for large datasets — it will automatically -// split inserts into smaller chunks (`defaultBatchSize`) and perform -// multiple upserts sequentially. +// Example: // -// Logs batch indices and collection name for debugging. -func (c *QdrantClient) BatchInsert(ctx context.Context, collectionName string, inputs []EmbeddingInput) error { - if len(inputs) == 0 { - return nil - } +// qc, _ := qdrant.NewQdrantClient(params) +// adapter := qdrant.NewVectorDBAdapter(qc.API()) +// var db vectordb.VectorDBService = adapter +func NewAdapter(client *qdrant.Client) *Adapter { + return &Adapter{client: client} +} - if collectionName == "" { - return fmt.Errorf("collection name cannot be empty") +// Search performs similarity search across one or more requests. +func (a *Adapter) Search(ctx context.Context, requests ...vectordb.SearchRequest) ([][]vectordb.SearchResult, error) { + if len(requests) == 0 { + return nil, fmt.Errorf("at least one search request is required") } - // Convert all inputs into internal embeddings - embeddings := make([]Embedding, len(inputs)) - for i, in := range inputs { - embeddings[i] = NewEmbedding(in) - } + log.Printf("[Qdrant] Starting search batch with %d requests", len(requests)) - for start := 0; start < len(embeddings); start += defaultBatchSize { - end := start + defaultBatchSize - if end > len(embeddings) { - end = len(embeddings) + // Validate + for i, req := range requests { + if req.CollectionName == "" { + return nil, fmt.Errorf("request [%d]: collection name cannot be empty", i) } - batch := embeddings[start:end] - - if err := c.upsertBatch(ctx, batch, collectionName); err != nil { - return fmt.Errorf("[Qdrant] batch upsert failed at [%d:%d]: %w", start, end, err) + if len(req.Vector) == 0 { + return nil, fmt.Errorf("request [%d]: vector cannot be empty", i) + } + if req.TopK <= 0 { + return nil, fmt.Errorf("request [%d]: topK must be greater than 0", i) } - log.Printf("[Qdrant] Inserted batch [%d:%d] (collection=%s)", start, end, collectionName) } - return nil -} + results := make([][]vectordb.SearchResult, len(requests)) + g, ctx := errgroup.WithContext(ctx) + sem := semaphore.NewWeighted(maxConcurrentSearches) -// ────────────────────────────────────────────────────────────── -// upsertBatch -// ────────────────────────────────────────────────────────────── -// -// upsertBatch sends a single `Upsert` request for a slice of embeddings. -// -// Converts Embedding structs into Qdrant’s `PointStruct` objects and -// triggers a blocking insert (`Wait=true`) to ensure data persistence -// before returning. -func (c *QdrantClient) upsertBatch(ctx context.Context, batch []Embedding, collectionName string) error { - points := make([]*qdrant.PointStruct, 0, len(batch)) - for _, e := range batch { - points = append(points, &qdrant.PointStruct{ - Id: qdrant.NewID(e.ID), - Vectors: qdrant.NewVectors(e.Vector...), - Payload: qdrant.NewValueMap(e.Meta), + for i, req := range requests { + i, req := i, req + g.Go(func() error { + if err := sem.Acquire(ctx, 1); err != nil { + return fmt.Errorf("request [%d]: failed to acquire semaphore: %w", i, err) + } + defer sem.Release(1) + + res, err := searchInternal(ctx, a.client, req) + if err != nil { + return fmt.Errorf("request [%d]: search failed: %w", i, err) + } + results[i] = res + log.Printf("[Qdrant] Search request [%d] returned %d results", i, len(res)) + return nil }) } - wait := true - req := &qdrant.UpsertPoints{ - CollectionName: collectionName, - Points: points, - Wait: &wait, + if err := g.Wait(); err != nil { + return nil, fmt.Errorf("search batch failed: %w", err) } + return results, nil +} - if _, err := c.api.Upsert(ctx, req); err != nil { - return fmt.Errorf("[Qdrant] upsert failed: %w", err) +// Insert adds embeddings to a collection using batch processing. +func (a *Adapter) Insert(ctx context.Context, collection string, inputs []vectordb.EmbeddingInput) error { + if len(inputs) == 0 { + return nil } - return nil + if collection == "" { + return fmt.Errorf("collection name cannot be empty") + } + return insertInternal(ctx, a.client, collection, inputs) } -// GetCollection ────────────────────────────────────────────────────────────── -// GetCollection -// ────────────────────────────────────────────────────────────── -// -// GetCollection retrieves detailed metadata about a specific collection -// from the connected Qdrant instance. -// -// It returns a high-level, decoupled `Collection` struct containing -// core details such as: -// • Collection name -// • Status (e.g., "Green", "Yellow") -// • Total vectors and points -// • Vector size (embedding dimension) -// • Distance metric (e.g., "Cosine", "Dot", "Euclid") -// -// This abstraction intentionally hides Qdrant SDK internals (`qdrant.CollectionInfo`) -// so that the application layer remains independent of Qdrant’s client library. -// -// Example: -// -// collection, err := client.GetCollection(ctx, "my_collection") -// if err != nil { -// log.Printf("Failed to fetch collection info: %v", err) -// return -// } -// log.Printf("Collection '%s': vectors=%d, points=%d, vector_size=%d, distance=%s", -// collection.Name, -// collection.Vectors, -// collection.Points, -// collection.VectorSize, -// collection.Distance, -// ) - -func (c *QdrantClient) GetCollection(ctx context.Context, name string) (*Collection, error) { - if c.api == nil { - return nil, fmt.Errorf("[Qdrant] client not initialized") - } +// Delete removes points by their IDs from a collection. +func (a *Adapter) Delete(ctx context.Context, collection string, ids []string) error { + return deleteInternal(ctx, a.client, collection, ids) +} + +// EnsureCollection creates a collection if it doesn't exist. +func (a *Adapter) EnsureCollection(ctx context.Context, name string, vectorSize uint64) error { + return ensureCollectionInternal(ctx, a.client, name, vectorSize) +} +// GetCollection retrieves metadata about a collection. +func (a *Adapter) GetCollection(ctx context.Context, name string) (*vectordb.Collection, error) { if name == "" { return nil, fmt.Errorf("collection name cannot be empty") } - info, err := c.api.GetCollectionInfo(ctx, name) + info, err := a.client.GetCollectionInfo(ctx, name) if err != nil { - return nil, fmt.Errorf("[Qdrant] failed to get collection '%s': %w", name, err) + return nil, fmt.Errorf("failed to get collection '%s': %w", name, err) } size, distance := extractVectorDetails(info) - collection := &Collection{ - Name: name, - Status: info.Status.String(), - Vectors: derefUint64(info.IndexedVectorsCount), - Points: derefUint64(info.PointsCount), - VectorSize: size, - Distance: distance, - } - - return collection, nil + return &vectordb.Collection{ + Name: name, + Status: info.Status.String(), + VectorSize: size, + Distance: distance, + VectorCount: derefUint64(info.IndexedVectorsCount), + PointCount: derefUint64(info.PointsCount), + }, nil } -// ListCollections ────────────────────────────────────────────────────────────── -// ListCollections -// ────────────────────────────────────────────────────────────── -// -// ListCollections retrieves all existing collections from Qdrant and returns -// their names as a string slice. This can be extended to preload metadata -// using GetCollection for each name if needed. -// -// Example: -// -// names, err := client.ListCollections(ctx) -// if err != nil { -// log.Fatalf("failed to list collections: %v", err) -// } -// log.Printf("Found collections: %v", names) -func (c *QdrantClient) ListCollections(ctx context.Context) ([]string, error) { - if c.api == nil { - return nil, fmt.Errorf("[Qdrant] client not initialized") - } - - names, err := c.api.ListCollections(ctx) - if err != nil { - return nil, fmt.Errorf("[Qdrant] failed to list collections: %w", err) - } - - log.Printf("[Qdrant] Found %d collections", len(names)) - return names, nil +// ListCollections returns names of all collections. +func (a *Adapter) ListCollections(ctx context.Context) ([]string, error) { + return listCollectionsInternal(ctx, a.client) } -// Search ────────────────────────────────────────────────────────────── -// Search -// ────────────────────────────────────────────────────────────── -// -// Search performs a similarity search in the configured collection. -// -// Parameters: -// - collectionName — the collection to search in -// - vector — the query embedding to search against. -// - topK — maximum number of nearest results to return. -// -// Returns: -// -// A slice of `SearchResultInterface` instances representing the -// most similar stored embeddings. -// func (c *QdrantClient) Search(ctx context.Context, collectionName string, vector []float32, topK int) ([]SearchResultInterface, error) { -// if err := validateSearchInput(collectionName, vector, topK); err != nil { -// return nil, err -// } - -// limit := uint64(topK) -// req := &qdrant.QueryPoints{ -// CollectionName: collectionName, -// Query: qdrant.NewQuery(vector...), -// Limit: &limit, -// WithPayload: qdrant.NewWithPayload(true), -// } - -// resp, err := c.api.Query(ctx, req) -// if err != nil { -// return nil, fmt.Errorf("[Qdrant] search failed: %w", err) -// } - -// results, err := c.parseSearchResults(resp) -// if err != nil { -// return nil, err -// } - -// log.Printf("[Qdrant] Search returned %d results", len(results)) -// return results, nil -// } - -// SearchWithFilter ────────────────────────────────────────────────────────────── -// SearchWithFilter -// ────────────────────────────────────────────────────────────── -// -// SearchWithFilter performs a similarity search with required filters (AND logic). -// Returns error if filters is nil or empty - use Search() for unfiltered searches. -// -// Parameters: -// - collectionName — the collection to search in -// - vector — the query embedding to search against -// - topK — maximum number of nearest results to return -// - filters — required key-value filters (all must match) -// -// Returns: -// -// A slice of SearchResultInterface instances matching the filter criteria. -// func (c *QdrantClient) SearchWithFilter(ctx context.Context, collectionName string, vector []float32, topK int, filters map[string]string) ([]SearchResultInterface, error) { -// if err := validateSearchInput(collectionName, vector, topK); err != nil { -// return nil, err -// } - -// if len(filters) == 0 { -// return nil, fmt.Errorf("filters cannot be empty, use Search() for unfiltered searches") -// } - -// limit := uint64(topK) -// req := &qdrant.QueryPoints{ -// CollectionName: collectionName, -// Query: qdrant.NewQuery(vector...), -// Limit: &limit, -// WithPayload: qdrant.NewWithPayload(true), -// Filter: buildFilter(filters), -// } - -// resp, err := c.api.Query(ctx, req) -// if err != nil { -// return nil, fmt.Errorf("[Qdrant] search failed: %w", err) -// } - -// results, err := c.parseSearchResults(resp) -// if err != nil { -// return nil, err -// } - -// log.Printf("[Qdrant] SearchWithFilter returned %d results", len(results)) -// return results, nil -// } - -// executeSearch performs a single search request against Qdrant -func (c *QdrantClient) executeSearch(ctx context.Context, searchReq SearchRequest) ([]SearchResultInterface, error) { - limit := uint64(searchReq.TopK) - req := &qdrant.QueryPoints{ - CollectionName: searchReq.CollectionName, - Query: qdrant.NewQuery(searchReq.Vector...), +// ══════════════════════════════════════════════════════════════════════════════ +// Internal Functions +// ══════════════════════════════════════════════════════════════════════════════ + +func searchInternal(ctx context.Context, client *qdrant.Client, req vectordb.SearchRequest) ([]vectordb.SearchResult, error) { + limit := uint64(req.TopK) + queryReq := &qdrant.QueryPoints{ + CollectionName: req.CollectionName, + Query: qdrant.NewQuery(req.Vector...), Limit: &limit, WithPayload: qdrant.NewWithPayload(true), - Filter: buildFilter(searchReq.Filters), + Filter: convertVectorDBFilterSet(req.Filters), } - resp, err := c.api.Query(ctx, req) + resp, err := client.Query(ctx, queryReq) if err != nil { - return nil, fmt.Errorf("search failed: %w", err) + return nil, fmt.Errorf("query failed: %w", err) } - - return c.parseSearchResults(resp) + return parseVectorDBSearchResults(resp, req.CollectionName) } -// Search ────────────────────────────────────────────────────────────── -// Search -// ────────────────────────────────────────────────────────────── -// -// Search performs multiple searches and returns results for each request. -// Each request can optionally include filters. -// -// Parameters: -// - requests — variadic SearchRequest structs, each containing: -// - CollectionName: the collection to search in -// - Vector: the query embedding -// - TopK: maximum number of results per request -// - Filters: optional key-value filters (AND logic) -// -// Returns: -// -// A slice of result slices — one []SearchResultInterface per request. -// -// Example: -// -// results, err := client.Search(ctx, -// SearchRequest{CollectionName: "docs", Vector: vec1, TopK: 10, Filters: map[string]string{"partition_id": "store-A"}}, -// SearchRequest{CollectionName: "docs", Vector: vec2, TopK: 5}, -// ) -// // results[0] = results for first request -// // results[1] = results for second request -func (c *QdrantClient) Search(ctx context.Context, requests ...SearchRequest) ([][]SearchResultInterface, error) { - if len(requests) == 0 { - return nil, fmt.Errorf("at least one search request is required") - } - - log.Printf("[Qdrant] Starting search batch with %d requests", len(requests)) - - // Validate all requests first (fail fast) - for i, searchReq := range requests { - if err := validateSearchInput(searchReq.CollectionName, searchReq.Vector, searchReq.TopK); err != nil { - return nil, fmt.Errorf("request [%d]: %w", i, err) +func insertInternal(ctx context.Context, client *qdrant.Client, collection string, inputs []vectordb.EmbeddingInput) error { + for start := 0; start < len(inputs); start += defaultBatchSize { + end := start + defaultBatchSize + if end > len(inputs) { + end = len(inputs) + } + batch := inputs[start:end] + + points := make([]*qdrant.PointStruct, 0, len(batch)) + for _, e := range batch { + points = append(points, &qdrant.PointStruct{ + Id: qdrant.NewID(e.ID), + Vectors: qdrant.NewVectors(e.Vector...), + Payload: qdrant.NewValueMap(e.Payload), + }) } - } - - results := make([][]SearchResultInterface, len(requests)) - - // Create errgroup with context - g, ctx := errgroup.WithContext(ctx) - - // Semaphore to limit concurrency - sem := semaphore.NewWeighted(maxConcurrentSearches) - - for i, searchReq := range requests { - i, searchReq := i, searchReq // Capture loop variables - - g.Go(func() error { - // Acquire semaphore (blocks if at max concurrency) - if err := sem.Acquire(ctx, 1); err != nil { - return fmt.Errorf("request [%d]: failed to acquire semaphore: %w", i, err) - } - defer sem.Release(1) - - res, err := c.executeSearch(ctx, searchReq) - if err != nil { - return fmt.Errorf("request [%d]: search failed: %w", i, err) - } - - results[i] = res - log.Printf("[Qdrant] Search request [%d] returned %d results", i, len(res)) - return nil - }) - } - - if err := g.Wait(); err != nil { - return nil, fmt.Errorf("search batch failed: %w", err) - } - - return results, nil -} -// parseSearchResults converts Qdrant response to SearchResultInterface slice -func (c *QdrantClient) parseSearchResults(resp []*qdrant.ScoredPoint) ([]SearchResultInterface, error) { - results := make([]SearchResultInterface, 0, len(resp)) - for _, r := range resp { - var id string - switch v := r.Id.PointIdOptions.(type) { - case *qdrant.PointId_Num: - id = fmt.Sprintf("%d", v.Num) - case *qdrant.PointId_Uuid: - id = v.Uuid - default: - return nil, fmt.Errorf("[Qdrant] unexpected PointId type: %T", v) + wait := true + req := &qdrant.UpsertPoints{ + CollectionName: collection, + Points: points, + Wait: &wait, } - results = append(results, SearchResult{ - ID: id, - Score: r.Score, - Meta: r.Payload, - }) + if _, err := client.Upsert(ctx, req); err != nil { + return fmt.Errorf("[Qdrant] batch upsert failed at [%d:%d]: %w", start, end, err) + } + log.Printf("[Qdrant] Inserted batch [%d:%d] (collection=%s)", start, end, collection) } - return results, nil + return nil } -// Delete ────────────────────────────────────────────────────────────── -// Delete -// ────────────────────────────────────────────────────────────── -// -// Delete removes embeddings from a collection by their IDs. -// -// It constructs a `DeletePoints` request containing a list of `PointId`s, -// waits synchronously for completion, and logs the operation status. -func (c *QdrantClient) Delete(ctx context.Context, collectionName string, ids []string) error { +func deleteInternal(ctx context.Context, client *qdrant.Client, collection string, ids []string) error { if len(ids) == 0 { return nil } - - if collectionName == "" { + if collection == "" { return fmt.Errorf("collection name cannot be empty") } @@ -462,7 +207,7 @@ func (c *QdrantClient) Delete(ctx context.Context, collectionName string, ids [] wait := true req := &qdrant.DeletePoints{ - CollectionName: collectionName, + CollectionName: collection, Points: &qdrant.PointsSelector{ PointsSelectorOneOf: &qdrant.PointsSelector_Points{ Points: &qdrant.PointsIdsList{Ids: qdrantIDs}, @@ -471,12 +216,53 @@ func (c *QdrantClient) Delete(ctx context.Context, collectionName string, ids [] Wait: &wait, } - resp, err := c.api.Delete(ctx, req) + resp, err := client.Delete(ctx, req) if err != nil { return fmt.Errorf("[Qdrant] delete failed: %w", err) } - log.Printf("[Qdrant] Delete completed (status=%s, collection=%s)", - resp.Status.String(), collectionName) + log.Printf("[Qdrant] Delete completed (status=%s, collection=%s)", resp.Status.String(), collection) + return nil +} + +func ensureCollectionInternal(ctx context.Context, client *qdrant.Client, name string, vectorSize uint64) error { + if name == "" { + return fmt.Errorf("collection name cannot be empty") + } + + collections, err := client.ListCollections(ctx) + if err != nil { + return fmt.Errorf("[Qdrant] failed to list collections: %w", err) + } + + if slices.Contains(collections, name) { + log.Printf("[Qdrant] Collection '%s' already exists", name) + return nil + } + + log.Printf("[Qdrant] Collection '%s' not found, creating it...", name) + + req := &qdrant.CreateCollection{ + CollectionName: name, + VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{ + Size: vectorSize, + Distance: qdrant.Distance_Cosine, + }), + } + + if err := client.CreateCollection(ctx, req); err != nil { + return fmt.Errorf("[Qdrant] failed to create collection '%s': %w", name, err) + } + + log.Printf("[Qdrant] Created collection '%s' successfully", name) return nil } + +func listCollectionsInternal(ctx context.Context, client *qdrant.Client) ([]string, error) { + names, err := client.ListCollections(ctx) + if err != nil { + return nil, fmt.Errorf("[Qdrant] failed to list collections: %w", err) + } + log.Printf("[Qdrant] Found %d collections", len(names)) + return names, nil +} diff --git a/v1/qdrant/qdrant_integration_test.go b/v1/qdrant/qdrant_integration_test.go index 100e707..4e5fed3 100644 --- a/v1/qdrant/qdrant_integration_test.go +++ b/v1/qdrant/qdrant_integration_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/Aleph-Alpha/std/v1/vectordb" "github.com/docker/docker/api/types/container" "github.com/docker/go-connections/nat" "github.com/stretchr/testify/assert" @@ -189,43 +190,46 @@ func TestQdrantWithFXModule(t *testing.T) { err = qdrantClient.healthCheck() assert.NoError(t, err) + // Create adapter for operations + adapter := NewAdapter(qdrantClient.api) + // Test collection operations t.Run("EnsureCollection", func(t *testing.T) { // First call should create the collection - err := qdrantClient.EnsureCollection(ctx, "test_collection_1") + err := adapter.EnsureCollection(ctx, "test_collection_1", 1536) assert.NoError(t, err) // Second call should be idempotent - err = qdrantClient.EnsureCollection(ctx, "test_collection_1") + err = adapter.EnsureCollection(ctx, "test_collection_1", 1536) assert.NoError(t, err) // Empty collection name should fail - err = qdrantClient.EnsureCollection(ctx, "") + err = adapter.EnsureCollection(ctx, "", 1536) assert.Error(t, err) }) // Test basic CRUD operations t.Run("BasicCRUDOperations", func(t *testing.T) { collectionName := "test_crud" - err := qdrantClient.EnsureCollection(ctx, collectionName) + err := adapter.EnsureCollection(ctx, collectionName, 1536) require.NoError(t, err) - // Insert single embedding (use numeric ID or UUID format) - embedding := EmbeddingInput{ - ID: "00000000-0000-0000-0000-000000000001", // UUID format + // Insert single embedding (use UUID format) + embedding := vectordb.EmbeddingInput{ + ID: "00000000-0000-0000-0000-000000000001", Vector: generateRandomVector(1536), - Meta: map[string]any{ + Payload: map[string]any{ "title": "Test Document 1", "content": "This is a test document", }, } - err = qdrantClient.Insert(ctx, collectionName, embedding) + err = adapter.Insert(ctx, collectionName, []vectordb.EmbeddingInput{embedding}) assert.NoError(t, err) // Search for the inserted embedding time.Sleep(1 * time.Second) // Allow time for indexing - batchResults, err := qdrantClient.Search(ctx, SearchRequest{ + batchResults, err := adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: collectionName, Vector: embedding.Vector, TopK: 5, @@ -237,28 +241,28 @@ func TestQdrantWithFXModule(t *testing.T) { // Verify the result if len(results) > 0 { - assert.Equal(t, embedding.ID, results[0].GetID()) - assert.Greater(t, results[0].GetScore(), float32(0.9)) // Should be very similar + assert.Equal(t, embedding.ID, results[0].ID) + assert.Greater(t, results[0].Score, float32(0.9)) // Should be very similar } // Delete the embedding - err = qdrantClient.Delete(ctx, collectionName, []string{embedding.ID}) + err = adapter.Delete(ctx, collectionName, []string{embedding.ID}) assert.NoError(t, err) }) // Test batch insert t.Run("BatchInsert", func(t *testing.T) { collectionName := "test_batch" - err := qdrantClient.EnsureCollection(ctx, collectionName) + err := adapter.EnsureCollection(ctx, collectionName, 1536) require.NoError(t, err) // Create multiple embeddings (use UUID format) - embeddings := make([]EmbeddingInput, 10) + embeddings := make([]vectordb.EmbeddingInput, 10) for i := 0; i < 10; i++ { - embeddings[i] = EmbeddingInput{ + embeddings[i] = vectordb.EmbeddingInput{ ID: fmt.Sprintf("00000000-0000-0000-0000-%012d", i+1), Vector: generateRandomVector(1536), - Meta: map[string]any{ + Payload: map[string]any{ "title": fmt.Sprintf("Document %d", i), "index": i, }, @@ -266,12 +270,12 @@ func TestQdrantWithFXModule(t *testing.T) { } // Batch insert - err = qdrantClient.BatchInsert(ctx, collectionName, embeddings) + err = adapter.Insert(ctx, collectionName, embeddings) assert.NoError(t, err) // Search and verify time.Sleep(1 * time.Second) // Allow time for indexing - batchResults, err := qdrantClient.Search(ctx, SearchRequest{ + batchResults, err := adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: collectionName, Vector: embeddings[0].Vector, TopK: 10, @@ -284,22 +288,22 @@ func TestQdrantWithFXModule(t *testing.T) { for i, emb := range embeddings { ids[i] = emb.ID } - err = qdrantClient.Delete(ctx, collectionName, ids) + err = adapter.Delete(ctx, collectionName, ids) assert.NoError(t, err) }) // Test empty operations t.Run("EmptyOperations", func(t *testing.T) { collectionName := "test_empty" - err := qdrantClient.EnsureCollection(ctx, collectionName) + err := adapter.EnsureCollection(ctx, collectionName, 1536) require.NoError(t, err) // Empty batch insert should be no-op - err = qdrantClient.BatchInsert(ctx, collectionName, []EmbeddingInput{}) + err = adapter.Insert(ctx, collectionName, []vectordb.EmbeddingInput{}) assert.NoError(t, err) // Empty delete should be no-op - err = qdrantClient.Delete(ctx, collectionName, []string{}) + err = adapter.Delete(ctx, collectionName, []string{}) assert.NoError(t, err) }) @@ -307,8 +311,8 @@ func TestQdrantWithFXModule(t *testing.T) { require.NoError(t, app.Stop(ctx)) } -// TestQdrantClientOperations tests various client operations -func TestQdrantClientOperations(t *testing.T) { +// TestVectorDBAdapterOperations tests various adapter operations +func TestVectorDBAdapterOperations(t *testing.T) { // Skip if running in short mode if testing.Short() { t.Skip("Skipping integration test in short mode") @@ -339,20 +343,23 @@ func TestQdrantClientOperations(t *testing.T) { require.NotNil(t, client) defer client.Close() + // Create adapter + adapter := NewAdapter(client.api) + collectionName := "test_operations" // Ensure collection exists - err = client.EnsureCollection(ctx, collectionName) + err = adapter.EnsureCollection(ctx, collectionName, 1536) require.NoError(t, err) t.Run("GetCollectionByName", func(t *testing.T) { // Fetch collection info using GetCollection - col, err := client.GetCollection(ctx, collectionName) + col, err := adapter.GetCollection(ctx, collectionName) assert.NoError(t, err, "expected GetCollection to succeed") assert.NotNil(t, col, "expected non-nil collection info") // Validate expected metadata fields - assert.GreaterOrEqual(t, int(col.Vectors), 0, "vector count should be >= 0") - assert.GreaterOrEqual(t, int(col.Points), 0, "points count should be >= 0") + assert.GreaterOrEqual(t, int(col.VectorCount), 0, "vector count should be >= 0") + assert.GreaterOrEqual(t, int(col.PointCount), 0, "points count should be >= 0") // Validate vector config details (size and distance) assert.NotZero(t, col.VectorSize, "vector size should not be zero") @@ -362,8 +369,8 @@ func TestQdrantClientOperations(t *testing.T) { t.Logf("Collection '%s': status=%s, vectors=%d, points=%d, vectorSize=%d, distance=%s", col.Name, col.Status, - col.Vectors, - col.Points, + col.VectorCount, + col.PointCount, col.VectorSize, col.Distance, ) @@ -371,22 +378,22 @@ func TestQdrantClientOperations(t *testing.T) { t.Run("SearchReturnsTopK", func(t *testing.T) { // Insert multiple embeddings (use UUID format) - embeddings := make([]EmbeddingInput, 20) + embeddings := make([]vectordb.EmbeddingInput, 20) for i := 0; i < 20; i++ { - embeddings[i] = EmbeddingInput{ - ID: fmt.Sprintf("00000000-0000-0000-0001-%012d", i+1), - Vector: generateRandomVector(1536), - Meta: map[string]any{"index": i}, + embeddings[i] = vectordb.EmbeddingInput{ + ID: fmt.Sprintf("00000000-0000-0000-0001-%012d", i+1), + Vector: generateRandomVector(1536), + Payload: map[string]any{"index": i}, } } - err := client.BatchInsert(ctx, collectionName, embeddings) + err := adapter.Insert(ctx, collectionName, embeddings) require.NoError(t, err) time.Sleep(1 * time.Second) // Allow time for indexing // Search with topK = 5 - batchResults, err := client.Search(ctx, SearchRequest{ + batchResults, err := adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: collectionName, Vector: embeddings[0].Vector, TopK: 5, @@ -395,7 +402,7 @@ func TestQdrantClientOperations(t *testing.T) { assert.LessOrEqual(t, len(batchResults[0]), 5) // Search with topK = 10 - batchResults, err = client.Search(ctx, SearchRequest{ + batchResults, err = adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: collectionName, Vector: embeddings[0].Vector, TopK: 10, @@ -408,30 +415,30 @@ func TestQdrantClientOperations(t *testing.T) { for i, emb := range embeddings { ids[i] = emb.ID } - err = client.Delete(ctx, collectionName, ids) + err = adapter.Delete(ctx, collectionName, ids) assert.NoError(t, err) }) t.Run("SearchWithMetadata", func(t *testing.T) { // Insert embedding with rich metadata (UUID format, simple types only) - embedding := EmbeddingInput{ + embedding := vectordb.EmbeddingInput{ ID: "00000000-0000-0000-0002-000000000001", Vector: generateRandomVector(1536), - Meta: map[string]any{ + Payload: map[string]any{ "title": "Test Title", "author": "Test Author", "timestamp": time.Now().Unix(), - "category": "test", // Use simple types instead of arrays + "category": "test", }, } - err := client.Insert(ctx, collectionName, embedding) + err := adapter.Insert(ctx, collectionName, []vectordb.EmbeddingInput{embedding}) require.NoError(t, err) time.Sleep(1 * time.Second) // Search and verify metadata - batchResults, err := client.Search(ctx, SearchRequest{ + batchResults, err := adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: collectionName, Vector: embedding.Vector, TopK: 1, @@ -440,39 +447,39 @@ func TestQdrantClientOperations(t *testing.T) { assert.Greater(t, len(batchResults[0]), 0) if len(batchResults[0]) > 0 { - meta := batchResults[0][0].GetMeta() - assert.NotNil(t, meta) + payload := batchResults[0][0].Payload + assert.NotNil(t, payload) } // Clean up - err = client.Delete(ctx, collectionName, []string{embedding.ID}) + err = adapter.Delete(ctx, collectionName, []string{embedding.ID}) assert.NoError(t, err) }) t.Run("LargeBatchInsert", func(t *testing.T) { collectionName := "test_large_batch" - err := client.EnsureCollection(ctx, collectionName) + err := adapter.EnsureCollection(ctx, collectionName, 1536) require.NoError(t, err) // Create a large batch (more than defaultBatchSize, use UUID format) largeCount := 500 - embeddings := make([]EmbeddingInput, largeCount) + embeddings := make([]vectordb.EmbeddingInput, largeCount) for i := 0; i < largeCount; i++ { - embeddings[i] = EmbeddingInput{ - ID: fmt.Sprintf("00000000-0000-0000-0003-%012d", i+1), - Vector: generateRandomVector(1536), - Meta: map[string]any{"index": i}, + embeddings[i] = vectordb.EmbeddingInput{ + ID: fmt.Sprintf("00000000-0000-0000-0003-%012d", i+1), + Vector: generateRandomVector(1536), + Payload: map[string]any{"index": i}, } } // Should handle batching automatically - err = client.BatchInsert(ctx, collectionName, embeddings) + err = adapter.Insert(ctx, collectionName, embeddings) assert.NoError(t, err) time.Sleep(2 * time.Second) // Verify some embeddings exist - batchResults, err := client.Search(ctx, SearchRequest{ + batchResults, err := adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: collectionName, Vector: embeddings[0].Vector, TopK: 10, @@ -485,7 +492,7 @@ func TestQdrantClientOperations(t *testing.T) { for i, emb := range embeddings { ids[i] = emb.ID } - err = client.Delete(ctx, collectionName, ids) + err = adapter.Delete(ctx, collectionName, ids) assert.NoError(t, err) }) } @@ -522,6 +529,8 @@ func TestQdrantErrorHandling(t *testing.T) { require.NotNil(t, client) defer client.Close() + adapter := NewAdapter(client.api) + t.Run("InvalidEndpoint", func(t *testing.T) { invalidCfg := &Config{ Endpoint: "invalid-host:9999", @@ -534,14 +543,14 @@ func TestQdrantErrorHandling(t *testing.T) { }) t.Run("EmptyCollectionName", func(t *testing.T) { - err := client.EnsureCollection(ctx, "") + err := adapter.EnsureCollection(ctx, "", 1536) assert.Error(t, err) assert.Contains(t, err.Error(), "collection name cannot be empty") }) t.Run("SearchOnNonExistentCollection", func(t *testing.T) { vector := generateRandomVector(1536) - _, err := client.Search(ctx, SearchRequest{ + _, err := adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: "non_existent_collection", Vector: vector, TopK: 5, @@ -586,9 +595,10 @@ func TestQdrantLifecycleAndHealthCheck(t *testing.T) { err = client.healthCheck() require.NoError(t, err, "Qdrant health check failed") - // Ensure collection exists + // Create adapter and ensure collection exists + adapter := NewAdapter(client.api) collectionName := "test_collection" - err = client.EnsureCollection(context.Background(), collectionName) + err = adapter.EnsureCollection(context.Background(), collectionName, 1536) require.NoError(t, err, "failed to ensure collection") // Close client diff --git a/v1/qdrant/utils.go b/v1/qdrant/utils.go index 9909631..6815a6e 100644 --- a/v1/qdrant/utils.go +++ b/v1/qdrant/utils.go @@ -3,6 +3,7 @@ package qdrant import ( "fmt" + "github.com/Aleph-Alpha/std/v1/vectordb" qdrant "github.com/qdrant/go-client/qdrant" ) @@ -100,7 +101,7 @@ type SearchRequest struct { CollectionName string Vector []float32 TopK int - Filters *FilterSet // Optional: key-value filters + Filters *vectordb.FilterSet // Optional: key-value filters } // validateSearchInput validates common search parameters diff --git a/v1/vectordb/doc.go b/v1/vectordb/doc.go new file mode 100644 index 0000000..fb91483 --- /dev/null +++ b/v1/vectordb/doc.go @@ -0,0 +1,138 @@ +// Package vectordb provides a database-agnostic abstraction for vector similarity search. +// +// # Overview +// +// This package defines a common interface [VectorDBService] that can be implemented +// by different vector database adapters (Qdrant, Weaviate, Pinecone, etc.), allowing +// applications to switch between databases without changing application code. +// +// # Architecture +// +// ┌─────────────────────────────────────────────────────────────┐ +// │ Application Layer │ +// │ (uses vectordb.VectorDBService - no DB-specific imports) │ +// └──────────────────────────┬──────────────────────────────────┘ +// │ +// ▼ +// ┌─────────────────────────────────────────────────────────────┐ +// │ vectordb.VectorDBService │ +// │ (common interface + DB-agnostic types) │ +// └──────────────────────────┬──────────────────────────────────┘ +// │ +// ┌──────────────────┼──────────────────┐ +// ▼ ▼ ▼ +// ┌───────────────┐ ┌───────────────┐ ┌───────────────┐ +// │ qdrant.Adapter│ │weaviate.Adapter│ │pinecone.Adapter│ +// │ (implements) │ │ (implements) │ │ (implements) │ +// └───────────────┘ └───────────────┘ └───────────────┘ +// +// # Benefits +// +// - Single Source of Truth: Filter types, search interfaces, and result types defined once. +// - Easy to Add New DBs: Just add a new adapter; consuming projects don't change. +// - Consistent API: All projects using std get the same interface. +// - Testability: Mock the interface once, works for all DBs. +// +// # Usage +// +// In your application, depend only on the vectordb interface: +// +// import "github.com/Aleph-Alpha/std/v1/vectordb" +// +// type SearchService struct { +// db vectordb.VectorDBService +// } +// +// func NewSearchService(db vectordb.VectorDBService) *SearchService { +// return &SearchService{db: db} +// } +// +// func (s *SearchService) Search(ctx context.Context, query string, vector []float32) ([]vectordb.SearchResult, error) { +// results, err := s.db.Search(ctx, vectordb.SearchRequest{ +// CollectionName: "documents", +// Vector: vector, +// TopK: 10, +// Filters: &vectordb.FilterSet{ +// Must: &vectordb.ConditionSet{ +// Conditions: []vectordb.FilterCondition{ +// vectordb.NewMatch("status", "published"), +// }, +// }, +// }, +// }) +// if err != nil { +// return nil, err +// } +// return results[0], nil +// } +// +// # Wire Up with Qdrant +// +// In your main or DI setup: +// +// import ( +// "github.com/Aleph-Alpha/std/v1/vectordb" +// "github.com/Aleph-Alpha/std/v1/qdrant" +// ) +// +// func main() { +// // Create Qdrant client (with health checks, config, etc.) +// qc, _ := qdrant.NewQdrantClient(qdrant.QdrantParams{ +// Config: &qdrant.Config{Endpoint: "localhost", Port: 6334}, +// }) +// +// // Create adapter for DB-agnostic usage +// db := qdrant.NewVectorDBAdapter(qc.API()) +// +// // Use in application +// svc := NewSearchService(db) +// // ... +// } +// +// # Package Layout +// +// vectordb/ +// ├── interface.go # VectorDBService interface +// ├── types.go # SearchRequest, SearchResult, EmbeddingInput, Collection +// ├── filters.go # FilterSet, FilterCondition, convenience constructors +// └── doc.go # This file +// +// qdrant/ # Qdrant package (includes adapter) +// ├── client.go # QdrantClient wrapper +// ├── adapter.go # VectorDBAdapter - implements VectorDBService +// ├── vectordb_converter.go # vectordb types → qdrant types +// ├── operations.go # Direct Qdrant operations +// ├── filters.go # Qdrant-specific filters +// └── ... +// +// Future adapters would live in their own packages: +// +// weaviate/ # (future) Weaviate adapter +// pinecone/ # (future) Pinecone adapter +// +// # Filter Types +// +// The package provides DB-agnostic filter conditions: +// +// | Type | Description | SQL Equivalent | +// |-----------------------|------------------------------|-----------------------------------| +// | MatchCondition | Exact value match | WHERE field = value | +// | MatchAnyCondition | Value in set | WHERE field IN (...) | +// | MatchExceptCondition | Value not in set | WHERE field NOT IN (...) | +// | NumericRangeCondition | Numeric range | WHERE field >= min AND field <= max| +// | TimeRangeCondition | Datetime range | WHERE created_at BETWEEN ... | +// | IsNullCondition | Field is null | WHERE field IS NULL | +// | IsEmptyCondition | Field is empty/null/missing | WHERE field IS NULL OR field = '' | +// +// Use convenience constructors for cleaner code: +// +// // Internal field (top-level in payload) +// vectordb.NewMatch("status", "published") +// +// // User-defined field (stored under "custom." prefix) +// vectordb.NewUserMatch("category", "research") +// +// // Range conditions +// vectordb.NewNumericRange("price", &min, &max) +// vectordb.NewTimeRange("created_at", &startTime, &endTime) +package vectordb diff --git a/v1/vectordb/filters.go b/v1/vectordb/filters.go new file mode 100644 index 0000000..0e156ad --- /dev/null +++ b/v1/vectordb/filters.go @@ -0,0 +1,240 @@ +package vectordb + +import ( + "encoding/json" + "time" +) + +// FieldType indicates whether a field is internal (system-managed) +// or user-defined (stored under a prefix like "custom."). +type FieldType int + +const ( + // InternalField - system-managed fields stored at top-level + InternalField FieldType = iota + // UserField - user-defined fields stored under a prefix (e.g., "custom.") + UserField +) + +// FilterCondition is the interface all filter conditions must implement. +// Each database adapter converts these to its native filter format. +type FilterCondition interface { + // isFilterCondition is a marker method to ensure type safety + isFilterCondition() +} + +// FilterSet supports Must (AND), Should (OR), and MustNot (NOT) clauses. +// Use with SearchRequest.Filters to filter search results. +// +// Example: +// +// filters := &FilterSet{ +// Must: &ConditionSet{ +// Conditions: []FilterCondition{ +// &MatchCondition{Field: "city", Value: "London"}, +// }, +// }, +// } +type FilterSet struct { + // Must: All conditions must match (AND) + Must *ConditionSet `json:"must,omitempty"` + // Should: At least one condition must match (OR) + Should *ConditionSet `json:"should,omitempty"` + // MustNot: None of the conditions should match (NOT) + MustNot *ConditionSet `json:"mustNot,omitempty"` +} + +// ConditionSet holds a group of conditions for a single clause. +type ConditionSet struct { + Conditions []FilterCondition `json:"conditions,omitempty"` +} + +// ── Match Conditions ───────────────────────────────────────────────────────── + +// MatchCondition represents an exact match filter (WHERE field = value). +// Supports string, bool, and int64 values. +type MatchCondition struct { + Field string `json:"field"` + Value any `json:"equalTo"` + FieldType FieldType `json:"-"` +} + +func (c *MatchCondition) isFilterCondition() {} + +// MatchAnyCondition matches if value is one of the given values (IN operator). +// SQL equivalent: WHERE field IN (value1, value2, ...) +type MatchAnyCondition struct { + Field string `json:"field"` + Values []any `json:"anyOf"` + FieldType FieldType `json:"-"` +} + +func (c *MatchAnyCondition) isFilterCondition() {} + +// MatchExceptCondition matches if value is NOT one of the given values (NOT IN). +// SQL equivalent: WHERE field NOT IN (value1, value2, ...) +type MatchExceptCondition struct { + Field string `json:"field"` + Values []any `json:"noneOf"` + FieldType FieldType `json:"-"` +} + +func (c *MatchExceptCondition) isFilterCondition() {} + +// ── Range Conditions ───────────────────────────────────────────────────────── + +// NumericRangeCondition filters by numeric range. +// SQL equivalent: WHERE field >= min AND field <= max +type NumericRangeCondition struct { + Field string `json:"field"` + GreaterThan *float64 `json:"greaterThan,omitempty"` + GreaterThanOrEqualTo *float64 `json:"greaterThanOrEqualTo,omitempty"` + LessThan *float64 `json:"lessThan,omitempty"` + LessThanOrEqualTo *float64 `json:"lessThanOrEqualTo,omitempty"` + FieldType FieldType `json:"-"` +} + +func (c *NumericRangeCondition) isFilterCondition() {} + +// TimeRangeCondition filters by datetime range. +// SQL equivalent: WHERE created_at >= '2024-01-01' AND created_at < '2025-01-01' +type TimeRangeCondition struct { + Field string `json:"field"` + After *time.Time `json:"after,omitempty"` + AtOrAfter *time.Time `json:"atOrAfter,omitempty"` + Before *time.Time `json:"before,omitempty"` + AtOrBefore *time.Time `json:"atOrBefore,omitempty"` + FieldType FieldType `json:"-"` +} + +func (c *TimeRangeCondition) isFilterCondition() {} + +// ── Null/Empty Conditions ──────────────────────────────────────────────────── + +// IsNullCondition checks if a field has a NULL value. +// SQL equivalent: WHERE field IS NULL +type IsNullCondition struct { + Field string `json:"field"` + FieldType FieldType `json:"-"` +} + +func (c *IsNullCondition) isFilterCondition() {} + +// IsEmptyCondition checks if a field is empty (doesn't exist, null, or []). +// SQL equivalent: WHERE field IS NULL OR field = ” OR field = [] +type IsEmptyCondition struct { + Field string `json:"field"` + FieldType FieldType `json:"-"` +} + +func (c *IsEmptyCondition) isFilterCondition() {} + +// ── Convenience Constructors ───────────────────────────────────────────────── + +// NewMatch creates a match condition for internal fields. +func NewMatch(field string, value any) *MatchCondition { + return &MatchCondition{Field: field, Value: value, FieldType: InternalField} +} + +// NewUserMatch creates a match condition for user-defined fields. +func NewUserMatch(field string, value any) *MatchCondition { + return &MatchCondition{Field: field, Value: value, FieldType: UserField} +} + +// NewMatchAny creates an IN condition for internal fields. +func NewMatchAny(field string, values ...any) *MatchAnyCondition { + return &MatchAnyCondition{Field: field, Values: values, FieldType: InternalField} +} + +// NewUserMatchAny creates an IN condition for user-defined fields. +func NewUserMatchAny(field string, values ...any) *MatchAnyCondition { + return &MatchAnyCondition{Field: field, Values: values, FieldType: UserField} +} + +// NewMatchExcept creates a NOT IN condition for internal fields. +func NewMatchExcept(field string, values ...any) *MatchExceptCondition { + return &MatchExceptCondition{Field: field, Values: values, FieldType: InternalField} +} + +// NewUserMatchExcept creates a NOT IN condition for user-defined fields. +func NewUserMatchExcept(field string, values ...any) *MatchExceptCondition { + return &MatchExceptCondition{Field: field, Values: values, FieldType: UserField} +} + +// NewNumericRange creates a numeric range condition for internal fields. +func NewNumericRange(field string, gte, lte *float64) *NumericRangeCondition { + return &NumericRangeCondition{ + Field: field, + GreaterThanOrEqualTo: gte, + LessThanOrEqualTo: lte, + FieldType: InternalField, + } +} + +// NewUserNumericRange creates a numeric range condition for user-defined fields. +func NewUserNumericRange(field string, gte, lte *float64) *NumericRangeCondition { + return &NumericRangeCondition{ + Field: field, + GreaterThanOrEqualTo: gte, + LessThanOrEqualTo: lte, + FieldType: UserField, + } +} + +// NewTimeRange creates a time range condition for internal fields. +func NewTimeRange(field string, atOrAfter, before *time.Time) *TimeRangeCondition { + return &TimeRangeCondition{ + Field: field, + AtOrAfter: atOrAfter, + Before: before, + FieldType: InternalField, + } +} + +// NewUserTimeRange creates a time range condition for user-defined fields. +func NewUserTimeRange(field string, atOrAfter, before *time.Time) *TimeRangeCondition { + return &TimeRangeCondition{ + Field: field, + AtOrAfter: atOrAfter, + Before: before, + FieldType: UserField, + } +} + +// NewIsNull creates an IS NULL condition for internal fields. +func NewIsNull(field string) *IsNullCondition { + return &IsNullCondition{Field: field, FieldType: InternalField} +} + +// NewUserIsNull creates an IS NULL condition for user-defined fields. +func NewUserIsNull(field string) *IsNullCondition { + return &IsNullCondition{Field: field, FieldType: UserField} +} + +// NewIsEmpty creates an IS EMPTY condition for internal fields. +func NewIsEmpty(field string) *IsEmptyCondition { + return &IsEmptyCondition{Field: field, FieldType: InternalField} +} + +// NewUserIsEmpty creates an IS EMPTY condition for user-defined fields. +func NewUserIsEmpty(field string) *IsEmptyCondition { + return &IsEmptyCondition{Field: field, FieldType: UserField} +} + +// ── JSON Serialization ─────────────────────────────────────────────────────── + +// MarshalJSON implements custom JSON marshaling for ConditionSet. +// This is needed because FilterCondition is an interface. +func (cs *ConditionSet) MarshalJSON() ([]byte, error) { + return json.Marshal(cs.Conditions) +} + +// UnmarshalJSON implements custom JSON unmarshaling for ConditionSet. +func (cs *ConditionSet) UnmarshalJSON(data []byte) error { + var raw []json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + cs.Conditions = make([]FilterCondition, 0, len(raw)) + return nil +} diff --git a/v1/vectordb/interface.go b/v1/vectordb/interface.go new file mode 100644 index 0000000..8886b1d --- /dev/null +++ b/v1/vectordb/interface.go @@ -0,0 +1,49 @@ +package vectordb + +import "context" + +// VectorDBService is the common interface for all vector databases. +// It provides a database-agnostic abstraction for vector similarity search, +// allowing applications to switch between different vector databases +// (Qdrant, Weaviate, Pinecone, etc.) without changing application code. +// +// Example usage: +// +// func NewSearchService(db vectordb.VectorDBService) *SearchService { +// return &SearchService{db: db} +// } +// +// // Works with any implementation: +// // - vectordb.NewQdrantAdapter(qdrantClient) +// // - vectordb.NewWeaviateAdapter(weaviateClient) +type VectorDBService interface { + // Search performs similarity search across one or more requests. + // Each request can target a different collection with different filters. + // Returns a slice of result slices—one []SearchResult per request. + // + // Example: + // results, err := db.Search(ctx, + // SearchRequest{CollectionName: "docs", Vector: vec1, TopK: 10}, + // SearchRequest{CollectionName: "docs", Vector: vec2, TopK: 5, Filters: filters}, + // ) + // // results[0] = results for first query + // // results[1] = results for second query + Search(ctx context.Context, requests ...SearchRequest) ([][]SearchResult, error) + + // Insert adds embeddings to a collection. + // Uses batch processing internally for efficiency. + Insert(ctx context.Context, collection string, inputs []EmbeddingInput) error + + // Delete removes points by their IDs from a collection. + Delete(ctx context.Context, collection string, ids []string) error + + // EnsureCollection creates a collection if it doesn't exist. + // Safe to call multiple times—no-op if collection already exists. + EnsureCollection(ctx context.Context, name string, vectorSize uint64) error + + // GetCollection retrieves metadata about a collection. + GetCollection(ctx context.Context, name string) (*Collection, error) + + // ListCollections returns names of all collections. + ListCollections(ctx context.Context) ([]string, error) +} diff --git a/v1/vectordb/types.go b/v1/vectordb/types.go new file mode 100644 index 0000000..9439407 --- /dev/null +++ b/v1/vectordb/types.go @@ -0,0 +1,69 @@ +package vectordb + +// SearchRequest represents a single similarity search query. +// Use with VectorDBService.Search() for single or batch queries. +type SearchRequest struct { + // CollectionName is the target collection to search in + CollectionName string `json:"collectionName"` + + // Vector is the query embedding to find similar vectors for + Vector []float32 `json:"vector"` + + // TopK is the maximum number of results to return + TopK int `json:"maxResults"` + + // Filters is optional metadata filtering (AND/OR/NOT logic) + Filters *FilterSet `json:"filters,omitempty"` +} + +// SearchResult represents a single search result with its similarity score. +// This is database-agnostic—payload is converted to map[string]any. +type SearchResult struct { + // ID is the unique identifier of the matched point + ID string `json:"id"` + + // Score is the similarity score (higher = more similar for cosine) + Score float32 `json:"score"` + + // Payload contains the metadata stored with the vector + Payload map[string]any `json:"payload"` + + // Vector is the stored embedding (only populated if requested) + Vector []float32 `json:"vector,omitempty"` + + // CollectionName identifies which collection this result came from + CollectionName string `json:"collectionName,omitempty"` +} + +// EmbeddingInput is the input for inserting vectors into a collection. +type EmbeddingInput struct { + // ID is the unique identifier for this embedding + ID string `json:"id"` + + // Vector is the dense embedding representation + Vector []float32 `json:"vector"` + + // Payload is optional metadata to store with the vector + Payload map[string]any `json:"payload,omitempty"` +} + +// Collection contains metadata about a vector collection. +type Collection struct { + // Name is the unique identifier of the collection + Name string `json:"name"` + + // Status indicates the operational state (e.g., "Green", "Yellow") + Status string `json:"status"` + + // VectorSize is the dimension of vectors in this collection + VectorSize int `json:"vectorSize"` + + // Distance is the similarity metric (e.g., "Cosine", "Dot", "Euclid") + Distance string `json:"distance"` + + // VectorCount is the number of indexed vectors + VectorCount uint64 `json:"vectorCount"` + + // PointCount is the number of stored points/documents + PointCount uint64 `json:"pointCount"` +} From 0a3ae4cd8b1ea43c4e8135d69621708dacfc4f1e Mon Sep 17 00:00:00 2001 From: Shinu Joseph Date: Mon, 15 Dec 2025 14:35:33 +0100 Subject: [PATCH 2/4] chore: refactor qdrant to decouple vectordb types from Qdrant --- v1/qdrant/client.go | 14 + v1/qdrant/configs.go | 6 + v1/qdrant/converter.go | 135 ++++++---- v1/qdrant/doc.go | 356 ++++++++++++------------ v1/qdrant/operations.go | 79 +++--- v1/qdrant/qdrant_integration_test.go | 390 +++++++++++++++++++++++++++ v1/qdrant/utils.go | 98 ------- v1/vectordb/doc.go | 41 +-- v1/vectordb/filters.go | 226 +++++++--------- v1/vectordb/interface.go | 6 +- v1/vectordb/types.go | 2 +- v1/vectordb/utils.go | 233 ++++++++++++++++ 12 files changed, 1070 insertions(+), 516 deletions(-) create mode 100644 v1/vectordb/utils.go diff --git a/v1/qdrant/client.go b/v1/qdrant/client.go index 8d50596..817cd3f 100644 --- a/v1/qdrant/client.go +++ b/v1/qdrant/client.go @@ -97,6 +97,10 @@ func NewQdrantClient(p QdrantParams) (*QdrantClient, error) { // // It should be lightweight and fast — typically used during startup or readiness probes. func (c *QdrantClient) healthCheck() error { + if !c.started { + return fmt.Errorf("[Qdrant] client not started") + } + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() @@ -114,6 +118,12 @@ func (c *QdrantClient) healthCheck() error { return nil } +// Client returns the underlying Qdrant SDK client. +// This is useful for direct access to low-level operations. +func (c *QdrantClient) Client() *qdrant.Client { + return c.api +} + // Close ────────────────────────────────────────────────────────────── // Close // ────────────────────────────────────────────────────────────── @@ -123,6 +133,10 @@ func (c *QdrantClient) healthCheck() error { // Since the official Qdrant Go SDK doesn't maintain persistent connections, // this is currently a no-op. It exists for lifecycle symmetry and future safety. func (c *QdrantClient) Close() error { + if !c.started { + return nil + } + log.Println("[Qdrant] closing client (no-op)") return nil } diff --git a/v1/qdrant/configs.go b/v1/qdrant/configs.go index cffee45..11a8e8f 100644 --- a/v1/qdrant/configs.go +++ b/v1/qdrant/configs.go @@ -31,6 +31,12 @@ type Config struct { // Optional authentication token for secured deployments. ApiKey string `yaml:"api_key" env:"QDRANT_API_KEY"` + // ──────────────────────────────────────────────────────────────────────────── + // Connection Settings + // ──────────────────────────────────────────────────────────────────────────── + // Note: The following fields are reserved for future use. + // They are not currently passed to the Qdrant SDK client. + // Maximum request duration before timing out. Timeout time.Duration `yaml:"timeout" env:"QDRANT_TIMEOUT"` diff --git a/v1/qdrant/converter.go b/v1/qdrant/converter.go index cd4ebad..0e901e0 100644 --- a/v1/qdrant/converter.go +++ b/v1/qdrant/converter.go @@ -135,30 +135,28 @@ func convertVectorDBMatchAnyCondition(c *vectordb.MatchAnyCondition) []*qdrant.C key := resolveVectorDBFieldKey(c.Field, c.FieldType) // Detect type from first value - if len(c.Values) > 0 { - switch c.Values[0].(type) { - case string: - strs := make([]string, len(c.Values)) - for i, v := range c.Values { - if s, ok := v.(string); ok { - strs[i] = s - } + switch c.Values[0].(type) { + case string: + strs := make([]string, len(c.Values)) + for i, v := range c.Values { + if s, ok := v.(string); ok { + strs[i] = s } - return []*qdrant.Condition{qdrant.NewMatchKeywords(key, strs...)} - case int, int64, float64: - ints := make([]int64, len(c.Values)) - for i, v := range c.Values { - switch n := v.(type) { - case int: - ints[i] = int64(n) - case int64: - ints[i] = n - case float64: - ints[i] = int64(n) - } + } + return []*qdrant.Condition{qdrant.NewMatchKeywords(key, strs...)} + case int, int64, float64: + ints := make([]int64, len(c.Values)) + for i, v := range c.Values { + switch n := v.(type) { + case int: + ints[i] = int64(n) + case int64: + ints[i] = n + case float64: + ints[i] = int64(n) } - return []*qdrant.Condition{qdrant.NewMatchInts(key, ints...)} } + return []*qdrant.Condition{qdrant.NewMatchInts(key, ints...)} } return nil } @@ -170,30 +168,28 @@ func convertVectorDBMatchExceptCondition(c *vectordb.MatchExceptCondition) []*qd key := resolveVectorDBFieldKey(c.Field, c.FieldType) // Detect type from first value - if len(c.Values) > 0 { - switch c.Values[0].(type) { - case string: - strs := make([]string, len(c.Values)) - for i, v := range c.Values { - if s, ok := v.(string); ok { - strs[i] = s - } + switch c.Values[0].(type) { + case string: + strs := make([]string, len(c.Values)) + for i, v := range c.Values { + if s, ok := v.(string); ok { + strs[i] = s } - return []*qdrant.Condition{qdrant.NewMatchExceptKeywords(key, strs...)} - case int, int64, float64: - ints := make([]int64, len(c.Values)) - for i, v := range c.Values { - switch n := v.(type) { - case int: - ints[i] = int64(n) - case int64: - ints[i] = n - case float64: - ints[i] = int64(n) - } + } + return []*qdrant.Condition{qdrant.NewMatchExceptKeywords(key, strs...)} + case int, int64, float64: + ints := make([]int64, len(c.Values)) + for i, v := range c.Values { + switch n := v.(type) { + case int: + ints[i] = int64(n) + case int64: + ints[i] = n + case float64: + ints[i] = int64(n) } - return []*qdrant.Condition{qdrant.NewMatchExceptInts(key, ints...)} } + return []*qdrant.Condition{qdrant.NewMatchExceptInts(key, ints...)} } return nil } @@ -201,10 +197,10 @@ func convertVectorDBMatchExceptCondition(c *vectordb.MatchExceptCondition) []*qd func convertVectorDBNumericRangeCondition(c *vectordb.NumericRangeCondition) []*qdrant.Condition { key := resolveVectorDBFieldKey(c.Field, c.FieldType) rangeFilter := &qdrant.Range{ - Gt: c.GreaterThan, - Gte: c.GreaterThanOrEqualTo, - Lt: c.LessThan, - Lte: c.LessThanOrEqualTo, + Gt: c.Range.Gt, + Gte: c.Range.Gte, + Lt: c.Range.Lt, + Lte: c.Range.Lte, } if rangeFilter.Gt == nil && rangeFilter.Gte == nil && @@ -218,10 +214,10 @@ func convertVectorDBNumericRangeCondition(c *vectordb.NumericRangeCondition) []* func convertVectorDBTimeRangeCondition(c *vectordb.TimeRangeCondition) []*qdrant.Condition { key := resolveVectorDBFieldKey(c.Field, c.FieldType) dateRange := &qdrant.DatetimeRange{ - Gt: toVectorDBTimestamp(c.After), - Gte: toVectorDBTimestamp(c.AtOrAfter), - Lt: toVectorDBTimestamp(c.Before), - Lte: toVectorDBTimestamp(c.AtOrBefore), + Gt: toVectorDBTimestamp(c.Range.Gt), + Gte: toVectorDBTimestamp(c.Range.Gte), + Lt: toVectorDBTimestamp(c.Range.Lt), + Lte: toVectorDBTimestamp(c.Range.Lte), } if dateRange.Gt == nil && dateRange.Gte == nil && @@ -265,7 +261,7 @@ func toVectorDBTimestamp(t *time.Time) *timestamppb.Timestamp { // ── Result Conversion ──────────────────────────────────────────────────────── // parseVectorDBSearchResults converts Qdrant response to vectordb.SearchResult slice. -func parseVectorDBSearchResults(resp []*qdrant.ScoredPoint, collectionName string) ([]vectordb.SearchResult, error) { +func parseVectorDBSearchResults(resp []*qdrant.ScoredPoint) ([]vectordb.SearchResult, error) { results := make([]vectordb.SearchResult, 0, len(resp)) for _, r := range resp { id, err := extractVectorDBPointID(r.Id) @@ -274,10 +270,9 @@ func parseVectorDBSearchResults(resp []*qdrant.ScoredPoint, collectionName strin } results = append(results, vectordb.SearchResult{ - ID: id, - Score: r.Score, - Payload: convertVectorDBPayload(r.Payload), - CollectionName: collectionName, + ID: id, + Score: r.Score, + Payload: convertVectorDBPayload(r.Payload), }) } return results, nil @@ -344,3 +339,33 @@ func extractVectorDBValue(v *qdrant.Value) any { return nil } } + +// convertVectorDBFilterSets converts an array of FilterSets to a single Qdrant filter. +// Multiple filter sets are combined with AND logic (all must match). +func convertVectorDBFilterSets(filters []*vectordb.FilterSet) *qdrant.Filter { + if len(filters) == 0 { + return nil + } + + // Single filter set - convert directly + if len(filters) == 1 { + return convertVectorDBFilterSet(filters[0]) + } + + // Multiple filter sets - combine with AND + var allConditions []*qdrant.Condition + for _, fs := range filters { + converted := convertVectorDBFilterSet(fs) + if converted != nil { + allConditions = append(allConditions, &qdrant.Condition{ + ConditionOneOf: &qdrant.Condition_Filter{Filter: converted}, + }) + } + } + + if len(allConditions) == 0 { + return nil + } + + return &qdrant.Filter{Must: allConditions} +} diff --git a/v1/qdrant/doc.go b/v1/qdrant/doc.go index 71d93bb..b371ed8 100644 --- a/v1/qdrant/doc.go +++ b/v1/qdrant/doc.go @@ -5,22 +5,21 @@ // collection management, embedding insertion, similarity search, and deletion. It integrates // seamlessly with the fx dependency injection framework and supports builder-style configuration. // -// Core Features: +// # Core Features // // - Managed Qdrant client lifecycle with Fx integration // - Config struct supporting environment and YAML loading // - Automatic health checks on client initialization // - Safe, batched insertion of embeddings with configurable batch size -// - Vector similarity search with abstracted SearchResult interface +// - Database-agnostic interface via vectordb.Service // - Type-safe collection creation and existence checks // - Support for payload metadata and optional vector retrieval -// - Extensible abstraction layer for alternate vector stores (e.g., Pinecone, Postgres) -// - VectorDBAdapter implementing vectordb.VectorDBService for DB-agnostic usage +// - Extensible abstraction layer for alternate vector stores (e.g., Pinecone, Weaviate) // // # VectorDB Interface // -// This package includes [VectorDBAdapter] which implements the database-agnostic -// [vectordb.VectorDBService] interface. Use this for new projects to enable +// This package includes [Adapter] which implements the database-agnostic +// [vectordb.Service] interface. Use this for new projects to enable // easy switching between vector databases: // // import ( @@ -29,92 +28,101 @@ // ) // // // Create your existing QdrantClient -// qc, _ := qdrant.NewQdrantClient(params) +// qc, _ := qdrant.NewQdrantClient(qdrant.QdrantParams{ +// Config: &qdrant.Config{ +// Endpoint: "localhost", +// Port: 6334, +// }, +// }) // // // Create adapter for DB-agnostic usage -// var db vectordb.VectorDBService = qdrant.NewVectorDBAdapter(qc.API()) +// var db vectordb.Service = qdrant.NewAdapter(qc.Client()) // // This allows switching between vector databases (Qdrant, Weaviate, Pinecone) without // changing application code. // // # Basic Usage // -// import "github.com/Aleph-Alpha/std/v1/qdrant" +// import ( +// "github.com/Aleph-Alpha/std/v1/qdrant" +// "github.com/Aleph-Alpha/std/v1/vectordb" +// ) // // // Create a new client // client, err := qdrant.NewQdrantClient(qdrant.QdrantParams{ -// Config: &qdrant.Config{ -// Endpoint: "localhost:6334", -// ApiKey: "", -// }, +// Config: &qdrant.Config{ +// Endpoint: "localhost", +// Port: 6334, +// }, // }) // if err != nil { -// log.Fatal(err) +// log.Fatal(err) // } // +// // Create adapter +// adapter := qdrant.NewAdapter(client.API()) +// // collectionName := "documents" // -// // Insert single embedding -// input := qdrant.EmbeddingInput{ -// ID: "doc_1", -// Vector: []float32{0.12, 0.43, 0.85}, -// Meta: map[string]any{"title": "My Document"}, -// } -// if err := client.Insert(ctx, collectionName, input); err != nil { -// log.Fatal(err) +// // Ensure collection exists +// if err := adapter.EnsureCollection(ctx, collectionName, 1536); err != nil { +// log.Fatal(err) // } // -// // Batch insert embeddings -// batch := []qdrant.EmbeddingInput{input1, input2, input3} -// if err := client.BatchInsert(ctx, collectionName, batch); err != nil { -// log.Fatal(err) +// // Insert embeddings +// inputs := []vectordb.EmbeddingInput{ +// { +// ID: "doc_1", +// Vector: []float32{0.12, 0.43, 0.85, ...}, +// Payload: map[string]any{"title": "My Document"}, +// }, +// } +// if err := adapter.Insert(ctx, collectionName, inputs); err != nil { +// log.Fatal(err) // } // // // Perform similarity search -// results, err := client.Search(ctx, qdrant.SearchRequest{ -// CollectionName: collectionName, -// Vector: queryVector, -// TopK: 5, +// results, err := adapter.Search(ctx, vectordb.SearchRequest{ +// CollectionName: collectionName, +// Vector: queryVector, +// TopK: 5, // }) // for _, res := range results[0] { -// fmt.Printf("ID=%s Score=%.4f\n", res.GetID(), res.GetScore()) +// fmt.Printf("ID=%s Score=%.4f\n", res.ID, res.Score) // } // -// FX Module Integration: +// # FX Module Integration // // The package exposes an Fx module for automatic dependency injection: // // app := fx.New( -// qdrant.FXModule, -// // other modules... +// qdrant.FXModule, +// // other modules... // ) // app.Run() // -// Abstractions: -// -// The package defines a lightweight SearchResultInterface that encapsulates -// search results via methods such as GetID(), GetScore(), GetMeta(), and GetCollectionName(). -// The underlying concrete type remains SearchResult, allowing both strong typing internally -// and loose coupling externally. +// # Search Results // -// Example: +// Search results are returned as [vectordb.SearchResult] structs with public fields: // -// type SearchResultInterface interface { -// GetID() string -// GetScore() float32 -// GetMeta() map[string]*qdrant.Value -// GetCollectionName() string +// type SearchResult struct { +// ID string // Unique identifier of the matched point +// Score float32 // Similarity score +// Payload map[string]any // Metadata stored with the vector +// Vector []float32 // Stored embedding (if requested) +// CollectionName string // Source collection name // } // -// type SearchResult struct { /* implements SearchResultInterface */ } +// Access fields directly (no getter methods needed): // -// // Function signature: -// func (c *QdrantClient) Search(ctx context.Context, vector []float32, topK int) ([]SearchResultInterface, error) +// for _, result := range results[0] { +// fmt.Println(result.ID, result.Score, result.Payload["title"]) +// } // // # Filtering // -// The package provides a comprehensive, type-safe filtering system for vector searches. -// Filters support boolean logic (AND, OR, NOT) and various condition types. +// Filters are defined in the [vectordb] package and support boolean logic (AND, OR, NOT). +// The qdrant adapter converts these to native Qdrant filters automatically. // // Filter Structure: // @@ -124,17 +132,13 @@ // MustNot *ConditionSet // NOT - none of the conditions should match // } // -// Condition Types: -// -// - TextCondition: Exact string match -// - BoolCondition: Exact boolean match -// - IntCondition: Exact integer match -// - TextAnyCondition: String IN operator (match any of values) -// - IntAnyCondition: Integer IN operator -// - TextExceptCondition: String NOT IN operator -// - IntExceptCondition: Integer NOT IN operator -// - TimeRangeCondition: DateTime range filter (gte, lte, gt, lt) -// - NumericRangeCondition: Numeric range filter +// Condition Types (all in vectordb package): +// +// - MatchCondition: Exact match (string, bool, int64) +// - MatchAnyCondition: IN operator (match any of values) +// - MatchExceptCondition: NOT IN operator +// - NumericRangeCondition: Numeric range filter (gt, gte, lt, lte) +// - TimeRangeCondition: DateTime range filter // - IsNullCondition: Check if field is null // - IsEmptyCondition: Check if field is empty, null, or missing // @@ -143,176 +147,172 @@ // The package distinguishes between system-managed and user-defined metadata: // // const ( -// InternalField FieldType = iota // Top-level: "search_store_id" -// UserField // Nested: "custom.document_id" +// InternalField FieldType = iota // Top-level: "status" +// UserField // Prefixed: "custom.document_id" // ) // // User fields are automatically prefixed with "custom." when querying Qdrant. // -// Basic Filter Example: +// # Filter Examples Using Convenience Constructors // -// // Filter: city = "London" AND active = true -// filters := &qdrant.FilterSet{ -// Must: &qdrant.ConditionSet{ -// Conditions: []qdrant.FilterCondition{ -// qdrant.TextCondition{Key: "city", Value: "London"}, -// qdrant.BoolCondition{Key: "active", Value: true}, -// }, -// }, -// } +// The vectordb package provides convenience constructors for clean filter creation: +// +// Basic Filter (Must - AND logic): // -// results, err := client.Search(ctx, qdrant.SearchRequest{ +// // Filter: city = "London" AND active = true +// results, err := adapter.Search(ctx, vectordb.SearchRequest{ // CollectionName: "documents", // Vector: queryVector, // TopK: 10, -// Filters: filters, +// Filters: []*vectordb.FilterSet{ +// vectordb.NewFilterSet( +// vectordb.Must( +// vectordb.NewMatch("city", "London"), +// vectordb.NewMatch("active", true), +// ), +// ), +// }, // }) // // OR Conditions (Should): // // // Filter: city = "London" OR city = "Berlin" -// filters := &qdrant.FilterSet{ -// Should: &qdrant.ConditionSet{ -// Conditions: []qdrant.FilterCondition{ -// qdrant.TextCondition{Key: "city", Value: "London"}, -// qdrant.TextCondition{Key: "city", Value: "Berlin"}, -// }, -// }, +// filters := []*vectordb.FilterSet{ +// vectordb.NewFilterSet( +// vectordb.Should( +// vectordb.NewMatch("city", "London"), +// vectordb.NewMatch("city", "Berlin"), +// ), +// ), // } // // IN Operator (MatchAny): // // // Filter: city IN ["London", "Berlin", "Paris"] -// filters := &qdrant.FilterSet{ -// Must: &qdrant.ConditionSet{ -// Conditions: []qdrant.FilterCondition{ -// qdrant.TextAnyCondition{ -// Key: "city", -// Values: []string{"London", "Berlin", "Paris"}, -// }, -// }, -// }, +// filters := []*vectordb.FilterSet{ +// vectordb.NewFilterSet( +// vectordb.Must( +// vectordb.NewMatchAny("city", "London", "Berlin", "Paris"), +// ), +// ), +// } +// +// Numeric Range Filter: +// +// // Filter: price >= 100 AND price < 500 +// min, max := float64(100), float64(500) +// filters := []*vectordb.FilterSet{ +// vectordb.NewFilterSet( +// vectordb.Must( +// vectordb.NewNumericRange("price", vectordb.NumericRange{ +// Gte: &min, +// Lt: &max, +// }), +// ), +// ), // } // // Time Range Filter: // +// // Filter: created_at >= yesterday AND created_at < now // now := time.Now() // yesterday := now.Add(-24 * time.Hour) -// -// filters := &qdrant.FilterSet{ -// Must: &qdrant.ConditionSet{ -// Conditions: []qdrant.FilterCondition{ -// qdrant.TimeRangeCondition{ -// Key: "created_at", -// Value: qdrant.TimeRange{ -// Gte: &yesterday, -// Lt: &now, -// }, -// }, -// }, -// }, +// filters := []*vectordb.FilterSet{ +// vectordb.NewFilterSet( +// vectordb.Must( +// vectordb.NewTimeRange("created_at", vectordb.TimeRange{ +// Gte: &yesterday, +// Lt: &now, +// }), +// ), +// ), // } // // Complex Filter (Combined Clauses): // -// // Filter: (status = "published") AND (category = "tech" OR "science") AND NOT (deleted = true) -// filters := &qdrant.FilterSet{ -// Must: &qdrant.ConditionSet{ -// Conditions: []qdrant.FilterCondition{ -// qdrant.TextCondition{Key: "status", Value: "published"}, -// }, -// }, -// Should: &qdrant.ConditionSet{ -// Conditions: []qdrant.FilterCondition{ -// qdrant.TextCondition{Key: "category", Value: "tech", FieldType: qdrant.UserField}, -// qdrant.TextCondition{Key: "category", Value: "science", FieldType: qdrant.UserField}, -// }, -// }, -// MustNot: &qdrant.ConditionSet{ -// Conditions: []qdrant.FilterCondition{ -// qdrant.BoolCondition{Key: "deleted", Value: true}, -// }, -// }, +// // Filter: status = "published" AND (tag = "ml" OR tag = "ai") AND NOT deleted = true +// filters := []*vectordb.FilterSet{ +// vectordb.NewFilterSet( +// vectordb.Must(vectordb.NewMatch("status", "published")), +// vectordb.Should( +// vectordb.NewMatch("tag", "ml"), +// vectordb.NewMatch("tag", "ai"), +// ), +// vectordb.MustNot(vectordb.NewMatch("deleted", true)), +// ), // } // -// UUID Filtering: +// User-Defined Fields: // -// UUIDs are filtered as strings using TextCondition: +// For fields stored under a custom prefix, use the User* constructors: // -// filters := &qdrant.FilterSet{ -// Must: &qdrant.ConditionSet{ -// Conditions: []qdrant.FilterCondition{ -// qdrant.TextCondition{ -// Key: "document_id", -// Value: "f47ac10b-58cc-4372-a567-0e02b2c3d479", -// FieldType: qdrant.UserField, -// }, -// }, -// }, +// // Filter on user-defined field: custom.document_id = "doc-123" +// filters := []*vectordb.FilterSet{ +// vectordb.NewFilterSet( +// vectordb.Must( +// vectordb.NewUserMatch("document_id", "doc-123"), +// ), +// ), // } // -// Payload Structure Helper: -// -// The BuildPayload function creates properly structured payloads that separate -// internal and user fields: -// -// payload := qdrant.BuildPayload( -// map[string]any{"search_store_id": "store-123"}, // Internal (top-level) -// map[string]any{"document_id": "doc-456"}, // User (under "custom.") -// ) +// Multiple FilterSets (AND between sets): // -// // Result: -// // { -// // "search_store_id": "store-123", -// // "custom": { -// // "document_id": "doc-456" -// // } -// // } +// When you provide multiple FilterSets, they are combined with AND logic: // -// This structure ensures user-defined filter indexes are created at the correct -// path (custom.field_name). +// // Filter: (color = "red") AND (size < 20) +// lt := float64(20) +// filters := []*vectordb.FilterSet{ +// vectordb.NewFilterSet(vectordb.Must(vectordb.NewMatch("color", "red"))), +// vectordb.NewFilterSet(vectordb.Must(vectordb.NewNumericRange("size", vectordb.NumericRange{Lt: <}))), +// } // -// Configuration: +// # Configuration // // Qdrant can be configured via environment variables or YAML: // -// QDRANT_ENDPOINT=http://localhost:6334 +// QDRANT_ENDPOINT=localhost +// QDRANT_PORT=6334 // QDRANT_API_KEY=your-api-key // -// Performance Considerations: +// # Performance Considerations // -// The BatchInsert method automatically splits large embedding inserts into smaller -// upserts (default batch size = 500). This minimizes memory usage and avoids timeouts +// The Insert method automatically splits large embedding batches into smaller +// upserts (default batch size = 100). This minimizes memory usage and avoids timeouts // when ingesting large datasets. // -// Thread Safety: +// # Thread Safety // -// All exported methods on QdrantClient are safe for concurrent use by multiple goroutines. +// All exported methods on the Adapter are safe for concurrent use by multiple goroutines. // -// Testing: +// # Testing // -// For testing and mocking, application code should depend on the public interface types -// (e.g., SearchResultInterface, EmbeddingInput) instead of concrete Qdrant structs. -// This allows replacing the QdrantClient with in-memory or mock implementations in tests. +// For testing and mocking, depend on the [vectordb.Service] interface: // -// Example Mock: +// type MockVectorDB struct{} // -// type MockResult struct { -// id string -// score float32 -// meta map[string]any +// func (m *MockVectorDB) Search(ctx context.Context, requests ...vectordb.SearchRequest) ([][]vectordb.SearchResult, error) { +// return [][]vectordb.SearchResult{ +// {{ID: "doc-1", Score: 0.95, Payload: map[string]any{"title": "Test"}}}, +// }, nil // } -// func (m MockResult) GetID() string { return m.id } -// func (m MockResult) GetScore() float32 { return m.score } -// func (m MockResult) GetMeta() map[string]any { return m.meta } // -// Package Layout: +// // Use in tests: +// var db vectordb.Service = &MockVectorDB{} +// +// # Package Layout // // qdrant/ -// ├── client.go // Qdrant client implementation -// ├── operations.go // CRUD operations (Insert, Search, Delete, etc.) -// ├── filters.go // Type-safe filtering system -// ├── utils.go // Shared types and interfaces -// ├── configs.go // Configuration struct and builder methods +// ├── client.go // Qdrant client wrapper and lifecycle +// ├── operations.go // Service implementation (Adapter) +// ├── converter.go // vectordb ↔ Qdrant type conversion +// ├── utils.go // Qdrant-specific helper functions +// ├── configs.go // Configuration struct // └── fx_module.go // Fx dependency injection module +// +// # Related Packages +// +// - [vectordb]: Database-agnostic types and interfaces +// - [vectordb.FilterSet]: Filter structures for search queries +// - [vectordb.SearchResult]: Search result type +// - [vectordb.EmbeddingInput]: Input type for inserting vectors package qdrant diff --git a/v1/qdrant/operations.go b/v1/qdrant/operations.go index 4c6b107..0a22f5e 100644 --- a/v1/qdrant/operations.go +++ b/v1/qdrant/operations.go @@ -13,13 +13,13 @@ import ( ) // Ensure VectorDBAdapter implements VectorDBService at compile time -var _ vectordb.VectorDBService = (*Adapter)(nil) +var _ vectordb.Service = (*Adapter)(nil) // ══════════════════════════════════════════════════════════════════════════════ -// Adapter - implements vectordb.VectorDBService interface +// Adapter - implements vectordb.Service interface // ══════════════════════════════════════════════════════════════════════════════ -// VectorDBAdapter implements vectordb.VectorDBService for Qdrant. +// Adapter implements vectordb.Service for Qdrant. // It wraps a Qdrant client and converts between generic vectordb types // and Qdrant-specific protobuf types. // @@ -29,14 +29,14 @@ type Adapter struct { client *qdrant.Client } -// NewVectorDBAdapter creates a new Qdrant adapter for the vectordb interface. +// NewAdapter creates a new Qdrant adapter for the vectordb interface. // Pass the underlying SDK client via QdrantClient.API(). // // Example: // // qc, _ := qdrant.NewQdrantClient(params) -// adapter := qdrant.NewVectorDBAdapter(qc.API()) -// var db vectordb.VectorDBService = adapter +// adapter := qdrant.NewAdapter(qc.Client()) +// var db vectordb.Service = adapter func NewAdapter(client *qdrant.Client) *Adapter { return &Adapter{client: client} } @@ -49,32 +49,31 @@ func (a *Adapter) Search(ctx context.Context, requests ...vectordb.SearchRequest log.Printf("[Qdrant] Starting search batch with %d requests", len(requests)) - // Validate - for i, req := range requests { - if req.CollectionName == "" { - return nil, fmt.Errorf("request [%d]: collection name cannot be empty", i) - } - if len(req.Vector) == 0 { - return nil, fmt.Errorf("request [%d]: vector cannot be empty", i) - } - if req.TopK <= 0 { - return nil, fmt.Errorf("request [%d]: topK must be greater than 0", i) + // Validate all requests first + for i, searchReq := range requests { + if err := validateSearchInput(searchReq.CollectionName, searchReq.Vector, searchReq.TopK); err != nil { + return nil, fmt.Errorf("request [%d]: %w", i, err) } } results := make([][]vectordb.SearchResult, len(requests)) + + // Create errgroup with context g, ctx := errgroup.WithContext(ctx) + + // Create semaphore to limit concurrent searches sem := semaphore.NewWeighted(maxConcurrentSearches) - for i, req := range requests { - i, req := i, req + for i, searchReq := range requests { + i, searchReq := i, searchReq // Capture loop variables g.Go(func() error { + // Acquire semaphore (blocks if at max concurrency) if err := sem.Acquire(ctx, 1); err != nil { return fmt.Errorf("request [%d]: failed to acquire semaphore: %w", i, err) } defer sem.Release(1) - res, err := searchInternal(ctx, a.client, req) + res, err := searchInternal(ctx, a.client, searchReq) if err != nil { return fmt.Errorf("request [%d]: search failed: %w", i, err) } @@ -91,19 +90,19 @@ func (a *Adapter) Search(ctx context.Context, requests ...vectordb.SearchRequest } // Insert adds embeddings to a collection using batch processing. -func (a *Adapter) Insert(ctx context.Context, collection string, inputs []vectordb.EmbeddingInput) error { +func (a *Adapter) Insert(ctx context.Context, collectionName string, inputs []vectordb.EmbeddingInput) error { if len(inputs) == 0 { return nil } - if collection == "" { + if collectionName == "" { return fmt.Errorf("collection name cannot be empty") } - return insertInternal(ctx, a.client, collection, inputs) + return insertInternal(ctx, a.client, collectionName, inputs) } // Delete removes points by their IDs from a collection. -func (a *Adapter) Delete(ctx context.Context, collection string, ids []string) error { - return deleteInternal(ctx, a.client, collection, ids) +func (a *Adapter) Delete(ctx context.Context, collectionName string, ids []string) error { + return deleteInternal(ctx, a.client, collectionName, ids) } // EnsureCollection creates a collection if it doesn't exist. @@ -124,14 +123,16 @@ func (a *Adapter) GetCollection(ctx context.Context, name string) (*vectordb.Col size, distance := extractVectorDetails(info) - return &vectordb.Collection{ + collection := &vectordb.Collection{ Name: name, Status: info.Status.String(), VectorSize: size, Distance: distance, VectorCount: derefUint64(info.IndexedVectorsCount), PointCount: derefUint64(info.PointsCount), - }, nil + } + + return collection, nil } // ListCollections returns names of all collections. @@ -143,24 +144,24 @@ func (a *Adapter) ListCollections(ctx context.Context) ([]string, error) { // Internal Functions // ══════════════════════════════════════════════════════════════════════════════ -func searchInternal(ctx context.Context, client *qdrant.Client, req vectordb.SearchRequest) ([]vectordb.SearchResult, error) { - limit := uint64(req.TopK) +func searchInternal(ctx context.Context, client *qdrant.Client, searchReq vectordb.SearchRequest) ([]vectordb.SearchResult, error) { + limit := uint64(searchReq.TopK) queryReq := &qdrant.QueryPoints{ - CollectionName: req.CollectionName, - Query: qdrant.NewQuery(req.Vector...), + CollectionName: searchReq.CollectionName, + Query: qdrant.NewQuery(searchReq.Vector...), Limit: &limit, WithPayload: qdrant.NewWithPayload(true), - Filter: convertVectorDBFilterSet(req.Filters), + Filter: convertVectorDBFilterSets(searchReq.Filters), } resp, err := client.Query(ctx, queryReq) if err != nil { return nil, fmt.Errorf("query failed: %w", err) } - return parseVectorDBSearchResults(resp, req.CollectionName) + return parseVectorDBSearchResults(resp) } -func insertInternal(ctx context.Context, client *qdrant.Client, collection string, inputs []vectordb.EmbeddingInput) error { +func insertInternal(ctx context.Context, client *qdrant.Client, collectionName string, inputs []vectordb.EmbeddingInput) error { for start := 0; start < len(inputs); start += defaultBatchSize { end := start + defaultBatchSize if end > len(inputs) { @@ -179,7 +180,7 @@ func insertInternal(ctx context.Context, client *qdrant.Client, collection strin wait := true req := &qdrant.UpsertPoints{ - CollectionName: collection, + CollectionName: collectionName, Points: points, Wait: &wait, } @@ -187,16 +188,16 @@ func insertInternal(ctx context.Context, client *qdrant.Client, collection strin if _, err := client.Upsert(ctx, req); err != nil { return fmt.Errorf("[Qdrant] batch upsert failed at [%d:%d]: %w", start, end, err) } - log.Printf("[Qdrant] Inserted batch [%d:%d] (collection=%s)", start, end, collection) + log.Printf("[Qdrant] Inserted batch [%d:%d] (collection=%s)", start, end, collectionName) } return nil } -func deleteInternal(ctx context.Context, client *qdrant.Client, collection string, ids []string) error { +func deleteInternal(ctx context.Context, client *qdrant.Client, collectionName string, ids []string) error { if len(ids) == 0 { return nil } - if collection == "" { + if collectionName == "" { return fmt.Errorf("collection name cannot be empty") } @@ -207,7 +208,7 @@ func deleteInternal(ctx context.Context, client *qdrant.Client, collection strin wait := true req := &qdrant.DeletePoints{ - CollectionName: collection, + CollectionName: collectionName, Points: &qdrant.PointsSelector{ PointsSelectorOneOf: &qdrant.PointsSelector_Points{ Points: &qdrant.PointsIdsList{Ids: qdrantIDs}, @@ -221,7 +222,7 @@ func deleteInternal(ctx context.Context, client *qdrant.Client, collection strin return fmt.Errorf("[Qdrant] delete failed: %w", err) } - log.Printf("[Qdrant] Delete completed (status=%s, collection=%s)", resp.Status.String(), collection) + log.Printf("[Qdrant] Delete completed (status=%s, collection=%s)", resp.Status.String(), collectionName) return nil } diff --git a/v1/qdrant/qdrant_integration_test.go b/v1/qdrant/qdrant_integration_test.go index 4e5fed3..4ce996f 100644 --- a/v1/qdrant/qdrant_integration_test.go +++ b/v1/qdrant/qdrant_integration_test.go @@ -608,6 +608,396 @@ func TestQdrantLifecycleAndHealthCheck(t *testing.T) { t.Log("Qdrant client lifecycle test passed successfully") } +// TestFilterOperations tests various filter scenarios +func TestFilterOperations(t *testing.T) { + // Skip if running in short mode + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + containerInstance, err := setupQdrantContainer(ctx) + require.NoError(t, err) + defer func() { + if err := containerInstance.Terminate(ctx); err != nil { + t.Fatalf("failed to terminate container: %s", err) + } + }() + + // Convert port to uint + portNum, err := strconv.Atoi(containerInstance.Port) + require.NoError(t, err) + + cfg := &Config{ + Endpoint: containerInstance.Host, + Port: portNum, + CheckCompatibility: false, + Timeout: 10 * time.Second, + } + + client, err := NewQdrantClient(QdrantParams{Config: cfg}) + require.NoError(t, err) + defer client.Close() + + adapter := NewAdapter(client.api) + + t.Run("SearchWithMustFilter", func(t *testing.T) { + collectionName := "test_must_filter" + err := adapter.EnsureCollection(ctx, collectionName, 1536) + require.NoError(t, err) + + // Insert embeddings with different metadata + embeddings := []vectordb.EmbeddingInput{ + { + ID: "00000000-0000-0000-0004-000000000001", + Vector: generateRandomVector(1536), + Payload: map[string]any{"color": "red", "size": int64(10)}, + }, + { + ID: "00000000-0000-0000-0004-000000000002", + Vector: generateRandomVector(1536), + Payload: map[string]any{"color": "blue", "size": int64(20)}, + }, + { + ID: "00000000-0000-0000-0004-000000000003", + Vector: generateRandomVector(1536), + Payload: map[string]any{"color": "red", "size": int64(30)}, + }, + } + + err = adapter.Insert(ctx, collectionName, embeddings) + require.NoError(t, err) + time.Sleep(1 * time.Second) + + // Test Must filter (color == "red") + batchResults, err := adapter.Search(ctx, vectordb.SearchRequest{ + CollectionName: collectionName, + Vector: embeddings[0].Vector, + TopK: 10, + Filters: []*vectordb.FilterSet{ + vectordb.NewFilterSet( + vectordb.Must(vectordb.NewMatch("color", "red")), + ), + }, + }) + assert.NoError(t, err) + assert.Equal(t, 2, len(batchResults[0])) // Only red items + + // Clean up + ids := []string{embeddings[0].ID, embeddings[1].ID, embeddings[2].ID} + err = adapter.Delete(ctx, collectionName, ids) + assert.NoError(t, err) + }) + + t.Run("SearchWithNumericRangeFilter", func(t *testing.T) { + collectionName := "test_numeric_range" + err := adapter.EnsureCollection(ctx, collectionName, 1536) + require.NoError(t, err) + + embeddings := []vectordb.EmbeddingInput{ + {ID: "00000000-0000-0000-0010-000000000001", Vector: generateRandomVector(1536), Payload: map[string]any{"size": int64(10)}}, + {ID: "00000000-0000-0000-0010-000000000002", Vector: generateRandomVector(1536), Payload: map[string]any{"size": int64(20)}}, + {ID: "00000000-0000-0000-0010-000000000003", Vector: generateRandomVector(1536), Payload: map[string]any{"size": int64(30)}}, + } + + err = adapter.Insert(ctx, collectionName, embeddings) + require.NoError(t, err) + time.Sleep(1 * time.Second) + + // Test numeric range filter (size >= 20) + gte := float64(20) + batchResults, err := adapter.Search(ctx, vectordb.SearchRequest{ + CollectionName: collectionName, + Vector: embeddings[0].Vector, + TopK: 10, + Filters: []*vectordb.FilterSet{ + vectordb.NewFilterSet( + vectordb.Must(vectordb.NewNumericRange("size", vectordb.NumericRange{Gte: >e})), + ), + }, + }) + assert.NoError(t, err) + assert.Equal(t, 2, len(batchResults[0])) // size 20 and 30 + + // Clean up + ids := []string{embeddings[0].ID, embeddings[1].ID, embeddings[2].ID} + err = adapter.Delete(ctx, collectionName, ids) + assert.NoError(t, err) + }) + + t.Run("SearchWithMustNotFilter", func(t *testing.T) { + collectionName := "test_must_not" + err := adapter.EnsureCollection(ctx, collectionName, 1536) + require.NoError(t, err) + + embeddings := []vectordb.EmbeddingInput{ + {ID: "00000000-0000-0000-0005-000000000001", Vector: generateRandomVector(1536), Payload: map[string]any{"status": "published"}}, + {ID: "00000000-0000-0000-0005-000000000002", Vector: generateRandomVector(1536), Payload: map[string]any{"status": "draft"}}, + {ID: "00000000-0000-0000-0005-000000000003", Vector: generateRandomVector(1536), Payload: map[string]any{"status": "archived"}}, + } + + err = adapter.Insert(ctx, collectionName, embeddings) + require.NoError(t, err) + time.Sleep(1 * time.Second) + + // Exclude archived items + batchResults, err := adapter.Search(ctx, vectordb.SearchRequest{ + CollectionName: collectionName, + Vector: embeddings[0].Vector, + TopK: 10, + Filters: []*vectordb.FilterSet{ + vectordb.NewFilterSet( + vectordb.MustNot(vectordb.NewMatch("status", "archived")), + ), + }, + }) + assert.NoError(t, err) + assert.Equal(t, 2, len(batchResults[0])) // published and draft only + + // Clean up + ids := []string{embeddings[0].ID, embeddings[1].ID, embeddings[2].ID} + err = adapter.Delete(ctx, collectionName, ids) + assert.NoError(t, err) + }) + + t.Run("SearchWithMultipleFilterSets", func(t *testing.T) { + collectionName := "test_multi_filter" + err := adapter.EnsureCollection(ctx, collectionName, 1536) + require.NoError(t, err) + + embeddings := []vectordb.EmbeddingInput{ + {ID: "00000000-0000-0000-0006-000000000001", Vector: generateRandomVector(1536), Payload: map[string]any{"color": "red", "size": int64(10)}}, + {ID: "00000000-0000-0000-0006-000000000002", Vector: generateRandomVector(1536), Payload: map[string]any{"color": "red", "size": int64(50)}}, + {ID: "00000000-0000-0000-0006-000000000003", Vector: generateRandomVector(1536), Payload: map[string]any{"color": "blue", "size": int64(10)}}, + } + + err = adapter.Insert(ctx, collectionName, embeddings) + require.NoError(t, err) + time.Sleep(1 * time.Second) + + // color == red AND size < 20 + lt := float64(20) + batchResults, err := adapter.Search(ctx, vectordb.SearchRequest{ + CollectionName: collectionName, + Vector: embeddings[0].Vector, + TopK: 10, + Filters: []*vectordb.FilterSet{ + vectordb.NewFilterSet(vectordb.Must(vectordb.NewMatch("color", "red"))), + vectordb.NewFilterSet(vectordb.Must(vectordb.NewNumericRange("size", vectordb.NumericRange{Lt: <}))), + }, + }) + assert.NoError(t, err) + assert.Equal(t, 1, len(batchResults[0])) // Only red with size < 20 + + // Clean up + ids := []string{embeddings[0].ID, embeddings[1].ID, embeddings[2].ID} + err = adapter.Delete(ctx, collectionName, ids) + assert.NoError(t, err) + }) + + t.Run("SearchWithMatchAnyFilter", func(t *testing.T) { + collectionName := "test_match_any" + err := adapter.EnsureCollection(ctx, collectionName, 1536) + require.NoError(t, err) + + embeddings := []vectordb.EmbeddingInput{ + {ID: "00000000-0000-0000-0008-000000000001", Vector: generateRandomVector(1536), Payload: map[string]any{"tag": "ml"}}, + {ID: "00000000-0000-0000-0008-000000000002", Vector: generateRandomVector(1536), Payload: map[string]any{"tag": "nlp"}}, + {ID: "00000000-0000-0000-0008-000000000003", Vector: generateRandomVector(1536), Payload: map[string]any{"tag": "cv"}}, + } + + err = adapter.Insert(ctx, collectionName, embeddings) + require.NoError(t, err) + time.Sleep(1 * time.Second) + + // Match any of ["ml", "nlp"] + batchResults, err := adapter.Search(ctx, vectordb.SearchRequest{ + CollectionName: collectionName, + Vector: embeddings[0].Vector, + TopK: 10, + Filters: []*vectordb.FilterSet{ + vectordb.NewFilterSet( + vectordb.Must(vectordb.NewMatchAny("tag", "ml", "nlp")), + ), + }, + }) + assert.NoError(t, err) + assert.Equal(t, 2, len(batchResults[0])) // ml and nlp only + + // Clean up + ids := []string{embeddings[0].ID, embeddings[1].ID, embeddings[2].ID} + err = adapter.Delete(ctx, collectionName, ids) + assert.NoError(t, err) + }) + + t.Run("SearchWithShouldFilter", func(t *testing.T) { + collectionName := "test_should_filter" + err := adapter.EnsureCollection(ctx, collectionName, 1536) + require.NoError(t, err) + + embeddings := []vectordb.EmbeddingInput{ + {ID: "00000000-0000-0000-0009-000000000001", Vector: generateRandomVector(1536), Payload: map[string]any{"color": "red"}}, + {ID: "00000000-0000-0000-0009-000000000002", Vector: generateRandomVector(1536), Payload: map[string]any{"color": "blue"}}, + {ID: "00000000-0000-0000-0009-000000000003", Vector: generateRandomVector(1536), Payload: map[string]any{"color": "green"}}, + } + + err = adapter.Insert(ctx, collectionName, embeddings) + require.NoError(t, err) + time.Sleep(1 * time.Second) + + // Should match red OR blue + batchResults, err := adapter.Search(ctx, vectordb.SearchRequest{ + CollectionName: collectionName, + Vector: embeddings[0].Vector, + TopK: 10, + Filters: []*vectordb.FilterSet{ + vectordb.NewFilterSet( + vectordb.Should( + vectordb.NewMatch("color", "red"), + vectordb.NewMatch("color", "blue"), + ), + ), + }, + }) + assert.NoError(t, err) + assert.Equal(t, 2, len(batchResults[0])) // red and blue only + + // Clean up + ids := []string{embeddings[0].ID, embeddings[1].ID, embeddings[2].ID} + err = adapter.Delete(ctx, collectionName, ids) + assert.NoError(t, err) + }) +} + +// TestBatchSearch tests batch search with multiple queries +func TestBatchSearch(t *testing.T) { + // Skip if running in short mode + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + containerInstance, err := setupQdrantContainer(ctx) + require.NoError(t, err) + defer func() { + if err := containerInstance.Terminate(ctx); err != nil { + t.Fatalf("failed to terminate container: %s", err) + } + }() + + portNum, err := strconv.Atoi(containerInstance.Port) + require.NoError(t, err) + + cfg := &Config{ + Endpoint: containerInstance.Host, + Port: portNum, + CheckCompatibility: false, + Timeout: 10 * time.Second, + } + + client, err := NewQdrantClient(QdrantParams{Config: cfg}) + require.NoError(t, err) + defer client.Close() + + adapter := NewAdapter(client.api) + + collectionName := "test_batch_search" + err = adapter.EnsureCollection(ctx, collectionName, 1536) + require.NoError(t, err) + + embeddings := make([]vectordb.EmbeddingInput, 5) + for i := 0; i < 5; i++ { + embeddings[i] = vectordb.EmbeddingInput{ + ID: fmt.Sprintf("00000000-0000-0000-0007-%012d", i+1), + Vector: generateRandomVector(1536), + Payload: map[string]any{"index": int64(i)}, + } + } + + err = adapter.Insert(ctx, collectionName, embeddings) + require.NoError(t, err) + time.Sleep(1 * time.Second) + + // Batch search with multiple queries + batchResults, err := adapter.Search(ctx, + vectordb.SearchRequest{CollectionName: collectionName, Vector: embeddings[0].Vector, TopK: 3}, + vectordb.SearchRequest{CollectionName: collectionName, Vector: embeddings[1].Vector, TopK: 3}, + vectordb.SearchRequest{CollectionName: collectionName, Vector: embeddings[2].Vector, TopK: 3}, + ) + assert.NoError(t, err) + assert.Equal(t, 3, len(batchResults)) // 3 result sets + + // Each result set should have results + for i, results := range batchResults { + assert.Greater(t, len(results), 0, "result set %d should have results", i) + } + + // Clean up + ids := make([]string, len(embeddings)) + for i, emb := range embeddings { + ids[i] = emb.ID + } + err = adapter.Delete(ctx, collectionName, ids) + assert.NoError(t, err) +} + +// TestListCollections tests collection listing +func TestListCollections(t *testing.T) { + // Skip if running in short mode + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + containerInstance, err := setupQdrantContainer(ctx) + require.NoError(t, err) + defer func() { + if err := containerInstance.Terminate(ctx); err != nil { + t.Fatalf("failed to terminate container: %s", err) + } + }() + + portNum, err := strconv.Atoi(containerInstance.Port) + require.NoError(t, err) + + cfg := &Config{ + Endpoint: containerInstance.Host, + Port: portNum, + CheckCompatibility: false, + Timeout: 10 * time.Second, + } + + client, err := NewQdrantClient(QdrantParams{Config: cfg}) + require.NoError(t, err) + defer client.Close() + + adapter := NewAdapter(client.api) + + // Create a few collections + err = adapter.EnsureCollection(ctx, "list_test_1", 1536) + require.NoError(t, err) + err = adapter.EnsureCollection(ctx, "list_test_2", 1536) + require.NoError(t, err) + + // List collections + collections, err := adapter.ListCollections(ctx) + assert.NoError(t, err) + assert.GreaterOrEqual(t, len(collections), 2) + + // Verify our collections are in the list + found1, found2 := false, false + for _, name := range collections { + if name == "list_test_1" { + found1 = true + } + if name == "list_test_2" { + found2 = true + } + } + assert.True(t, found1, "list_test_1 should be in collections") + assert.True(t, found2, "list_test_2 should be in collections") +} + // Helper function to generate random vectors for testing func generateRandomVector(size int) []float32 { vector := make([]float32, size) diff --git a/v1/qdrant/utils.go b/v1/qdrant/utils.go index 6815a6e..964f0a0 100644 --- a/v1/qdrant/utils.go +++ b/v1/qdrant/utils.go @@ -3,107 +3,9 @@ package qdrant import ( "fmt" - "github.com/Aleph-Alpha/std/v1/vectordb" qdrant "github.com/qdrant/go-client/qdrant" ) -// EmbeddingInput is the type the application provides to insert embeddings. -// Keeps the app decoupled from internal Qdrant SDK structs. -type EmbeddingInput struct { - ID string // Unique identifier for the embedding (e.g., document ID) - Vector []float32 // Dense vector representation of the embedding - Meta map[string]any // Optional metadata associated with the embedding -} - -// Embedding represents a dense embedding vector. -type Embedding struct { - ID string // Unique identifier (same as Qdrant point ID) - Vector []float32 // Vector representation of the embedding - Meta map[string]any // Optional metadata associated with the embedding -} - -// SearchResult holds results from a similarity search. -type SearchResult struct { - ID string - Score float32 - Meta map[string]*qdrant.Value - Vector []float32 - Collection string -} - -// GetID returns the result's unique identifier. -func (r SearchResult) GetID() string { return r.ID } - -// GetScore returns the similarity score associated with the result. -func (r SearchResult) GetScore() float32 { return r.Score } - -// GetMeta returns the metadata stored with the vector. -func (r SearchResult) GetMeta() map[string]*qdrant.Value { return r.Meta } - -// GetVector returns the dense embedding vector if available. -func (r SearchResult) GetVector() []float32 { return r.Vector } - -// HasVector reports whether the result contains a non-empty vector payload. -func (r SearchResult) HasVector() bool { return len(r.Vector) > 0 } - -// GetCollectionName returns the name of the collection from which the result originated. -func (r SearchResult) GetCollectionName() string { return r.Collection } - -// NewEmbedding converts a high-level EmbeddingInput into the internal Embedding type. -// Having this builder allows for future validation or normalization logic. -// For now, it performs a shallow copy. -func NewEmbedding(input EmbeddingInput) Embedding { - return Embedding(input) -} - -// SearchResultInterface is the public interface for search results. -// It provides a consistent way to access search results regardless of the underlying implementation. -type SearchResultInterface interface { - GetID() string // Unique result identifier - GetScore() float32 // Similarity score - GetMeta() map[string]*qdrant.Value // Metadata associated with the result - GetVector() []float32 // Optional embedding vector - HasVector() bool // Whether a vector payload is present - GetCollectionName() string // Name of the Qdrant collection -} - -// Collection ────────────────────────────────────────────────────────────── -// Collection -// ────────────────────────────────────────────────────────────── -// -// Collection represents a high-level, decoupled view of a Qdrant collection. -// -// It provides essential metadata about a vector collection without exposing -// Qdrant SDK types, allowing the application layer to remain independent -// of the underlying database implementation. -// -// Fields: -// - Name — The unique name of the collection. -// - Status — Current operational state (e.g., "Green", "Yellow"). -// - VectorSize — The dimension of stored vectors (e.g., 1536). -// - Distance — The similarity metric used ("Cosine", "Dot", "Euclid"). -// - Vectors — Total number of stored vectors in the collection. -// - Points — Total number of indexed points/documents in the collection. -// -// This struct serves as an abstraction layer between Qdrant's low-level -// protobuf models and the higher-level application logic. -type Collection struct { - Name string - Status string - VectorSize int - Distance string - Vectors uint64 - Points uint64 -} - -// SearchRequest represents a single search request for batch operations -type SearchRequest struct { - CollectionName string - Vector []float32 - TopK int - Filters *vectordb.FilterSet // Optional: key-value filters -} - // validateSearchInput validates common search parameters func validateSearchInput(collectionName string, vector []float32, topK int) error { if collectionName == "" { diff --git a/v1/vectordb/doc.go b/v1/vectordb/doc.go index fb91483..7432dbb 100644 --- a/v1/vectordb/doc.go +++ b/v1/vectordb/doc.go @@ -2,7 +2,7 @@ // // # Overview // -// This package defines a common interface [VectorDBService] that can be implemented +// This package defines a common interface [Service] that can be implemented // by different vector database adapters (Qdrant, Weaviate, Pinecone, etc.), allowing // applications to switch between databases without changing application code. // @@ -10,12 +10,12 @@ // // ┌─────────────────────────────────────────────────────────────┐ // │ Application Layer │ -// │ (uses vectordb.VectorDBService - no DB-specific imports) │ +// │ (uses vectordb.Service - no DB-specific imports) │ // └──────────────────────────┬──────────────────────────────────┘ // │ // ▼ // ┌─────────────────────────────────────────────────────────────┐ -// │ vectordb.VectorDBService │ +// │ vectordb.Service │ // │ (common interface + DB-agnostic types) │ // └──────────────────────────┬──────────────────────────────────┘ // │ @@ -40,10 +40,10 @@ // import "github.com/Aleph-Alpha/std/v1/vectordb" // // type SearchService struct { -// db vectordb.VectorDBService +// db vectordb.Service // } // -// func NewSearchService(db vectordb.VectorDBService) *SearchService { +// func NewSearchService(db vectordb.Service) *SearchService { // return &SearchService{db: db} // } // @@ -52,10 +52,12 @@ // CollectionName: "documents", // Vector: vector, // TopK: 10, -// Filters: &vectordb.FilterSet{ -// Must: &vectordb.ConditionSet{ -// Conditions: []vectordb.FilterCondition{ -// vectordb.NewMatch("status", "published"), +// Filters: []*vectordb.FilterSet{ +// { +// Must: &vectordb.ConditionSet{ +// Conditions: []vectordb.FilterCondition{ +// vectordb.NewMatch("status", "published"), +// }, // }, // }, // }, @@ -68,7 +70,7 @@ // // # Wire Up with Qdrant // -// In your main or DI setup: +// In your main setup: // // import ( // "github.com/Aleph-Alpha/std/v1/vectordb" @@ -82,7 +84,7 @@ // }) // // // Create adapter for DB-agnostic usage -// db := qdrant.NewVectorDBAdapter(qc.API()) +// db := qdrant.NewAdapter(qc.Client()) // // // Use in application // svc := NewSearchService(db) @@ -92,17 +94,16 @@ // # Package Layout // // vectordb/ -// ├── interface.go # VectorDBService interface +// ├── interface.go # Service interface // ├── types.go # SearchRequest, SearchResult, EmbeddingInput, Collection -// ├── filters.go # FilterSet, FilterCondition, convenience constructors +// ├── filters.go # FilterSet, FilterCondition, and condition types +// ├── utils.go # Convenience constructors (New*) and JSON helpers // └── doc.go # This file // // qdrant/ # Qdrant package (includes adapter) // ├── client.go # QdrantClient wrapper -// ├── adapter.go # VectorDBAdapter - implements VectorDBService -// ├── vectordb_converter.go # vectordb types → qdrant types -// ├── operations.go # Direct Qdrant operations -// ├── filters.go # Qdrant-specific filters +// ├── operations.go # Adapter - implements Service +// ├── converter.go # vectordb types → qdrant types // └── ... // // Future adapters would live in their own packages: @@ -132,7 +133,7 @@ // // User-defined field (stored under "custom." prefix) // vectordb.NewUserMatch("category", "research") // -// // Range conditions -// vectordb.NewNumericRange("price", &min, &max) -// vectordb.NewTimeRange("created_at", &startTime, &endTime) +// // Range conditions with NumericRange/TimeRange structs +// vectordb.NewNumericRange("price", vectordb.NumericRange{Gte: &min, Lt: &max}) +// vectordb.NewTimeRange("created_at", vectordb.TimeRange{AtOrAfter: &start, Before: &end}) package vectordb diff --git a/v1/vectordb/filters.go b/v1/vectordb/filters.go index 0e156ad..2e2f9b0 100644 --- a/v1/vectordb/filters.go +++ b/v1/vectordb/filters.go @@ -81,34 +81,126 @@ type MatchExceptCondition struct { func (c *MatchExceptCondition) isFilterCondition() {} +// ── Range Types ────────────────────────────────────────────────────────────── + +// NumericRange defines bounds for numeric filtering. +// Used with NewNumericRange for cleaner constructor calls. +type NumericRange struct { + Gt *float64 `json:"greaterThan,omitempty"` // GreaterThan (exclusive) + Gte *float64 `json:"greaterThanOrEqualTo,omitempty"` // GreaterThanOrEqualTo (inclusive) + Lt *float64 `json:"lessThan,omitempty"` // LessThan (exclusive) + Lte *float64 `json:"lessThanOrEqualTo,omitempty"` // LessThanOrEqualTo (inclusive) +} + +// TimeRange defines bounds for time filtering. +// Used with NewTimeRange for cleaner constructor calls. +type TimeRange struct { + Gt *time.Time `json:"after,omitempty"` // After (exclusive) + Gte *time.Time `json:"atOrAfter,omitempty"` // AtOrAfter (inclusive) + Lt *time.Time `json:"before,omitempty"` // Before (exclusive) + Lte *time.Time `json:"atOrBefore,omitempty"` // AtOrBefore (inclusive) +} + // ── Range Conditions ───────────────────────────────────────────────────────── // NumericRangeCondition filters by numeric range. // SQL equivalent: WHERE field >= min AND field <= max type NumericRangeCondition struct { - Field string `json:"field"` - GreaterThan *float64 `json:"greaterThan,omitempty"` - GreaterThanOrEqualTo *float64 `json:"greaterThanOrEqualTo,omitempty"` - LessThan *float64 `json:"lessThan,omitempty"` - LessThanOrEqualTo *float64 `json:"lessThanOrEqualTo,omitempty"` - FieldType FieldType `json:"-"` + Field string `json:"field"` + Range NumericRange `json:"-"` + FieldType FieldType `json:"-"` } func (c *NumericRangeCondition) isFilterCondition() {} +func (c *NumericRangeCondition) MarshalJSON() ([]byte, error) { + type Alias struct { + Field string `json:"field"` + GreaterThan *float64 `json:"greaterThan,omitempty"` + GreaterThanOrEqualTo *float64 `json:"greaterThanOrEqualTo,omitempty"` + LessThan *float64 `json:"lessThan,omitempty"` + LessThanOrEqualTo *float64 `json:"lessThanOrEqualTo,omitempty"` + } + return json.Marshal(Alias{ + Field: c.Field, + GreaterThan: c.Range.Gt, + GreaterThanOrEqualTo: c.Range.Gte, + LessThan: c.Range.Lt, + LessThanOrEqualTo: c.Range.Lte, + }) +} + +func (c *NumericRangeCondition) UnmarshalJSON(data []byte) error { + type Alias struct { + Field string `json:"field"` + GreaterThan *float64 `json:"greaterThan,omitempty"` + GreaterThanOrEqualTo *float64 `json:"greaterThanOrEqualTo,omitempty"` + LessThan *float64 `json:"lessThan,omitempty"` + LessThanOrEqualTo *float64 `json:"lessThanOrEqualTo,omitempty"` + } + var alias Alias + if err := json.Unmarshal(data, &alias); err != nil { + return err + } + c.Field = alias.Field + c.Range = NumericRange{ + Gt: alias.GreaterThan, + Gte: alias.GreaterThanOrEqualTo, + Lt: alias.LessThan, + Lte: alias.LessThanOrEqualTo, + } + return nil +} + // TimeRangeCondition filters by datetime range. // SQL equivalent: WHERE created_at >= '2024-01-01' AND created_at < '2025-01-01' type TimeRangeCondition struct { - Field string `json:"field"` - After *time.Time `json:"after,omitempty"` - AtOrAfter *time.Time `json:"atOrAfter,omitempty"` - Before *time.Time `json:"before,omitempty"` - AtOrBefore *time.Time `json:"atOrBefore,omitempty"` - FieldType FieldType `json:"-"` + Field string `json:"field"` + Range TimeRange `json:"-"` + FieldType FieldType `json:"-"` } func (c *TimeRangeCondition) isFilterCondition() {} +func (c TimeRangeCondition) MarshalJSON() ([]byte, error) { + type Alias struct { + Field string `json:"field"` + After *time.Time `json:"after,omitempty"` + AtOrAfter *time.Time `json:"atOrAfter,omitempty"` + Before *time.Time `json:"before,omitempty"` + AtOrBefore *time.Time `json:"atOrBefore,omitempty"` + } + return json.Marshal(Alias{ + Field: c.Field, + After: c.Range.Gt, + AtOrAfter: c.Range.Gte, + Before: c.Range.Lt, + AtOrBefore: c.Range.Lte, + }) +} + +func (c *TimeRangeCondition) UnmarshalJSON(data []byte) error { + type Alias struct { + Field string `json:"field"` + After *time.Time `json:"after,omitempty"` + AtOrAfter *time.Time `json:"atOrAfter,omitempty"` + Before *time.Time `json:"before,omitempty"` + AtOrBefore *time.Time `json:"atOrBefore,omitempty"` + } + var alias Alias + if err := json.Unmarshal(data, &alias); err != nil { + return err + } + c.Field = alias.Field + c.Range = TimeRange{ + Gt: alias.After, + Gte: alias.AtOrAfter, + Lt: alias.Before, + Lte: alias.AtOrBefore, + } + return nil +} + // ── Null/Empty Conditions ──────────────────────────────────────────────────── // IsNullCondition checks if a field has a NULL value. @@ -128,113 +220,3 @@ type IsEmptyCondition struct { } func (c *IsEmptyCondition) isFilterCondition() {} - -// ── Convenience Constructors ───────────────────────────────────────────────── - -// NewMatch creates a match condition for internal fields. -func NewMatch(field string, value any) *MatchCondition { - return &MatchCondition{Field: field, Value: value, FieldType: InternalField} -} - -// NewUserMatch creates a match condition for user-defined fields. -func NewUserMatch(field string, value any) *MatchCondition { - return &MatchCondition{Field: field, Value: value, FieldType: UserField} -} - -// NewMatchAny creates an IN condition for internal fields. -func NewMatchAny(field string, values ...any) *MatchAnyCondition { - return &MatchAnyCondition{Field: field, Values: values, FieldType: InternalField} -} - -// NewUserMatchAny creates an IN condition for user-defined fields. -func NewUserMatchAny(field string, values ...any) *MatchAnyCondition { - return &MatchAnyCondition{Field: field, Values: values, FieldType: UserField} -} - -// NewMatchExcept creates a NOT IN condition for internal fields. -func NewMatchExcept(field string, values ...any) *MatchExceptCondition { - return &MatchExceptCondition{Field: field, Values: values, FieldType: InternalField} -} - -// NewUserMatchExcept creates a NOT IN condition for user-defined fields. -func NewUserMatchExcept(field string, values ...any) *MatchExceptCondition { - return &MatchExceptCondition{Field: field, Values: values, FieldType: UserField} -} - -// NewNumericRange creates a numeric range condition for internal fields. -func NewNumericRange(field string, gte, lte *float64) *NumericRangeCondition { - return &NumericRangeCondition{ - Field: field, - GreaterThanOrEqualTo: gte, - LessThanOrEqualTo: lte, - FieldType: InternalField, - } -} - -// NewUserNumericRange creates a numeric range condition for user-defined fields. -func NewUserNumericRange(field string, gte, lte *float64) *NumericRangeCondition { - return &NumericRangeCondition{ - Field: field, - GreaterThanOrEqualTo: gte, - LessThanOrEqualTo: lte, - FieldType: UserField, - } -} - -// NewTimeRange creates a time range condition for internal fields. -func NewTimeRange(field string, atOrAfter, before *time.Time) *TimeRangeCondition { - return &TimeRangeCondition{ - Field: field, - AtOrAfter: atOrAfter, - Before: before, - FieldType: InternalField, - } -} - -// NewUserTimeRange creates a time range condition for user-defined fields. -func NewUserTimeRange(field string, atOrAfter, before *time.Time) *TimeRangeCondition { - return &TimeRangeCondition{ - Field: field, - AtOrAfter: atOrAfter, - Before: before, - FieldType: UserField, - } -} - -// NewIsNull creates an IS NULL condition for internal fields. -func NewIsNull(field string) *IsNullCondition { - return &IsNullCondition{Field: field, FieldType: InternalField} -} - -// NewUserIsNull creates an IS NULL condition for user-defined fields. -func NewUserIsNull(field string) *IsNullCondition { - return &IsNullCondition{Field: field, FieldType: UserField} -} - -// NewIsEmpty creates an IS EMPTY condition for internal fields. -func NewIsEmpty(field string) *IsEmptyCondition { - return &IsEmptyCondition{Field: field, FieldType: InternalField} -} - -// NewUserIsEmpty creates an IS EMPTY condition for user-defined fields. -func NewUserIsEmpty(field string) *IsEmptyCondition { - return &IsEmptyCondition{Field: field, FieldType: UserField} -} - -// ── JSON Serialization ─────────────────────────────────────────────────────── - -// MarshalJSON implements custom JSON marshaling for ConditionSet. -// This is needed because FilterCondition is an interface. -func (cs *ConditionSet) MarshalJSON() ([]byte, error) { - return json.Marshal(cs.Conditions) -} - -// UnmarshalJSON implements custom JSON unmarshaling for ConditionSet. -func (cs *ConditionSet) UnmarshalJSON(data []byte) error { - var raw []json.RawMessage - if err := json.Unmarshal(data, &raw); err != nil { - return err - } - cs.Conditions = make([]FilterCondition, 0, len(raw)) - return nil -} diff --git a/v1/vectordb/interface.go b/v1/vectordb/interface.go index 8886b1d..238c4d6 100644 --- a/v1/vectordb/interface.go +++ b/v1/vectordb/interface.go @@ -9,14 +9,14 @@ import "context" // // Example usage: // -// func NewSearchService(db vectordb.VectorDBService) *SearchService { +// func NewSearchService(db vectordb.Service) *SearchService { // return &SearchService{db: db} // } // // // Works with any implementation: // // - vectordb.NewQdrantAdapter(qdrantClient) // // - vectordb.NewWeaviateAdapter(weaviateClient) -type VectorDBService interface { +type Service interface { // Search performs similarity search across one or more requests. // Each request can target a different collection with different filters. // Returns a slice of result slices—one []SearchResult per request. @@ -32,7 +32,7 @@ type VectorDBService interface { // Insert adds embeddings to a collection. // Uses batch processing internally for efficiency. - Insert(ctx context.Context, collection string, inputs []EmbeddingInput) error + Insert(ctx context.Context, collectionName string, inputs []EmbeddingInput) error // Delete removes points by their IDs from a collection. Delete(ctx context.Context, collection string, ids []string) error diff --git a/v1/vectordb/types.go b/v1/vectordb/types.go index 9439407..36d5f94 100644 --- a/v1/vectordb/types.go +++ b/v1/vectordb/types.go @@ -13,7 +13,7 @@ type SearchRequest struct { TopK int `json:"maxResults"` // Filters is optional metadata filtering (AND/OR/NOT logic) - Filters *FilterSet `json:"filters,omitempty"` + Filters []*FilterSet `json:"filters,omitempty"` } // SearchResult represents a single search result with its similarity score. diff --git a/v1/vectordb/utils.go b/v1/vectordb/utils.go new file mode 100644 index 0000000..1a55485 --- /dev/null +++ b/v1/vectordb/utils.go @@ -0,0 +1,233 @@ +package vectordb + +import ( + "encoding/json" + "fmt" +) + +// ── FilterSet Constructors ─────────────────────────────────────────────────── + +// NewFilterSet creates a FilterSet with the given clauses. +// Use with Must(), Should(), and MustNot() helpers. +// +// Example: +// +// vectordb.NewFilterSet( +// vectordb.Must(vectordb.NewMatch("status", "published")), +// vectordb.Should(vectordb.NewMatch("tag", "ml"), vectordb.NewMatch("tag", "ai")), +// ) +func NewFilterSet(clauses ...func(*FilterSet)) *FilterSet { + fs := &FilterSet{} + for _, clause := range clauses { + clause(fs) + } + return fs +} + +// Must creates a Must clause (AND logic) with the given conditions. +// All conditions must match for a document to be included. +func Must(conditions ...FilterCondition) func(*FilterSet) { + return func(fs *FilterSet) { + fs.Must = &ConditionSet{Conditions: conditions} + } +} + +// Should creates a Should clause (OR logic) with the given conditions. +// At least one condition must match for a document to be included. +func Should(conditions ...FilterCondition) func(*FilterSet) { + return func(fs *FilterSet) { + fs.Should = &ConditionSet{Conditions: conditions} + } +} + +// MustNot creates a MustNot clause (NOT logic) with the given conditions. +// Documents matching any of these conditions are excluded. +func MustNot(conditions ...FilterCondition) func(*FilterSet) { + return func(fs *FilterSet) { + fs.MustNot = &ConditionSet{Conditions: conditions} + } +} + +// ── Condition Constructors ─────────────────────────────────────────────────── + +// NewMatch creates a match condition for internal fields. +func NewMatch(field string, value any) *MatchCondition { + return &MatchCondition{Field: field, Value: value, FieldType: InternalField} +} + +// NewUserMatch creates a match condition for user-defined fields. +func NewUserMatch(field string, value any) *MatchCondition { + return &MatchCondition{Field: field, Value: value, FieldType: UserField} +} + +// NewMatchAny creates an IN condition for internal fields. +func NewMatchAny(field string, values ...any) *MatchAnyCondition { + return &MatchAnyCondition{Field: field, Values: values, FieldType: InternalField} +} + +// NewUserMatchAny creates an IN condition for user-defined fields. +func NewUserMatchAny(field string, values ...any) *MatchAnyCondition { + return &MatchAnyCondition{Field: field, Values: values, FieldType: UserField} +} + +// NewMatchExcept creates a NOT IN condition for internal fields. +func NewMatchExcept(field string, values ...any) *MatchExceptCondition { + return &MatchExceptCondition{Field: field, Values: values, FieldType: InternalField} +} + +// NewUserMatchExcept creates a NOT IN condition for user-defined fields. +func NewUserMatchExcept(field string, values ...any) *MatchExceptCondition { + return &MatchExceptCondition{Field: field, Values: values, FieldType: UserField} +} + +// NewNumericRange creates a numeric range condition for internal fields. +func NewNumericRange(field string, r NumericRange) *NumericRangeCondition { + return &NumericRangeCondition{ + Field: field, + Range: r, + FieldType: InternalField, + } +} + +// NewUserNumericRange creates a numeric range condition for user-defined fields. +func NewUserNumericRange(field string, r NumericRange) *NumericRangeCondition { + return &NumericRangeCondition{ + Field: field, + Range: r, + FieldType: UserField, + } +} + +// NewTimeRange creates a time range condition for internal fields. +func NewTimeRange(field string, t TimeRange) *TimeRangeCondition { + return &TimeRangeCondition{ + Field: field, + Range: t, + FieldType: InternalField, + } +} + +// NewUserTimeRange creates a time range condition for user-defined fields. +func NewUserTimeRange(field string, t TimeRange) *TimeRangeCondition { + return &TimeRangeCondition{ + Field: field, + Range: t, + FieldType: UserField, + } +} + +// NewIsNull creates an IS NULL condition for internal fields. +func NewIsNull(field string) *IsNullCondition { + return &IsNullCondition{Field: field, FieldType: InternalField} +} + +// NewUserIsNull creates an IS NULL condition for user-defined fields. +func NewUserIsNull(field string) *IsNullCondition { + return &IsNullCondition{Field: field, FieldType: UserField} +} + +// NewIsEmpty creates an IS EMPTY condition for internal fields. +func NewIsEmpty(field string) *IsEmptyCondition { + return &IsEmptyCondition{Field: field, FieldType: InternalField} +} + +// NewUserIsEmpty creates an IS EMPTY condition for user-defined fields. +func NewUserIsEmpty(field string) *IsEmptyCondition { + return &IsEmptyCondition{Field: field, FieldType: UserField} +} + +// ── JSON Serialization ─────────────────────────────────────────────────────── + +// MarshalJSON implements custom JSON marshaling for ConditionSet. +// This is needed because FilterCondition is an interface. +func (cs *ConditionSet) MarshalJSON() ([]byte, error) { + return json.Marshal(cs.Conditions) +} + +// UnmarshalJSON implements custom JSON unmarshaling for ConditionSet. +// It detects the condition type based on JSON keys and deserializes +// into the appropriate concrete type (MatchCondition, NumericRangeCondition, etc.) +func (cs *ConditionSet) UnmarshalJSON(data []byte) error { + var raw []json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + cs.Conditions = make([]FilterCondition, 0, len(raw)) + + for _, r := range raw { + cond, err := parseCondition(r) + if err != nil { + return err + } + cs.Conditions = append(cs.Conditions, cond) + } + + return nil +} + +// parseCondition detects and parses a single FilterCondition from JSON. +// It examines the JSON keys to determine the condition type: +// - "equalTo" → MatchCondition +// - "anyOf" → MatchAnyCondition +// - "noneOf" → MatchExceptCondition +// - "greaterThan", "lessThan", etc. → NumericRangeCondition +// - "after", "before", etc. → TimeRangeCondition +// - "isNotNull" → IsNullCondition +// - "isEmpty" → IsEmptyCondition +func parseCondition(data []byte) (FilterCondition, error) { + // Extract field names to determine condition type + var fields map[string]json.RawMessage + if err := json.Unmarshal(data, &fields); err != nil { + return nil, err + } + + switch { + case hasKey(fields, "equalTo"): + var c MatchCondition + if err := json.Unmarshal(data, &c); err != nil { + return nil, err + } + return &c, nil + + case hasKey(fields, "anyOf"): + var c MatchAnyCondition + if err := json.Unmarshal(data, &c); err != nil { + return nil, err + } + return &c, nil + + case hasKey(fields, "noneOf"): + var c MatchExceptCondition + if err := json.Unmarshal(data, &c); err != nil { + return nil, err + } + return &c, nil + + case hasKey(fields, "greaterThan"), hasKey(fields, "greaterThanOrEqualTo"), + hasKey(fields, "lessThan"), hasKey(fields, "lessThanOrEqualTo"): + var c NumericRangeCondition + if err := json.Unmarshal(data, &c); err != nil { + return nil, err + } + return &c, nil + + case hasKey(fields, "after"), hasKey(fields, "atOrAfter"), + hasKey(fields, "before"), hasKey(fields, "atOrBefore"): + var c TimeRangeCondition + if err := json.Unmarshal(data, &c); err != nil { + return nil, err + } + return &c, nil + + default: + return nil, fmt.Errorf("unknown filter condition type: %s", string(data)) + } +} + +// hasKey checks if a JSON object contains a specific key. +// Used by parseCondition to detect filter condition types. +func hasKey(m map[string]json.RawMessage, key string) bool { + _, ok := m[key] + return ok +} From 95d6abc96c031bb5caa19aada5196c99c2432787 Mon Sep 17 00:00:00 2001 From: Shinu Joseph Date: Mon, 15 Dec 2025 16:04:16 +0100 Subject: [PATCH 3/4] chore: minor improvements --- v1/qdrant/doc.go | 2 +- v1/qdrant/operations.go | 4 ++-- v1/vectordb/interface.go | 2 +- v1/vectordb/types.go | 2 +- v1/vectordb/utils.go | 47 ++++++++++++++++++++++++++++++++++++++-- 5 files changed, 50 insertions(+), 7 deletions(-) diff --git a/v1/qdrant/doc.go b/v1/qdrant/doc.go index b371ed8..9dc3348 100644 --- a/v1/qdrant/doc.go +++ b/v1/qdrant/doc.go @@ -60,7 +60,7 @@ // } // // // Create adapter -// adapter := qdrant.NewAdapter(client.API()) +// adapter := qdrant.NewAdapter(client.Client()) // // collectionName := "documents" // diff --git a/v1/qdrant/operations.go b/v1/qdrant/operations.go index 0a22f5e..0378a17 100644 --- a/v1/qdrant/operations.go +++ b/v1/qdrant/operations.go @@ -12,7 +12,7 @@ import ( "golang.org/x/sync/semaphore" ) -// Ensure VectorDBAdapter implements VectorDBService at compile time +// Ensure Adapter implements Service at compile time var _ vectordb.Service = (*Adapter)(nil) // ══════════════════════════════════════════════════════════════════════════════ @@ -30,7 +30,7 @@ type Adapter struct { } // NewAdapter creates a new Qdrant adapter for the vectordb interface. -// Pass the underlying SDK client via QdrantClient.API(). +// Pass the underlying SDK client via QdrantClient.Client(). // // Example: // diff --git a/v1/vectordb/interface.go b/v1/vectordb/interface.go index 238c4d6..de0d9db 100644 --- a/v1/vectordb/interface.go +++ b/v1/vectordb/interface.go @@ -2,7 +2,7 @@ package vectordb import "context" -// VectorDBService is the common interface for all vector databases. +// Service is the common interface for all vector databases. // It provides a database-agnostic abstraction for vector similarity search, // allowing applications to switch between different vector databases // (Qdrant, Weaviate, Pinecone, etc.) without changing application code. diff --git a/v1/vectordb/types.go b/v1/vectordb/types.go index 36d5f94..d0ef969 100644 --- a/v1/vectordb/types.go +++ b/v1/vectordb/types.go @@ -1,7 +1,7 @@ package vectordb // SearchRequest represents a single similarity search query. -// Use with VectorDBService.Search() for single or batch queries. +// Use with Service.Search() for single or batch queries. type SearchRequest struct { // CollectionName is the target collection to search in CollectionName string `json:"collectionName"` diff --git a/v1/vectordb/utils.go b/v1/vectordb/utils.go index 1a55485..e1328ba 100644 --- a/v1/vectordb/utils.go +++ b/v1/vectordb/utils.go @@ -62,21 +62,25 @@ func NewUserMatch(field string, value any) *MatchCondition { // NewMatchAny creates an IN condition for internal fields. func NewMatchAny(field string, values ...any) *MatchAnyCondition { + validateHomogeneousTypes(values) return &MatchAnyCondition{Field: field, Values: values, FieldType: InternalField} } // NewUserMatchAny creates an IN condition for user-defined fields. func NewUserMatchAny(field string, values ...any) *MatchAnyCondition { + validateHomogeneousTypes(values) return &MatchAnyCondition{Field: field, Values: values, FieldType: UserField} } // NewMatchExcept creates a NOT IN condition for internal fields. func NewMatchExcept(field string, values ...any) *MatchExceptCondition { + validateHomogeneousTypes(values) return &MatchExceptCondition{Field: field, Values: values, FieldType: InternalField} } // NewUserMatchExcept creates a NOT IN condition for user-defined fields. func NewUserMatchExcept(field string, values ...any) *MatchExceptCondition { + validateHomogeneousTypes(values) return &MatchExceptCondition{Field: field, Values: values, FieldType: UserField} } @@ -173,8 +177,6 @@ func (cs *ConditionSet) UnmarshalJSON(data []byte) error { // - "noneOf" → MatchExceptCondition // - "greaterThan", "lessThan", etc. → NumericRangeCondition // - "after", "before", etc. → TimeRangeCondition -// - "isNotNull" → IsNullCondition -// - "isEmpty" → IsEmptyCondition func parseCondition(data []byte) (FilterCondition, error) { // Extract field names to determine condition type var fields map[string]json.RawMessage @@ -231,3 +233,44 @@ func hasKey(m map[string]json.RawMessage, key string) bool { _, ok := m[key] return ok } + +// validateHomogeneousTypes ensures all values are of the same type category. +// Panics if mixed types are detected - this catches programming errors early. +// +// TODO: Consider whether panic is appropriate here, or if we should: +// - Return an error instead (for runtime data validation) +// - Add a separate NewMatchAnyChecked() that returns error +// - Keep panic for constructor calls, error for JSON unmarshaling +func validateHomogeneousTypes(values []any) { + if len(values) <= 1 { + return + } + + expectedType := getType(values[0]) + if expectedType == "" { + panic(fmt.Sprintf("vectordb: unsupported value type: %T", values[0])) + } + + // Validate all values match expected type + for i, v := range values[1:] { + actualType := getType(v) + if actualType == "" { + panic(fmt.Sprintf("vectordb: unsupported value type at index %d: %T", i+1, v)) + } + if actualType != expectedType { + panic(fmt.Sprintf("vectordb: mixed types not allowed in MatchAny/MatchExcept: expected %s but got %s at index %d", expectedType, actualType, i+1)) + } + } +} + +func getType(value any) string { + switch value.(type) { + case string: + return "string" + case int, int64, float64: + return "numeric" + case bool: + return "boolean" + } + return "" +} From 840ca0ff05dd248d42d873c514ba72f667c590d9 Mon Sep 17 00:00:00 2001 From: Shinu Joseph Date: Mon, 22 Dec 2025 11:18:02 +0100 Subject: [PATCH 4/4] chore: fix search to return partial results --- docs/v1/qdrant.md | 2 +- v1/qdrant/doc.go | 4 +-- v1/qdrant/operations.go | 37 +++++++++++++++++----------- v1/qdrant/qdrant_integration_test.go | 33 ++++++++++++++----------- v1/vectordb/doc.go | 21 ++++++++-------- v1/vectordb/filters.go | 16 ++++++------ v1/vectordb/interface.go | 25 +++++++++++++------ 7 files changed, 80 insertions(+), 58 deletions(-) diff --git a/docs/v1/qdrant.md b/docs/v1/qdrant.md index 0ece18e..26bb5ba 100644 --- a/docs/v1/qdrant.md +++ b/docs/v1/qdrant.md @@ -19,7 +19,7 @@ Core Features: - Vector similarity search with abstracted SearchResult interface - Type\-safe collection creation and existence checks - Support for payload metadata and optional vector retrieval -- Extensible abstraction layer for alternate vector stores \(e.g., Pinecone, Postgres\) +- Extensible abstraction layer for alternate vector stores \(e.g. pgVector\) Basic Usage: diff --git a/v1/qdrant/doc.go b/v1/qdrant/doc.go index 9dc3348..78cf98a 100644 --- a/v1/qdrant/doc.go +++ b/v1/qdrant/doc.go @@ -14,7 +14,7 @@ // - Database-agnostic interface via vectordb.Service // - Type-safe collection creation and existence checks // - Support for payload metadata and optional vector retrieval -// - Extensible abstraction layer for alternate vector stores (e.g., Pinecone, Weaviate) +// - Extensible abstraction layer for alternate vector stores (e.g. pgVector) // // # VectorDB Interface // @@ -38,7 +38,7 @@ // // Create adapter for DB-agnostic usage // var db vectordb.Service = qdrant.NewAdapter(qc.Client()) // -// This allows switching between vector databases (Qdrant, Weaviate, Pinecone) without +// This allows switching between vector databases (Qdrant, pgVector) without // changing application code. // // # Basic Usage diff --git a/v1/qdrant/operations.go b/v1/qdrant/operations.go index 0378a17..7a51989 100644 --- a/v1/qdrant/operations.go +++ b/v1/qdrant/operations.go @@ -5,10 +5,10 @@ import ( "fmt" "log" "slices" + "sync" "github.com/Aleph-Alpha/std/v1/vectordb" qdrant "github.com/qdrant/go-client/qdrant" - "golang.org/x/sync/errgroup" "golang.org/x/sync/semaphore" ) @@ -42,9 +42,9 @@ func NewAdapter(client *qdrant.Client) *Adapter { } // Search performs similarity search across one or more requests. -func (a *Adapter) Search(ctx context.Context, requests ...vectordb.SearchRequest) ([][]vectordb.SearchResult, error) { +func (a *Adapter) Search(ctx context.Context, requests ...vectordb.SearchRequest) ([][]vectordb.SearchResult, []error, error) { if len(requests) == 0 { - return nil, fmt.Errorf("at least one search request is required") + return nil, nil, fmt.Errorf("at least one search request is required") } log.Printf("[Qdrant] Starting search batch with %d requests", len(requests)) @@ -52,41 +52,50 @@ func (a *Adapter) Search(ctx context.Context, requests ...vectordb.SearchRequest // Validate all requests first for i, searchReq := range requests { if err := validateSearchInput(searchReq.CollectionName, searchReq.Vector, searchReq.TopK); err != nil { - return nil, fmt.Errorf("request [%d]: %w", i, err) + return nil, nil, fmt.Errorf("request [%d]: %w", i, err) } } results := make([][]vectordb.SearchResult, len(requests)) + errs := make([]error, len(requests)) - // Create errgroup with context - g, ctx := errgroup.WithContext(ctx) + // Use WaitGroup for partial results + var wg sync.WaitGroup // Create semaphore to limit concurrent searches sem := semaphore.NewWeighted(maxConcurrentSearches) for i, searchReq := range requests { i, searchReq := i, searchReq // Capture loop variables - g.Go(func() error { + wg.Add(1) + go func() { + defer wg.Done() // Acquire semaphore (blocks if at max concurrency) if err := sem.Acquire(ctx, 1); err != nil { - return fmt.Errorf("request [%d]: failed to acquire semaphore: %w", i, err) + errs[i] = fmt.Errorf("request [%d]: failed to acquire semaphore: %w", i, err) + results[i] = []vectordb.SearchResult{} + return } defer sem.Release(1) res, err := searchInternal(ctx, a.client, searchReq) if err != nil { - return fmt.Errorf("request [%d]: search failed: %w", i, err) + errs[i] = fmt.Errorf("request [%d]: search failed: %w", i, err) + results[i] = []vectordb.SearchResult{} + return } results[i] = res log.Printf("[Qdrant] Search request [%d] returned %d results", i, len(res)) - return nil - }) + }() } - if err := g.Wait(); err != nil { - return nil, fmt.Errorf("search batch failed: %w", err) + wg.Wait() + + // Check for systemic failure (context cancelled) + if ctx.Err() != nil { + return results, errs, fmt.Errorf("search batch interrupted: %w", ctx.Err()) } - return results, nil + return results, errs, nil } // Insert adds embeddings to a collection using batch processing. diff --git a/v1/qdrant/qdrant_integration_test.go b/v1/qdrant/qdrant_integration_test.go index 4ce996f..69bf199 100644 --- a/v1/qdrant/qdrant_integration_test.go +++ b/v1/qdrant/qdrant_integration_test.go @@ -229,12 +229,13 @@ func TestQdrantWithFXModule(t *testing.T) { // Search for the inserted embedding time.Sleep(1 * time.Second) // Allow time for indexing - batchResults, err := adapter.Search(ctx, vectordb.SearchRequest{ + batchResults, errs, err := adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: collectionName, Vector: embedding.Vector, TopK: 5, }) assert.NoError(t, err) + assert.Nil(t, errs[0]) assert.Greater(t, len(batchResults), 0) results := batchResults[0] assert.Greater(t, len(results), 0) @@ -275,7 +276,7 @@ func TestQdrantWithFXModule(t *testing.T) { // Search and verify time.Sleep(1 * time.Second) // Allow time for indexing - batchResults, err := adapter.Search(ctx, vectordb.SearchRequest{ + batchResults, _, err := adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: collectionName, Vector: embeddings[0].Vector, TopK: 10, @@ -393,7 +394,7 @@ func TestVectorDBAdapterOperations(t *testing.T) { time.Sleep(1 * time.Second) // Allow time for indexing // Search with topK = 5 - batchResults, err := adapter.Search(ctx, vectordb.SearchRequest{ + batchResults, _, err := adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: collectionName, Vector: embeddings[0].Vector, TopK: 5, @@ -402,7 +403,7 @@ func TestVectorDBAdapterOperations(t *testing.T) { assert.LessOrEqual(t, len(batchResults[0]), 5) // Search with topK = 10 - batchResults, err = adapter.Search(ctx, vectordb.SearchRequest{ + batchResults, _, err = adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: collectionName, Vector: embeddings[0].Vector, TopK: 10, @@ -438,7 +439,7 @@ func TestVectorDBAdapterOperations(t *testing.T) { time.Sleep(1 * time.Second) // Search and verify metadata - batchResults, err := adapter.Search(ctx, vectordb.SearchRequest{ + batchResults, _, err := adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: collectionName, Vector: embedding.Vector, TopK: 1, @@ -479,7 +480,7 @@ func TestVectorDBAdapterOperations(t *testing.T) { time.Sleep(2 * time.Second) // Verify some embeddings exist - batchResults, err := adapter.Search(ctx, vectordb.SearchRequest{ + batchResults, _, err := adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: collectionName, Vector: embeddings[0].Vector, TopK: 10, @@ -550,12 +551,14 @@ func TestQdrantErrorHandling(t *testing.T) { t.Run("SearchOnNonExistentCollection", func(t *testing.T) { vector := generateRandomVector(1536) - _, err := adapter.Search(ctx, vectordb.SearchRequest{ + _, errs, err := adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: "non_existent_collection", Vector: vector, TopK: 5, }) - assert.Error(t, err) + // With partial results, systemic error is nil but per-request error exists + assert.NoError(t, err) + assert.NotNil(t, errs[0]) }) } @@ -670,7 +673,7 @@ func TestFilterOperations(t *testing.T) { time.Sleep(1 * time.Second) // Test Must filter (color == "red") - batchResults, err := adapter.Search(ctx, vectordb.SearchRequest{ + batchResults, _, err := adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: collectionName, Vector: embeddings[0].Vector, TopK: 10, @@ -706,7 +709,7 @@ func TestFilterOperations(t *testing.T) { // Test numeric range filter (size >= 20) gte := float64(20) - batchResults, err := adapter.Search(ctx, vectordb.SearchRequest{ + batchResults, _, err := adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: collectionName, Vector: embeddings[0].Vector, TopK: 10, @@ -741,7 +744,7 @@ func TestFilterOperations(t *testing.T) { time.Sleep(1 * time.Second) // Exclude archived items - batchResults, err := adapter.Search(ctx, vectordb.SearchRequest{ + batchResults, _, err := adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: collectionName, Vector: embeddings[0].Vector, TopK: 10, @@ -777,7 +780,7 @@ func TestFilterOperations(t *testing.T) { // color == red AND size < 20 lt := float64(20) - batchResults, err := adapter.Search(ctx, vectordb.SearchRequest{ + batchResults, _, err := adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: collectionName, Vector: embeddings[0].Vector, TopK: 10, @@ -811,7 +814,7 @@ func TestFilterOperations(t *testing.T) { time.Sleep(1 * time.Second) // Match any of ["ml", "nlp"] - batchResults, err := adapter.Search(ctx, vectordb.SearchRequest{ + batchResults, _, err := adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: collectionName, Vector: embeddings[0].Vector, TopK: 10, @@ -846,7 +849,7 @@ func TestFilterOperations(t *testing.T) { time.Sleep(1 * time.Second) // Should match red OR blue - batchResults, err := adapter.Search(ctx, vectordb.SearchRequest{ + batchResults, _, err := adapter.Search(ctx, vectordb.SearchRequest{ CollectionName: collectionName, Vector: embeddings[0].Vector, TopK: 10, @@ -919,7 +922,7 @@ func TestBatchSearch(t *testing.T) { time.Sleep(1 * time.Second) // Batch search with multiple queries - batchResults, err := adapter.Search(ctx, + batchResults, _, err := adapter.Search(ctx, vectordb.SearchRequest{CollectionName: collectionName, Vector: embeddings[0].Vector, TopK: 3}, vectordb.SearchRequest{CollectionName: collectionName, Vector: embeddings[1].Vector, TopK: 3}, vectordb.SearchRequest{CollectionName: collectionName, Vector: embeddings[2].Vector, TopK: 3}, diff --git a/v1/vectordb/doc.go b/v1/vectordb/doc.go index 7432dbb..756c31f 100644 --- a/v1/vectordb/doc.go +++ b/v1/vectordb/doc.go @@ -3,28 +3,28 @@ // # Overview // // This package defines a common interface [Service] that can be implemented -// by different vector database adapters (Qdrant, Weaviate, Pinecone, etc.), allowing +// by different vector database adapters (Qdrant, pgVector, etc.), allowing // applications to switch between databases without changing application code. // // # Architecture // // ┌─────────────────────────────────────────────────────────────┐ // │ Application Layer │ -// │ (uses vectordb.Service - no DB-specific imports) │ +// │ (uses vectordb.Service - no DB-specific imports) │ // └──────────────────────────┬──────────────────────────────────┘ // │ // ▼ // ┌─────────────────────────────────────────────────────────────┐ -// │ vectordb.Service │ +// │ vectordb.Service │ // │ (common interface + DB-agnostic types) │ // └──────────────────────────┬──────────────────────────────────┘ // │ -// ┌──────────────────┼──────────────────┐ -// ▼ ▼ ▼ -// ┌───────────────┐ ┌───────────────┐ ┌───────────────┐ -// │ qdrant.Adapter│ │weaviate.Adapter│ │pinecone.Adapter│ -// │ (implements) │ │ (implements) │ │ (implements) │ -// └───────────────┘ └───────────────┘ └───────────────┘ +// ┌────────────┴────────────┐ +// ▼ ▼ +// ┌───────────────┐ ┌───────────────-┐ +// │ qdrant.Adapter│ │pgvector.Adapter│ +// │ (implements) │ │ (planned) │ +// └───────────────┘ └───────────────-┘ // // # Benefits // @@ -108,8 +108,7 @@ // // Future adapters would live in their own packages: // -// weaviate/ # (future) Weaviate adapter -// pinecone/ # (future) Pinecone adapter +// pgvector/ # (planned) PostgreSQL pgvector adapter // // # Filter Types // diff --git a/v1/vectordb/filters.go b/v1/vectordb/filters.go index 2e2f9b0..9fab712 100644 --- a/v1/vectordb/filters.go +++ b/v1/vectordb/filters.go @@ -20,7 +20,7 @@ const ( // Each database adapter converts these to its native filter format. type FilterCondition interface { // isFilterCondition is a marker method to ensure type safety - isFilterCondition() + IsFilterCondition() } // FilterSet supports Must (AND), Should (OR), and MustNot (NOT) clauses. @@ -59,7 +59,7 @@ type MatchCondition struct { FieldType FieldType `json:"-"` } -func (c *MatchCondition) isFilterCondition() {} +func (c *MatchCondition) IsFilterCondition() {} // MatchAnyCondition matches if value is one of the given values (IN operator). // SQL equivalent: WHERE field IN (value1, value2, ...) @@ -69,7 +69,7 @@ type MatchAnyCondition struct { FieldType FieldType `json:"-"` } -func (c *MatchAnyCondition) isFilterCondition() {} +func (c *MatchAnyCondition) IsFilterCondition() {} // MatchExceptCondition matches if value is NOT one of the given values (NOT IN). // SQL equivalent: WHERE field NOT IN (value1, value2, ...) @@ -79,7 +79,7 @@ type MatchExceptCondition struct { FieldType FieldType `json:"-"` } -func (c *MatchExceptCondition) isFilterCondition() {} +func (c *MatchExceptCondition) IsFilterCondition() {} // ── Range Types ────────────────────────────────────────────────────────────── @@ -111,7 +111,7 @@ type NumericRangeCondition struct { FieldType FieldType `json:"-"` } -func (c *NumericRangeCondition) isFilterCondition() {} +func (c *NumericRangeCondition) IsFilterCondition() {} func (c *NumericRangeCondition) MarshalJSON() ([]byte, error) { type Alias struct { @@ -160,7 +160,7 @@ type TimeRangeCondition struct { FieldType FieldType `json:"-"` } -func (c *TimeRangeCondition) isFilterCondition() {} +func (c *TimeRangeCondition) IsFilterCondition() {} func (c TimeRangeCondition) MarshalJSON() ([]byte, error) { type Alias struct { @@ -210,7 +210,7 @@ type IsNullCondition struct { FieldType FieldType `json:"-"` } -func (c *IsNullCondition) isFilterCondition() {} +func (c *IsNullCondition) IsFilterCondition() {} // IsEmptyCondition checks if a field is empty (doesn't exist, null, or []). // SQL equivalent: WHERE field IS NULL OR field = ” OR field = [] @@ -219,4 +219,4 @@ type IsEmptyCondition struct { FieldType FieldType `json:"-"` } -func (c *IsEmptyCondition) isFilterCondition() {} +func (c *IsEmptyCondition) IsFilterCondition() {} diff --git a/v1/vectordb/interface.go b/v1/vectordb/interface.go index de0d9db..b461d90 100644 --- a/v1/vectordb/interface.go +++ b/v1/vectordb/interface.go @@ -5,7 +5,7 @@ import "context" // Service is the common interface for all vector databases. // It provides a database-agnostic abstraction for vector similarity search, // allowing applications to switch between different vector databases -// (Qdrant, Weaviate, Pinecone, etc.) without changing application code. +// (Qdrant, pgVector, etc.) without changing application code. // // Example usage: // @@ -15,20 +15,31 @@ import "context" // // // Works with any implementation: // // - vectordb.NewQdrantAdapter(qdrantClient) -// // - vectordb.NewWeaviateAdapter(weaviateClient) +// // - vectordb.NewPgVectorAdapter(pgVectorClient) type Service interface { // Search performs similarity search across one or more requests. // Each request can target a different collection with different filters. - // Returns a slice of result slices—one []SearchResult per request. + // Returns: + // - results: slice of result slices—one []SearchResult per request + // - errs: per-request errors (errs[i] corresponds to requests[i]) + // - err: systemic error (context cancelled, etc.) // // Example: - // results, err := db.Search(ctx, + // results, errs, err := db.Search(ctx, // SearchRequest{CollectionName: "docs", Vector: vec1, TopK: 10}, // SearchRequest{CollectionName: "docs", Vector: vec2, TopK: 5, Filters: filters}, // ) - // // results[0] = results for first query - // // results[1] = results for second query - Search(ctx context.Context, requests ...SearchRequest) ([][]SearchResult, error) + // if err != nil { + // return err // systemic failure + // } + // for i, res := range results { + // if errs[i] != nil { + // log.Printf("request %d failed: %v", i, errs[i]) + // continue + // } + // // use res... + // } + Search(ctx context.Context, requests ...SearchRequest) ([][]SearchResult, []error, error) // Insert adds embeddings to a collection. // Uses batch processing internally for efficiency.