Skip to content

Commit

Permalink
Array comparison substitution (#676)
Browse files Browse the repository at this point in the history
Refs #457, #522.
  • Loading branch information
ribaraka committed Jun 7, 2022
1 parent f288ca8 commit 88af5f4
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 75 deletions.
12 changes: 12 additions & 0 deletions integration/query_comparison_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ func TestQueryComparisonImplicit(t *testing.T) {
filter: bson.D{{"value", bson.A{bson.A{int32(42), "foo"}, nil}}},
expectedIDs: []any{"array-embedded"},
},
"LongArrayEmbedded": {
filter: bson.D{{"value", bson.A{bson.A{int32(42), "foo"}, nil, "foo"}}},
expectedIDs: []any{},
},
"ArraySlice": {
filter: bson.D{{"value", bson.A{int32(42), "foo"}}},
expectedIDs: []any{"array-embedded"},
Expand Down Expand Up @@ -256,6 +260,10 @@ func TestQueryComparisonEq(t *testing.T) {
filter: bson.D{{"value", bson.D{{"$eq", bson.A{bson.A{int32(42), "foo"}, nil}}}}},
expectedIDs: []any{"array-embedded"},
},
"LongArrayEmbedded": {
filter: bson.D{{"value", bson.D{{"$eq", bson.A{bson.A{int32(42), "foo"}, nil, "foo"}}}}},
expectedIDs: []any{},
},
"ArraySlice": {
filter: bson.D{{"value", bson.D{{"$eq", bson.A{int32(42), "foo"}}}}},
expectedIDs: []any{"array-embedded"},
Expand Down Expand Up @@ -1393,6 +1401,10 @@ func TestQueryComparisonNe(t *testing.T) {
value: bson.A{bson.A{int32(42), "foo"}, nil},
unexpectedID: "array-embedded",
},
"LongArrayEmbedded": {
value: bson.A{bson.A{int32(42), "foo"}, nil, "foo"},
unexpectedID: "",
},
"ArrayShuffledValues": {
value: bson.A{"foo", nil, int32(42)},
unexpectedID: "",
Expand Down
28 changes: 1 addition & 27 deletions internal/handlers/common/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,7 @@ func filterDocumentPair(doc *types.Document, filterKey string, filterValue any)
if err != nil {
return false, nil // no error - the field is just not present
}
if docValue, ok := docValue.(*types.Array); ok {
return matchArrays(filterValue, docValue), nil
}
return false, nil
return types.Compare(docValue, filterValue) == types.Equal, nil

case types.Regex:
// {field: /regex/}
Expand Down Expand Up @@ -246,11 +243,6 @@ func filterFieldExpr(doc *types.Document, filterKey string, expr *types.Document
return matchDocuments(exprValue, fieldValue), nil
}
return false, nil
case *types.Array:
if fieldValue, ok := fieldValue.(*types.Array); ok {
return matchArrays(exprValue, fieldValue), nil
}
return false, nil
default:
if types.Compare(fieldValue, exprValue) != types.Equal {
return false, nil
Expand All @@ -265,16 +257,8 @@ func filterFieldExpr(doc *types.Document, filterKey string, expr *types.Document
return !matchDocuments(exprValue, fieldValue), nil
}
return false, nil

case *types.Array:
if fieldValue, ok := fieldValue.(*types.Array); ok {
return !matchArrays(exprValue, fieldValue), nil
}
return false, nil

case types.Regex:
return false, NewErrorMsg(ErrBadValue, "Can't have regex as arg to $ne.")

default:
if types.Compare(fieldValue, exprValue) == types.Equal {
return false, nil
Expand Down Expand Up @@ -335,11 +319,6 @@ func filterFieldExpr(doc *types.Document, filterKey string, expr *types.Document
}

switch arrValue := must.NotFail(arr.Get(i)).(type) {
case *types.Array:
fieldValue, ok := fieldValue.(*types.Array)
if ok && matchArrays(arrValue, fieldValue) {
found = true
}
case *types.Document:
for _, key := range arrValue.Keys() {
if strings.HasPrefix(key, "$") {
Expand Down Expand Up @@ -383,11 +362,6 @@ func filterFieldExpr(doc *types.Document, filterKey string, expr *types.Document
}

switch arrValue := must.NotFail(arr.Get(i)).(type) {
case *types.Array:
fieldValue, ok := fieldValue.(*types.Array)
if ok && matchArrays(arrValue, fieldValue) {
found = true
}
case *types.Document:
for _, key := range arrValue.Keys() {
if strings.HasPrefix(key, "$") {
Expand Down
30 changes: 0 additions & 30 deletions internal/handlers/common/match.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ import (

"golang.org/x/exp/slices"

"github.com/FerretDB/FerretDB/internal/fjson"
"github.com/FerretDB/FerretDB/internal/types"
"github.com/FerretDB/FerretDB/internal/util/must"
)

// matchDocuments returns true if 2 documents are equal.
Expand All @@ -42,31 +40,3 @@ func matchDocuments(a, b *types.Document) bool {
}
return reflect.DeepEqual(a.Map(), b.Map())
}

// matchArrays returns true if a filter array equals exactly the specified array or
// array contains an element that equals the array.
//
// TODO move into types.Compare.
func matchArrays(filterArr, docArr *types.Array) bool {
if filterArr == nil {
log.Panicf("%v is nil", filterArr)
}
if docArr == nil {
log.Panicf("%v is nil", docArr)
}

if string(must.NotFail(fjson.Marshal(filterArr))) == string(must.NotFail(fjson.Marshal(docArr))) {
return true
}

for i := 0; i < docArr.Len(); i++ {
arrValue := must.NotFail(docArr.Get(i))
if arrValue, ok := arrValue.(*types.Array); ok {
if string(must.NotFail(fjson.Marshal(filterArr))) == string(must.NotFail(fjson.Marshal(arrValue))) {
return true
}
}
}

return false
}
133 changes: 115 additions & 18 deletions internal/types/compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package types

import (
"bytes"
"fmt"
"math"
"math/big"
"time"
Expand Down Expand Up @@ -46,42 +45,47 @@ const (
// For that reason, it typically should not be used in tests.
//
// Compare and contrast with test helpers in testutil package.
func Compare(v1, v2 any) CompareResult {
if v1 == nil {
panic("compare: v1 is nil")
func Compare(docValue, filterValue any) CompareResult {
if docValue == nil {
panic("compare: docValue is nil")
}
if v2 == nil {
panic("compare: v2 is nil")
if filterValue == nil {
panic("compare: filterValue is nil")
}

switch v1 := v1.(type) {
switch docValue := docValue.(type) {
case *Document:
// TODO: implement document comparing
return Incomparable

case *Array:
for i := 0; i < v1.Len(); i++ {
v := must.NotFail(v1.Get(i))
switch v.(type) {
if filterArr, ok := filterValue.(*Array); ok {
return compareArrays(filterArr, docValue)
}

for i := 0; i < docValue.Len(); i++ {
docValue := must.NotFail(docValue.Get(i))
switch docValue.(type) {
case *Document, *Array:
continue
}

if res := compareScalars(v, v2); res != Incomparable {
if res := compareScalars(docValue, filterValue); res != Incomparable {
return res
}
}
return Incomparable

default:
return compareScalars(v1, v2)
return compareScalars(docValue, filterValue)
}
}

// compareScalars compares BSON scalar values.
func compareScalars(v1, v2 any) CompareResult {
compareEnsureScalar(v1)
compareEnsureScalar(v2)
if !isScalar(v1) || !isScalar(v2) {
return Incomparable
}

switch v1 := v1.(type) {
case float64:
Expand Down Expand Up @@ -198,18 +202,18 @@ func compareScalars(v1, v2 any) CompareResult {
panic("not reached")
}

// compareEnsureScalar panics if v is not a BSON scalar value.
func compareEnsureScalar(v any) {
// isScalar check if v is a BSON scalar value.
func isScalar(v any) bool {
if v == nil {
panic("v is nil")
}

switch v.(type) {
case float64, string, Binary, ObjectID, bool, time.Time, NullType, Regex, int32, Timestamp, int64:
return
return true
}

panic(fmt.Sprintf("non-scalar type %T", v))
return false
}

// compareInvert swaps Less and Greater, keeping Equal and Incomparable.
Expand Down Expand Up @@ -254,3 +258,96 @@ func compareNumbers(a float64, b int64) CompareResult {

return CompareResult(bigA.Cmp(bigB))
}

// compareArrays compares indices of a filter array according to indices of a document array;
// returns Equal when a document array contains another array(subarray) that equals filter array.
func compareArrays(filterArr, docArr *Array) CompareResult {
if docArr.Len() == 0 && filterArr.Len() == 0 {
return Equal
}
if filterArr.Len() == 0 {
return Incomparable
}

entireArrayResult, subArrayEquality := Incomparable, Incomparable

for i := 0; i < docArr.Len(); i++ {
arrValue := must.NotFail(docArr.Get(i))
switch arrValue := arrValue.(type) {
case *Array:
filterArrValue := must.NotFail(filterArr.Get(i))
switch filterArrValue := filterArrValue.(type) {
case *Array:
res := compareArrays(filterArrValue, arrValue)
res = handleSubArrayComparingResult(&res, &entireArrayResult, &subArrayEquality)
continue

default:
res := compareArrays(filterArr, arrValue)
res = handleSubArrayComparingResult(&res, &entireArrayResult, &subArrayEquality)
}
continue

// TODO: case Document
// case *Document

default:
if i+1 > filterArr.Len() {
if entireArrayResult == Equal {
entireArrayResult = Incomparable
}
continue // looking for next element is array that might fit filter query
}

filterValue := must.NotFail(filterArr.Get(i))
switch filterValue := filterValue.(type) {
case *Array, *Document:
if entireArrayResult == Equal {
entireArrayResult = Incomparable
}
continue
default:
res := CompareOrder(arrValue, filterValue, Ascending)
if entireArrayResult == Incomparable && i == 0 { // set first non-Incomparable result
entireArrayResult = res
}

if entireArrayResult != res {
entireArrayResult = Incomparable
continue
}

// both arrays are not equal if there are still elements in filter array
if filterArr.Len() > i+1 &&
docArr.Len() == i+1 &&
entireArrayResult == Equal {
entireArrayResult = Incomparable
}
}
}
}

if subArrayEquality == Equal && !(filterArr.Len() > docArr.Len()) {
return subArrayEquality
}

return entireArrayResult
}

// handleSubArrayComparingResult determines on the first iteration what result the comparison will follow (e.g. gt, ls, eq)
// detects inconsistency in iterations; detects equality for subbaray.
func handleSubArrayComparingResult(resultFromComparing, entireArrayResult, subArrayEquality *CompareResult) CompareResult {
if *resultFromComparing == Incomparable {
return Incomparable
}

if *resultFromComparing == Equal {
*subArrayEquality = Equal
}

if entireArrayResult != resultFromComparing {
return Incomparable
}

return *resultFromComparing
}

0 comments on commit 88af5f4

Please sign in to comment.