From ad8acc08a70f64ff4e877e594653a9146b738b0e Mon Sep 17 00:00:00 2001 From: Alexey Palazhchenko Date: Tue, 26 Apr 2022 10:17:44 +0300 Subject: [PATCH] Move existing comparision code to `types` (#531) With minimal code changes. Refs #481. --- internal/handlers/common/compare.go | 317 -------------------- internal/handlers/common/error.go | 20 +- internal/handlers/common/filter.go | 25 +- internal/handlers/common/match.go | 68 +++++ internal/handlers/common/projection.go | 73 ++--- internal/handlers/common/sort.go | 16 +- internal/handlers/common/sorttype_string.go | 29 ++ internal/types/compare.go | 247 +++++++++++++++ internal/types/compareresult_string.go | 36 +++ 9 files changed, 452 insertions(+), 379 deletions(-) delete mode 100644 internal/handlers/common/compare.go create mode 100644 internal/handlers/common/match.go create mode 100644 internal/handlers/common/sorttype_string.go create mode 100644 internal/types/compare.go create mode 100644 internal/types/compareresult_string.go diff --git a/internal/handlers/common/compare.go b/internal/handlers/common/compare.go deleted file mode 100644 index 1462d3318157..000000000000 --- a/internal/handlers/common/compare.go +++ /dev/null @@ -1,317 +0,0 @@ -// Copyright 2021 FerretDB Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package common - -import ( - "bytes" - "fmt" - "log" - "math" - "math/big" - "reflect" - "time" - - "golang.org/x/exp/constraints" - "golang.org/x/exp/slices" - - "github.com/FerretDB/FerretDB/internal/fjson" - "github.com/FerretDB/FerretDB/internal/types" - "github.com/FerretDB/FerretDB/internal/util/must" -) - -// compareResult represents the result of a comparison. -type compareResult int - -const ( - equal compareResult = iota - less - greater - notEqual // but not less or greater; for example, two NaNs -) - -// compareScalars compares two scalar values. -func compareScalars(a, b any) compareResult { - if a == nil { - panic("a is nil") - } - if b == nil { - panic("b is nil") - } - - switch a := a.(type) { - case float64: - switch b := b.(type) { - case float64: - if math.IsNaN(a) && math.IsNaN(b) { - return equal - } - return compareOrdered(a, b) - case int32: - return compareNumbers(a, int64(b)) - case int64: - return compareNumbers(a, b) - default: - return notEqual - } - - case string: - b, ok := b.(string) - if ok { - return compareOrdered(a, b) - } - return notEqual - - case types.Binary: - b, ok := b.(types.Binary) - if !ok { - return notEqual - } - al, bl := len(a.B), len(b.B) - if al != bl { - return compareOrdered(al, bl) - } - if a.Subtype != b.Subtype { - return compareOrdered(a.Subtype, b.Subtype) - } - switch bytes.Compare(a.B, b.B) { - case 0: - return equal - case -1: - return less - case 1: - return greater - default: - panic("unreachable") - } - - case types.ObjectID: - b, ok := b.(types.ObjectID) - if !ok { - return notEqual - } - switch bytes.Compare(a[:], b[:]) { - case 0: - return equal - case -1: - return less - case 1: - return greater - default: - panic("unreachable") - } - - case bool: - b, ok := b.(bool) - if !ok { - return notEqual - } - if a == b { - return equal - } - if b { - return less - } - return greater - - case time.Time: - b, ok := b.(time.Time) - if !ok { - return notEqual - } - return compareOrdered(a.UnixMilli(), b.UnixMilli()) - - case types.NullType: - _, ok := b.(types.NullType) - if ok { - return equal - } - return notEqual - - case types.Regex: - b, ok := b.(types.Regex) - if ok && a == b { - return equal - } - return notEqual - - case int32: - switch b := b.(type) { - case float64: - return filterCompareInvert(compareNumbers(b, int64(a))) - case int32: - return compareOrdered(a, b) - case int64: - return compareOrdered(int64(a), b) - default: - return notEqual - } - - case types.Timestamp: - b, ok := b.(types.Timestamp) - if ok { - return compareOrdered(a, b) - } - return notEqual - - case int64: - switch b := b.(type) { - case float64: - return filterCompareInvert(compareNumbers(b, a)) - case int32: - return compareOrdered(a, int64(b)) - case int64: - return compareOrdered(a, b) - default: - return notEqual - } - - default: - panic(fmt.Sprintf("unhandled type %T", a)) - } -} - -// compare compares the filter to the value of the document, whether it is a composite type or a scalar type. -func compare(docValue, filter any) compareResult { - if docValue == nil { - panic("docValue is nil") - } - if filter == nil { - panic("filter is nil") - } - - switch docValue := docValue.(type) { - case *types.Document: - // TODO: implement document comparing - return notEqual - - case *types.Array: - for i := 0; i < docValue.Len(); i++ { - arrValue := must.NotFail(docValue.Get(i)) - switch arrValue.(type) { - case *types.Document, *types.Array: - continue - } - - switch compareScalars(arrValue, filter) { - case equal: - return equal - case greater: - return greater - case less: - return less - case notEqual: - continue - } - } - return notEqual - - default: - return compareScalars(docValue, filter) - } -} - -// filterCompareInvert swaps less and greater, keeping equal and notEqual. -func filterCompareInvert(res compareResult) compareResult { - switch res { - case equal: - return equal - case less: - return greater - case greater: - return less - case notEqual: - return notEqual - default: - panic("unreachable") - } -} - -// compareOrdered compares two values of the same type using ==, <, > operators. -func compareOrdered[T constraints.Ordered](a, b T) compareResult { - if a == b { - return equal - } - if a < b { - return less - } - if a > b { - return greater - } - return notEqual -} - -// compareNumbers compares two numbers. -func compareNumbers(a float64, b int64) compareResult { - if math.IsNaN(a) { - return notEqual - } - - // TODO figure out correct precision - bigFloat := new(big.Float).SetFloat64(a).SetPrec(100000) - bigFloatFromInt := new(big.Float).SetInt64(b).SetPrec(100000) - - switch bigFloat.Cmp(bigFloatFromInt) { - case -1: - return less - case 0: - return equal - case 1: - return greater - default: - panic("not reached") - } -} - -// matchDocuments returns true if 2 documents are equal. -func matchDocuments(a, b *types.Document) bool { - if a == nil { - log.Panicf("%v is nil", a) - } - if b == nil { - log.Panicf("%v is nil", b) - } - - keys := a.Keys() - if !slices.Equal(keys, b.Keys()) { - return false - } - return reflect.DeepEqual(a.Map(), b.Map()) -} - -// matchArrays returns true if a filter array equals exactly the specified array or -// array contains an element that equals the array. -func matchArrays(filterArr, docArr *types.Array) bool { - if filterArr == nil { - log.Panicf("%v is nil", filterArr) - } - if docArr == nil { - log.Panicf("%v is nil", docArr) - } - - if string(must.NotFail(fjson.Marshal(filterArr))) == string(must.NotFail(fjson.Marshal(docArr))) { - return true - } - - for i := 0; i < docArr.Len(); i++ { - arrValue := must.NotFail(docArr.Get(i)) - if arrValue, ok := arrValue.(*types.Array); ok { - if string(must.NotFail(fjson.Marshal(filterArr))) == string(must.NotFail(fjson.Marshal(arrValue))) { - return true - } - } - } - - return false -} diff --git a/internal/handlers/common/error.go b/internal/handlers/common/error.go index 4c463c8c8dc5..d67b7b60b60f 100644 --- a/internal/handlers/common/error.go +++ b/internal/handlers/common/error.go @@ -122,14 +122,26 @@ func ProtocolError(err error) (*Error, bool) { func formatBitwiseOperatorErr(err error, operator string, maskValue any) error { switch err { case errNotWholeNumber: - return NewErrorMsg(ErrFailedToParse, fmt.Sprintf("Expected an integer: %s: %#v", operator, maskValue)) + return NewErrorMsg( + ErrFailedToParse, + fmt.Sprintf("Expected an integer: %s: %#v", operator, maskValue), + ) case errNegativeNumber: if _, ok := maskValue.(float64); ok { - return NewErrorMsg(ErrFailedToParse, fmt.Sprintf(`Expected a positive number in: %s: %.1f`, operator, maskValue)) + return NewErrorMsg( + ErrFailedToParse, + fmt.Sprintf(`Expected a positive number in: %s: %.1f`, operator, maskValue), + ) } - return NewErrorMsg(ErrFailedToParse, fmt.Sprintf(`Expected a positive number in: %s: %v`, operator, maskValue)) + return NewErrorMsg( + ErrFailedToParse, + fmt.Sprintf(`Expected a positive number in: %s: %v`, operator, maskValue), + ) case errNotBinaryMask: - return NewErrorMsg(ErrBadValue, fmt.Sprintf(`value takes an Array, a number, or a BinData but received: %s: %#v`, operator, maskValue)) + return NewErrorMsg( + ErrBadValue, + fmt.Sprintf(`value takes an Array, a number, or a BinData but received: %s: %#v`, operator, maskValue), + ) default: return err } diff --git a/internal/handlers/common/filter.go b/internal/handlers/common/filter.go index fee3f0b5f18b..11166e4af791 100644 --- a/internal/handlers/common/filter.go +++ b/internal/handlers/common/filter.go @@ -105,12 +105,7 @@ func filterDocumentPair(doc *types.Document, filterKey string, filterValue any) return false, nil // no error - the field is just not present } - switch docValue := docValue.(type) { - case *types.Document, *types.Array: - return compare(docValue, filterValue) == equal, nil - } - - return compareScalars(docValue, filterValue) == equal, nil + return types.Compare(docValue, filterValue) == types.Equal, nil } } @@ -242,7 +237,7 @@ func filterFieldExpr(doc *types.Document, filterKey string, expr *types.Document } return false, nil default: - if compare(fieldValue, exprValue) != equal { + if types.Compare(fieldValue, exprValue) != types.Equal { return false, nil } } @@ -266,7 +261,7 @@ func filterFieldExpr(doc *types.Document, filterKey string, expr *types.Document return false, NewErrorMsg(ErrBadValue, "Can't have regex as arg to $ne.") default: - if compare(fieldValue, exprValue) == equal { + if types.Compare(fieldValue, exprValue) == types.Equal { return false, nil } } @@ -277,7 +272,7 @@ func filterFieldExpr(doc *types.Document, filterKey string, expr *types.Document msg := fmt.Sprintf(`Can't have RegEx as arg to predicate over field '%s'.`, filterKey) return false, NewErrorMsg(ErrBadValue, msg) } - if compare(fieldValue, exprValue) != greater { + if types.Compare(fieldValue, exprValue) != types.Greater { return false, nil } @@ -287,7 +282,7 @@ func filterFieldExpr(doc *types.Document, filterKey string, expr *types.Document msg := fmt.Sprintf(`Can't have RegEx as arg to predicate over field '%s'.`, filterKey) return false, NewErrorMsg(ErrBadValue, msg) } - if c := compare(fieldValue, exprValue); c != greater && c != equal { + if c := types.Compare(fieldValue, exprValue); c != types.Greater && c != types.Equal { return false, nil } @@ -297,7 +292,7 @@ func filterFieldExpr(doc *types.Document, filterKey string, expr *types.Document msg := fmt.Sprintf(`Can't have RegEx as arg to predicate over field '%s'.`, filterKey) return false, NewErrorMsg(ErrBadValue, msg) } - if c := compare(fieldValue, exprValue); c != less { + if c := types.Compare(fieldValue, exprValue); c != types.Less { return false, nil } @@ -307,7 +302,7 @@ func filterFieldExpr(doc *types.Document, filterKey string, expr *types.Document msg := fmt.Sprintf(`Can't have RegEx as arg to predicate over field '%s'.`, filterKey) return false, NewErrorMsg(ErrBadValue, msg) } - if c := compare(fieldValue, exprValue); c != less && c != equal { + if c := types.Compare(fieldValue, exprValue); c != types.Less && c != types.Equal { return false, nil } @@ -350,7 +345,7 @@ func filterFieldExpr(doc *types.Document, filterKey string, expr *types.Document continue default: - if compare(fieldValue, arrValue) == equal { + if types.Compare(fieldValue, arrValue) == types.Equal { found = true break } @@ -399,7 +394,7 @@ func filterFieldExpr(doc *types.Document, filterKey string, expr *types.Document continue default: - if compare(fieldValue, arrValue) == equal { + if types.Compare(fieldValue, arrValue) == types.Equal { found = true break } @@ -524,7 +519,7 @@ func filterFieldRegex(fieldValue any, regex types.Regex) (bool, error) { } case types.Regex: - return compareScalars(fieldValue, regex) == equal, nil + return types.Compare(fieldValue, regex) == types.Equal, nil } return false, nil diff --git a/internal/handlers/common/match.go b/internal/handlers/common/match.go new file mode 100644 index 000000000000..97eec4405288 --- /dev/null +++ b/internal/handlers/common/match.go @@ -0,0 +1,68 @@ +// Copyright 2021 FerretDB Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "log" + "reflect" + + "golang.org/x/exp/slices" + + "github.com/FerretDB/FerretDB/internal/fjson" + "github.com/FerretDB/FerretDB/internal/types" + "github.com/FerretDB/FerretDB/internal/util/must" +) + +// matchDocuments returns true if 2 documents are equal. +func matchDocuments(a, b *types.Document) bool { + if a == nil { + log.Panicf("%v is nil", a) + } + if b == nil { + log.Panicf("%v is nil", b) + } + + keys := a.Keys() + if !slices.Equal(keys, b.Keys()) { + return false + } + return reflect.DeepEqual(a.Map(), b.Map()) +} + +// matchArrays returns true if a filter array equals exactly the specified array or +// array contains an element that equals the array. +func matchArrays(filterArr, docArr *types.Array) bool { + if filterArr == nil { + log.Panicf("%v is nil", filterArr) + } + if docArr == nil { + log.Panicf("%v is nil", docArr) + } + + if string(must.NotFail(fjson.Marshal(filterArr))) == string(must.NotFail(fjson.Marshal(docArr))) { + return true + } + + for i := 0; i < docArr.Len(); i++ { + arrValue := must.NotFail(docArr.Get(i)) + if arrValue, ok := arrValue.(*types.Array); ok { + if string(must.NotFail(fjson.Marshal(filterArr))) == string(must.NotFail(fjson.Marshal(arrValue))) { + return true + } + } + } + + return false +} diff --git a/internal/handlers/common/projection.go b/internal/handlers/common/projection.go index 77d43b9d5fed..f7936cf4124c 100644 --- a/internal/handlers/common/projection.go +++ b/internal/handlers/common/projection.go @@ -35,27 +35,24 @@ func isProjectionInclusion(projection *types.Document) (inclusion bool, err erro } v := must.NotFail(projection.Get(k)) switch v := v.(type) { - case bool: - if v { - if exclusion { - err = NewError(ErrProjectionInEx, - fmt.Errorf("Cannot do inclusion on field %s in exclusion projection", k), - ) + case *types.Document: + for _, projectionType := range v.Keys() { + supportedProjectionTypes := []string{"$elemMatch"} + if !slices.Contains(supportedProjectionTypes, projectionType) { + err = lazyerrors.Errorf("projecion of %s is not supported", projectionType) return } - inclusion = true - } else { - if inclusion { - err = NewError(ErrProjectionExIn, - fmt.Errorf("Cannot do exclusion on field %s in inclusion projection", k), - ) - return + + switch projectionType { + case "$elemMatch": + inclusion = true + default: + panic(projectionType + " not supported") } - exclusion = true } - case int32, int64, float64: - if compareScalars(v, int32(0)) == equal { + case float64, int32, int64: + if types.Compare(v, int32(0)) == types.Equal { if inclusion { err = NewError(ErrProjectionExIn, fmt.Errorf("Cannot do exclusion on field %s in inclusion projection", k), @@ -73,21 +70,25 @@ func isProjectionInclusion(projection *types.Document) (inclusion bool, err erro inclusion = true } - case *types.Document: - for _, projectionType := range v.Keys() { - supportedProjectionTypes := []string{"$elemMatch"} - if !slices.Contains(supportedProjectionTypes, projectionType) { - err = lazyerrors.Errorf("projecion of %s is not supported", projectionType) + case bool: + if v { + if exclusion { + err = NewError(ErrProjectionInEx, + fmt.Errorf("Cannot do inclusion on field %s in exclusion projection", k), + ) return } - - switch projectionType { - case "$elemMatch": - inclusion = true - default: - panic(projectionType + " not supported") + inclusion = true + } else { + if inclusion { + err = NewError(ErrProjectionExIn, + fmt.Errorf("Cannot do exclusion on field %s in inclusion projection", k), + ) + return } + exclusion = true } + default: err = lazyerrors.Errorf("unsupported operation %s %v (%T)", k, v, v) return @@ -133,19 +134,19 @@ func projectDocument(inclusion bool, doc *types.Document, projection *types.Docu } switch projectionVal := projectionVal.(type) { // found in the projection - case bool: // field: bool - if !projectionVal { - doc.Remove(k1) + case *types.Document: // field: { $elemMatch: { field2: value }} + if err := applyComplexProjection(k1, doc, projectionVal); err != nil { + return err } - case int32, int64, float64: // field: number - if compareScalars(projectionVal, int32(0)) == equal { + case float64, int32, int64: // field: number + if types.Compare(projectionVal, int32(0)) == types.Equal { doc.Remove(k1) } - case *types.Document: // field: { $elemMatch: { field2: value }} - if err := applyComplexProjection(k1, doc, projectionVal); err != nil { - return err + case bool: // field: bool + if !projectionVal { + doc.Remove(k1) } default: @@ -216,7 +217,7 @@ func filterFieldArrayElemMatch(k1 string, doc, conditions *types.Document, docVa doc.RemoveByPath(k1, strconv.Itoa(j)) continue } - if compareScalars(docVal, elemMatchFieldCondition) == equal { + if types.Compare(docVal, elemMatchFieldCondition) == types.Equal { // elemMatch to return first matching, all others are to be removed found = j break diff --git a/internal/handlers/common/sort.go b/internal/handlers/common/sort.go index aaa4f8cd2870..eb653e6fd98c 100644 --- a/internal/handlers/common/sort.go +++ b/internal/handlers/common/sort.go @@ -23,12 +23,14 @@ import ( "github.com/FerretDB/FerretDB/internal/util/must" ) +//go:generate ../../../bin/stringer -linecomment -type sortType + // sortType represents sort type for $sort aggregation. -type sortType int +type sortType int8 const ( - ascending sortType = iota - descending + ascending sortType = 1 // asc + descending sortType = -1 // desc ) // SortDocuments sorts given documents in place according to the given sorting conditions. @@ -73,24 +75,24 @@ func lessFunc(sortKey string, sortType sortType) func(a, b *types.Document) bool return false } - result := compare(aField, bField) + result := types.Compare(aField, bField) switch result { - case less: + case types.Less: switch sortType { case ascending: return true case descending: return false } - case greater: + case types.Greater: switch sortType { case ascending: return false case descending: return true } - case notEqual, equal: + case types.NotEqual, types.Equal: return false } diff --git a/internal/handlers/common/sorttype_string.go b/internal/handlers/common/sorttype_string.go new file mode 100644 index 000000000000..e613d146e8e5 --- /dev/null +++ b/internal/handlers/common/sorttype_string.go @@ -0,0 +1,29 @@ +// Code generated by "stringer -linecomment -type sortType"; DO NOT EDIT. + +package common + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[ascending-1] + _ = x[descending - -1] +} + +const ( + _sortType_name_0 = "desc" + _sortType_name_1 = "asc" +) + +func (i sortType) String() string { + switch { + case i == -1: + return _sortType_name_0 + case i == 1: + return _sortType_name_1 + default: + return "sortType(" + strconv.FormatInt(int64(i), 10) + ")" + } +} diff --git a/internal/types/compare.go b/internal/types/compare.go new file mode 100644 index 000000000000..49b01918756e --- /dev/null +++ b/internal/types/compare.go @@ -0,0 +1,247 @@ +// Copyright 2021 FerretDB Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +import ( + "bytes" + "fmt" + "math" + "math/big" + "time" + + "golang.org/x/exp/constraints" + + "github.com/FerretDB/FerretDB/internal/util/must" +) + +//go:generate ../../bin/stringer -linecomment -type CompareResult + +// CompareResult represents the result of a comparison. +type CompareResult int8 + +const ( + Equal CompareResult = 0 // == + Less CompareResult = -1 // < + Greater CompareResult = 1 // > + NotEqual CompareResult = 127 // != +) + +// Compare compares BSON values. +func Compare(a, b any) CompareResult { + if a == nil { + panic("a is nil") + } + if b == nil { + panic("b is nil") + } + + switch a := a.(type) { + case *Document: + // TODO: implement document comparing + return NotEqual + + case *Array: + for i := 0; i < a.Len(); i++ { + v := must.NotFail(a.Get(i)) + switch v.(type) { + case *Document, *Array: + continue + } + + if res := compareScalars(v, b); res != NotEqual { + return res + } + } + return NotEqual + + default: + return compareScalars(a, b) + } +} + +// compareScalars compares BSON scalar values. +func compareScalars(a, b any) CompareResult { + compareEnsureScalar(a) + compareEnsureScalar(b) + + switch a := a.(type) { + case float64: + switch b := b.(type) { + case float64: + if math.IsNaN(a) && math.IsNaN(b) { + return Equal + } + return compareOrdered(a, b) + case int32: + return compareNumbers(a, int64(b)) + case int64: + return compareNumbers(a, b) + default: + return NotEqual + } + + case string: + b, ok := b.(string) + if ok { + return compareOrdered(a, b) + } + return NotEqual + + case Binary: + b, ok := b.(Binary) + if !ok { + return NotEqual + } + al, bl := len(a.B), len(b.B) + if al != bl { + return compareOrdered(al, bl) + } + if a.Subtype != b.Subtype { + return compareOrdered(a.Subtype, b.Subtype) + } + return CompareResult(bytes.Compare(a.B, b.B)) + + case ObjectID: + b, ok := b.(ObjectID) + if !ok { + return NotEqual + } + return CompareResult(bytes.Compare(a[:], b[:])) + + case bool: + b, ok := b.(bool) + if !ok { + return NotEqual + } + if a == b { + return Equal + } + if b { + return Less + } + return Greater + + case time.Time: + b, ok := b.(time.Time) + if !ok { + return NotEqual + } + return compareOrdered(a.UnixMilli(), b.UnixMilli()) + + case NullType: + _, ok := b.(NullType) + if ok { + return Equal + } + return NotEqual + + case Regex: + b, ok := b.(Regex) + if ok && a == b { + return Equal + } + return NotEqual + + case int32: + switch b := b.(type) { + case float64: + return compareInvert(compareNumbers(b, int64(a))) + case int32: + return compareOrdered(a, b) + case int64: + return compareOrdered(int64(a), b) + default: + return NotEqual + } + + case Timestamp: + b, ok := b.(Timestamp) + if ok { + return compareOrdered(a, b) + } + return NotEqual + + case int64: + switch b := b.(type) { + case float64: + return compareInvert(compareNumbers(b, a)) + case int32: + return compareOrdered(a, int64(b)) + case int64: + return compareOrdered(a, b) + default: + return NotEqual + } + } + + panic("not reached") +} + +// compareEnsureScalar panics if v is not a BSON scalar value. +func compareEnsureScalar(v any) { + if v == nil { + panic("v is nil") + } + + switch v.(type) { + case float64, string, Binary, ObjectID, bool, time.Time, NullType, Regex, int32, Timestamp, int64: + return + } + + panic(fmt.Sprintf("unhandled type %T", v)) +} + +// compareInvert swaps Less and Greater, keeping Equal and NotEqual. +func compareInvert(res CompareResult) CompareResult { + switch res { + case Equal: + return Equal + case Less: + return Greater + case Greater: + return Less + case NotEqual: + return NotEqual + } + + panic("not reached") +} + +// compareOrdered compares BSON values of the same type using ==, <, > operators. +func compareOrdered[T constraints.Ordered](a, b T) CompareResult { + switch { + case a == b: + return Equal + case a < b: + return Less + case a > b: + return Greater + default: + return NotEqual + } +} + +// compareNumbers compares BSON numbers. +func compareNumbers(a float64, b int64) CompareResult { + if math.IsNaN(a) { + return NotEqual + } + + // TODO figure out correct precision + bigA := new(big.Float).SetFloat64(a).SetPrec(100000) + bigB := new(big.Float).SetInt64(b).SetPrec(100000) + + return CompareResult(bigA.Cmp(bigB)) +} diff --git a/internal/types/compareresult_string.go b/internal/types/compareresult_string.go new file mode 100644 index 000000000000..6df736a34de6 --- /dev/null +++ b/internal/types/compareresult_string.go @@ -0,0 +1,36 @@ +// Code generated by "stringer -linecomment -type CompareResult"; DO NOT EDIT. + +package types + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[Equal-0] + _ = x[Less - -1] + _ = x[Greater-1] + _ = x[NotEqual-127] +} + +const ( + _CompareResult_name_0 = "<==>" + _CompareResult_name_1 = "!=" +) + +var ( + _CompareResult_index_0 = [...]uint8{0, 1, 3, 4} +) + +func (i CompareResult) String() string { + switch { + case -1 <= i && i <= 1: + i -= -1 + return _CompareResult_name_0[_CompareResult_index_0[i]:_CompareResult_index_0[i+1]] + case i == 127: + return _CompareResult_name_1 + default: + return "CompareResult(" + strconv.FormatInt(int64(i), 10) + ")" + } +}