From 5ff7190927a25e30f619279b5e5078b98c7cbb9a Mon Sep 17 00:00:00 2001 From: Henrique Vicente Date: Tue, 20 Feb 2024 09:06:08 +0100 Subject: [PATCH 01/13] wip --- .golangci.yml | 2 + cmd/envtool/internal/testmatch/api.go | 40 +++ cmd/envtool/internal/testmatch/api_test.go | 65 ++++ cmd/envtool/internal/testmatch/match.go | 334 +++++++++++++++++++ cmd/envtool/internal/testmatch/match_test.go | 277 +++++++++++++++ cmd/envtool/tests.go | 119 +++---- cmd/envtool/tests_test.go | 206 +----------- 7 files changed, 763 insertions(+), 280 deletions(-) create mode 100644 cmd/envtool/internal/testmatch/api.go create mode 100644 cmd/envtool/internal/testmatch/api_test.go create mode 100644 cmd/envtool/internal/testmatch/match.go create mode 100644 cmd/envtool/internal/testmatch/match_test.go diff --git a/.golangci.yml b/.golangci.yml index 4ca6047af25f..7c7fbc84be45 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -3,6 +3,8 @@ run: timeout: 3m + skip-dirs: + - cmd/envtool/internal/testmatch # due to files match.go and match_test.go from the Go standard library linters-settings: # asciicheck diff --git a/cmd/envtool/internal/testmatch/api.go b/cmd/envtool/internal/testmatch/api.go new file mode 100644 index 000000000000..350221277944 --- /dev/null +++ b/cmd/envtool/internal/testmatch/api.go @@ -0,0 +1,40 @@ +// Copyright 2021 FerretDB Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testmatch + +import "regexp" + +type Matcher struct { + m *matcher +} + +// New matcher. +func New(run, skip string) *Matcher { + return &Matcher{ + m: newMatcher(regexp.MatchString, run, "-test.run", skip), + } +} + +// Match top-level test function. +func (m *Matcher) Match(testFunction string) bool { + _, ok, _ := m.m.fullName(&common{}, testFunction) + return ok +} + +// common is used internally by the matcher. +type common struct { + name string // name of the test + level int // level of the test +} diff --git a/cmd/envtool/internal/testmatch/api_test.go b/cmd/envtool/internal/testmatch/api_test.go new file mode 100644 index 000000000000..2d7cb9e4dc1f --- /dev/null +++ b/cmd/envtool/internal/testmatch/api_test.go @@ -0,0 +1,65 @@ +// Copyright 2021 FerretDB Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests based on match_test.go +// This file is a modification of +// https://go.googlesource.com/go/+/d31efbc95e6803742aaca39e3a825936791e6b5a/src/testing/match_test.go +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file at +// https://go.googlesource.com/go/+/d31efbc95e6803742aaca39e3a825936791e6b5a/LICENSE + +package testmatch + +import ( + "testing" +) + +func TestMatcherAPI(t *testing.T) { + testCases := []struct { + pattern string + skip string + name string + ok bool + }{ + // Behavior without subtests. + {"", "", "TestFoo", true}, + {"TestFoo", "", "TestFoo", true}, + {"TestFoo/", "", "TestFoo", true}, + {"TestFoo/bar/baz", "", "TestFoo", true}, + {"TestFoo", "", "TestBar", false}, + {"TestFoo/", "", "TestBar", false}, + {"TestFoo/bar/baz", "", "TestBar/bar/baz", false}, + {"", "TestBar", "TestFoo", true}, + {"", "TestBar", "TestBar", false}, + + // Skipping a non-existent test doesn't change anything. + {"", "TestFoo/skipped", "TestFoo", true}, + {"TestFoo", "TestFoo/skipped", "TestFoo", true}, + {"TestFoo/", "TestFoo/skipped", "TestFoo", true}, + {"TestFoo/bar/baz", "TestFoo/skipped", "TestFoo", true}, + {"TestFoo", "TestFoo/skipped", "TestBar", false}, + {"TestFoo/", "TestFoo/skipped", "TestBar", false}, + {"TestFoo/bar/baz", "TestFoo/skipped", "TestBar/bar/baz", false}, + } + + for _, tc := range testCases { + m := New(tc.pattern, tc.skip) + + if ok := m.Match(tc.name); ok != tc.ok { + t.Errorf("for pattern %q, Match(%q) = %v; want ok %v", + tc.pattern, tc.name, ok, tc.ok) + } + } +} diff --git a/cmd/envtool/internal/testmatch/match.go b/cmd/envtool/internal/testmatch/match.go new file mode 100644 index 000000000000..005fe3d6ce98 --- /dev/null +++ b/cmd/envtool/internal/testmatch/match.go @@ -0,0 +1,334 @@ +// Copyright 2021 FerretDB Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file is originally from +// https://go.googlesource.com/go/+/3bc28402fae2a1646e4d2756344b5eb34994d25f/src/testing/match.go +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file at +// https://go.googlesource.com/go/+/3bc28402fae2a1646e4d2756344b5eb34994d25f/LICENSE + +package testmatch + +import ( + "fmt" + "os" + "strconv" + "strings" + "sync" +) + +// matcher sanitizes, uniques, and filters names of subtests and subbenchmarks. +type matcher struct { + filter filterMatch + skip filterMatch + matchFunc func(pat, str string) (bool, error) + + mu sync.Mutex + + // subNames is used to deduplicate subtest names. + // Each key is the subtest name joined to the deduplicated name of the parent test. + // Each value is the count of the number of occurrences of the given subtest name + // already seen. + subNames map[string]int32 +} + +type filterMatch interface { + // matches checks the name against the receiver's pattern strings using the + // given match function. + matches(name []string, matchString func(pat, str string) (bool, error)) (ok, partial bool) + + // verify checks that the receiver's pattern strings are valid filters by + // calling the given match function. + verify(name string, matchString func(pat, str string) (bool, error)) error +} + +// simpleMatch matches a test name if all of the pattern strings match in +// sequence. +type simpleMatch []string + +// alternationMatch matches a test name if one of the alternations match. +type alternationMatch []filterMatch + +// TODO: fix test_main to avoid race and improve caching, also allowing to +// eliminate this Mutex. +var matchMutex sync.Mutex + +func allMatcher() *matcher { + return newMatcher(nil, "", "", "") +} + +func newMatcher(matchString func(pat, str string) (bool, error), patterns, name, skips string) *matcher { + var filter, skip filterMatch + if patterns == "" { + filter = simpleMatch{} // always partial true + } else { + filter = splitRegexp(patterns) + if err := filter.verify(name, matchString); err != nil { + fmt.Fprintf(os.Stderr, "testing: invalid regexp for %s\n", err) + os.Exit(1) + } + } + if skips == "" { + skip = alternationMatch{} // always false + } else { + skip = splitRegexp(skips) + if err := skip.verify("-test.skip", matchString); err != nil { + fmt.Fprintf(os.Stderr, "testing: invalid regexp for %v\n", err) + os.Exit(1) + } + } + return &matcher{ + filter: filter, + skip: skip, + matchFunc: matchString, + subNames: map[string]int32{}, + } +} + +func (m *matcher) fullName(c *common, subname string) (name string, ok, partial bool) { + name = subname + + m.mu.Lock() + defer m.mu.Unlock() + + if c != nil && c.level > 0 { + name = m.unique(c.name, rewrite(subname)) + } + + matchMutex.Lock() + defer matchMutex.Unlock() + + // We check the full array of paths each time to allow for the case that a pattern contains a '/'. + elem := strings.Split(name, "/") + + // filter must match. + // accept partial match that may produce full match later. + ok, partial = m.filter.matches(elem, m.matchFunc) + if !ok { + return name, false, false + } + + // skip must not match. + // ignore partial match so we can get to more precise match later. + skip, partialSkip := m.skip.matches(elem, m.matchFunc) + if skip && !partialSkip { + return name, false, false + } + + return name, ok, partial +} + +// clearSubNames clears the matcher's internal state, potentially freeing +// memory. After this is called, T.Name may return the same strings as it did +// for earlier subtests. +func (m *matcher) clearSubNames() { + m.mu.Lock() + defer m.mu.Unlock() + clear(m.subNames) +} + +func (m simpleMatch) matches(name []string, matchString func(pat, str string) (bool, error)) (ok, partial bool) { + for i, s := range name { + if i >= len(m) { + break + } + if ok, _ := matchString(m[i], s); !ok { + return false, false + } + } + return true, len(name) < len(m) +} + +func (m simpleMatch) verify(name string, matchString func(pat, str string) (bool, error)) error { + for i, s := range m { + m[i] = rewrite(s) + } + // Verify filters before doing any processing. + for i, s := range m { + if _, err := matchString(s, "non-empty"); err != nil { + return fmt.Errorf("element %d of %s (%q): %s", i, name, s, err) + } + } + return nil +} + +func (m alternationMatch) matches(name []string, matchString func(pat, str string) (bool, error)) (ok, partial bool) { + for _, m := range m { + if ok, partial = m.matches(name, matchString); ok { + return ok, partial + } + } + return false, false +} + +func (m alternationMatch) verify(name string, matchString func(pat, str string) (bool, error)) error { + for i, m := range m { + if err := m.verify(name, matchString); err != nil { + return fmt.Errorf("alternation %d of %s", i, err) + } + } + return nil +} + +func splitRegexp(s string) filterMatch { + a := make(simpleMatch, 0, strings.Count(s, "/")) + b := make(alternationMatch, 0, strings.Count(s, "|")) + cs := 0 + cp := 0 + for i := 0; i < len(s); { + switch s[i] { + case '[': + cs++ + case ']': + if cs--; cs < 0 { // An unmatched ']' is legal. + cs = 0 + } + case '(': + if cs == 0 { + cp++ + } + case ')': + if cs == 0 { + cp-- + } + case '\\': + i++ + case '/': + if cs == 0 && cp == 0 { + a = append(a, s[:i]) + s = s[i+1:] + i = 0 + continue + } + case '|': + if cs == 0 && cp == 0 { + a = append(a, s[:i]) + s = s[i+1:] + i = 0 + b = append(b, a) + a = make(simpleMatch, 0, len(a)) + continue + } + } + i++ + } + + a = append(a, s) + if len(b) == 0 { + return a + } + return append(b, a) +} + +// unique creates a unique name for the given parent and subname by affixing it +// with one or more counts, if necessary. +func (m *matcher) unique(parent, subname string) string { + base := parent + "/" + subname + + for { + n := m.subNames[base] + if n < 0 { + panic("subtest count overflow") + } + m.subNames[base] = n + 1 + + if n == 0 && subname != "" { + prefix, nn := parseSubtestNumber(base) + if len(prefix) < len(base) && nn < m.subNames[prefix] { + // This test is explicitly named like "parent/subname#NN", + // and #NN was already used for the NNth occurrence of "parent/subname". + // Loop to add a disambiguating suffix. + continue + } + return base + } + + name := fmt.Sprintf("%s#%02d", base, n) + if m.subNames[name] != 0 { + // This is the nth occurrence of base, but the name "parent/subname#NN" + // collides with the first occurrence of a subtest *explicitly* named + // "parent/subname#NN". Try the next number. + continue + } + + return name + } +} + +// parseSubtestNumber splits a subtest name into a "#%02d"-formatted int32 +// suffix (if present), and a prefix preceding that suffix (always). +func parseSubtestNumber(s string) (prefix string, nn int32) { + i := strings.LastIndex(s, "#") + if i < 0 { + return s, 0 + } + + prefix, suffix := s[:i], s[i+1:] + if len(suffix) < 2 || (len(suffix) > 2 && suffix[0] == '0') { + // Even if suffix is numeric, it is not a possible output of a "%02" format + // string: it has either too few digits or too many leading zeroes. + return s, 0 + } + if suffix == "00" { + if !strings.HasSuffix(prefix, "/") { + // We only use "#00" as a suffix for subtests named with the empty + // string — it isn't a valid suffix if the subtest name is non-empty. + return s, 0 + } + } + + n, err := strconv.ParseInt(suffix, 10, 32) + if err != nil || n < 0 { + return s, 0 + } + return prefix, int32(n) +} + +// rewrite rewrites a subname to having only printable characters and no white +// space. +func rewrite(s string) string { + b := []byte{} + for _, r := range s { + switch { + case isSpace(r): + b = append(b, '_') + case !strconv.IsPrint(r): + s := strconv.QuoteRune(r) + b = append(b, s[1:len(s)-1]...) + default: + b = append(b, string(r)...) + } + } + return string(b) +} + +func isSpace(r rune) bool { + if r < 0x2000 { + switch r { + // Note: not the same as Unicode Z class. + case '\t', '\n', '\v', '\f', '\r', ' ', 0x85, 0xA0, 0x1680: + return true + } + } else { + if r <= 0x200a { + return true + } + switch r { + case 0x2028, 0x2029, 0x202f, 0x205f, 0x3000: + return true + } + } + return false +} diff --git a/cmd/envtool/internal/testmatch/match_test.go b/cmd/envtool/internal/testmatch/match_test.go new file mode 100644 index 000000000000..a68b493df483 --- /dev/null +++ b/cmd/envtool/internal/testmatch/match_test.go @@ -0,0 +1,277 @@ +// Copyright 2021 FerretDB Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file is originally from +// https://go.googlesource.com/go/+/d31efbc95e6803742aaca39e3a825936791e6b5a/src/testing/match_test.go +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file at +// https://go.googlesource.com/go/+/d31efbc95e6803742aaca39e3a825936791e6b5a/LICENSE + +package testmatch + +import ( + "fmt" + "reflect" + "regexp" + "strings" + "testing" + "unicode" +) + +// Verify that our IsSpace agrees with unicode.IsSpace. +func TestIsSpace(t *testing.T) { + n := 0 + for r := rune(0); r <= unicode.MaxRune; r++ { + if isSpace(r) != unicode.IsSpace(r) { + t.Errorf("IsSpace(%U)=%t incorrect", r, isSpace(r)) + n++ + if n > 10 { + return + } + } + } +} + +func TestSplitRegexp(t *testing.T) { + res := func(s ...string) filterMatch { return simpleMatch(s) } + alt := func(m ...filterMatch) filterMatch { return alternationMatch(m) } + testCases := []struct { + pattern string + result filterMatch + }{ + // Correct patterns + // If a regexp pattern is correct, all split regexps need to be correct + // as well. + {"", res("")}, + {"/", res("", "")}, + {"//", res("", "", "")}, + {"A", res("A")}, + {"A/B", res("A", "B")}, + {"A/B/", res("A", "B", "")}, + {"/A/B/", res("", "A", "B", "")}, + {"[A]/(B)", res("[A]", "(B)")}, + {"[/]/[/]", res("[/]", "[/]")}, + {"[/]/[:/]", res("[/]", "[:/]")}, + {"/]", res("", "]")}, + {"]/", res("]", "")}, + {"]/[/]", res("]", "[/]")}, + {`([)/][(])`, res(`([)/][(])`)}, + {"[(]/[)]", res("[(]", "[)]")}, + + {"A/B|C/D", alt(res("A", "B"), res("C", "D"))}, + + // Faulty patterns + // Errors in original should produce at least one faulty regexp in results. + {")/", res(")/")}, + {")/(/)", res(")/(", ")")}, + {"a[/)b", res("a[/)b")}, + {"(/]", res("(/]")}, + {"(/", res("(/")}, + {"[/]/[/", res("[/]", "[/")}, + {`\p{/}`, res(`\p{`, "}")}, + {`\p/`, res(`\p`, "")}, + {`[[:/:]]`, res(`[[:/:]]`)}, + } + for _, tc := range testCases { + a := splitRegexp(tc.pattern) + if !reflect.DeepEqual(a, tc.result) { + t.Errorf("splitRegexp(%q) = %#v; want %#v", tc.pattern, a, tc.result) + } + + // If there is any error in the pattern, one of the returned subpatterns + // needs to have an error as well. + if _, err := regexp.Compile(tc.pattern); err != nil { + ok := true + if err := a.verify("", regexp.MatchString); err != nil { + ok = false + } + if ok { + t.Errorf("%s: expected error in any of %q", tc.pattern, a) + } + } + } +} + +func TestMatcher(t *testing.T) { + testCases := []struct { + pattern string + skip string + parent, sub string + ok bool + partial bool + }{ + // Behavior without subtests. + {"", "", "", "TestFoo", true, false}, + {"TestFoo", "", "", "TestFoo", true, false}, + {"TestFoo/", "", "", "TestFoo", true, true}, + {"TestFoo/bar/baz", "", "", "TestFoo", true, true}, + {"TestFoo", "", "", "TestBar", false, false}, + {"TestFoo/", "", "", "TestBar", false, false}, + {"TestFoo/bar/baz", "", "", "TestBar/bar/baz", false, false}, + {"", "TestBar", "", "TestFoo", true, false}, + {"", "TestBar", "", "TestBar", false, false}, + + // Skipping a non-existent test doesn't change anything. + {"", "TestFoo/skipped", "", "TestFoo", true, false}, + {"TestFoo", "TestFoo/skipped", "", "TestFoo", true, false}, + {"TestFoo/", "TestFoo/skipped", "", "TestFoo", true, true}, + {"TestFoo/bar/baz", "TestFoo/skipped", "", "TestFoo", true, true}, + {"TestFoo", "TestFoo/skipped", "", "TestBar", false, false}, + {"TestFoo/", "TestFoo/skipped", "", "TestBar", false, false}, + {"TestFoo/bar/baz", "TestFoo/skipped", "", "TestBar/bar/baz", false, false}, + + // with subtests + {"", "", "TestFoo", "x", true, false}, + {"TestFoo", "", "TestFoo", "x", true, false}, + {"TestFoo/", "", "TestFoo", "x", true, false}, + {"TestFoo/bar/baz", "", "TestFoo", "bar", true, true}, + + {"", "TestFoo/skipped", "TestFoo", "x", true, false}, + {"TestFoo", "TestFoo/skipped", "TestFoo", "x", true, false}, + {"TestFoo", "TestFoo/skipped", "TestFoo", "skipped", false, false}, + {"TestFoo/", "TestFoo/skipped", "TestFoo", "x", true, false}, + {"TestFoo/bar/baz", "TestFoo/skipped", "TestFoo", "bar", true, true}, + + // Subtest with a '/' in its name still allows for copy and pasted names + // to match. + {"TestFoo/bar/baz", "", "TestFoo", "bar/baz", true, false}, + {"TestFoo/bar/baz", "TestFoo/bar/baz", "TestFoo", "bar/baz", false, false}, + {"TestFoo/bar/baz", "TestFoo/bar/baz/skip", "TestFoo", "bar/baz", true, false}, + {"TestFoo/bar/baz", "", "TestFoo/bar", "baz", true, false}, + {"TestFoo/bar/baz", "", "TestFoo", "x", false, false}, + {"TestFoo", "", "TestBar", "x", false, false}, + {"TestFoo/", "", "TestBar", "x", false, false}, + {"TestFoo/bar/baz", "", "TestBar", "x/bar/baz", false, false}, + + {"A/B|C/D", "", "TestA", "B", true, false}, + {"A/B|C/D", "", "TestC", "D", true, false}, + {"A/B|C/D", "", "TestA", "C", false, false}, + + // subtests only + {"", "", "TestFoo", "x", true, false}, + {"/", "", "TestFoo", "x", true, false}, + {"./", "", "TestFoo", "x", true, false}, + {"./.", "", "TestFoo", "x", true, false}, + {"/bar/baz", "", "TestFoo", "bar", true, true}, + {"/bar/baz", "", "TestFoo", "bar/baz", true, false}, + {"//baz", "", "TestFoo", "bar/baz", true, false}, + {"//", "", "TestFoo", "bar/baz", true, false}, + {"/bar/baz", "", "TestFoo/bar", "baz", true, false}, + {"//foo", "", "TestFoo", "bar/baz", false, false}, + {"/bar/baz", "", "TestFoo", "x", false, false}, + {"/bar/baz", "", "TestBar", "x/bar/baz", false, false}, + } + + for _, tc := range testCases { + m := newMatcher(regexp.MatchString, tc.pattern, "-test.run", tc.skip) + + parent := &common{name: tc.parent} + if tc.parent != "" { + parent.level = 1 + } + if n, ok, partial := m.fullName(parent, tc.sub); ok != tc.ok || partial != tc.partial { + t.Errorf("for pattern %q, fullName(parent=%q, sub=%q) = %q, ok %v partial %v; want ok %v partial %v", + tc.pattern, tc.parent, tc.sub, n, ok, partial, tc.ok, tc.partial) + } + } +} + +var namingTestCases = []struct{ name, want string }{ + // Uniqueness + {"", "x/#00"}, + {"", "x/#01"}, + {"#0", "x/#0"}, // Doesn't conflict with #00 because the number of digits differs. + {"#00", "x/#00#01"}, // Conflicts with implicit #00 (used above), so add a suffix. + {"#", "x/#"}, + {"#", "x/##01"}, + + {"t", "x/t"}, + {"t", "x/t#01"}, + {"t", "x/t#02"}, + {"t#00", "x/t#00"}, // Explicit "#00" doesn't conflict with the unsuffixed first subtest. + + {"a#01", "x/a#01"}, // user has subtest with this name. + {"a", "x/a"}, // doesn't conflict with this name. + {"a", "x/a#02"}, // This string is claimed now, so resume + {"a", "x/a#03"}, // with counting. + {"a#02", "x/a#02#01"}, // We already used a#02 once, so add a suffix. + + {"b#00", "x/b#00"}, + {"b", "x/b"}, // Implicit 0 doesn't conflict with explicit "#00". + {"b", "x/b#01"}, + {"b#9223372036854775807", "x/b#9223372036854775807"}, // MaxInt64 + {"b", "x/b#02"}, + {"b", "x/b#03"}, + + // Sanitizing + {"A:1 B:2", "x/A:1_B:2"}, + {"s\t\r\u00a0", "x/s___"}, + {"\x01", `x/\x01`}, + {"\U0010ffff", `x/\U0010ffff`}, +} + +func TestNaming(t *testing.T) { + m := newMatcher(regexp.MatchString, "", "", "") + parent := &common{name: "x", level: 1} // top-level test. + + for i, tc := range namingTestCases { + if got, _, _ := m.fullName(parent, tc.name); got != tc.want { + t.Errorf("%d:%s: got %q; want %q", i, tc.name, got, tc.want) + } + } +} + +func FuzzNaming(f *testing.F) { + for _, tc := range namingTestCases { + f.Add(tc.name) + } + parent := &common{name: "x", level: 1} + var m *matcher + var seen map[string]string + reset := func() { + m = allMatcher() + seen = make(map[string]string) + } + reset() + + f.Fuzz(func(t *testing.T, subname string) { + if len(subname) > 10 { + // Long names attract the OOM killer. + t.Skip() + } + name := m.unique(parent.name, subname) + if !strings.Contains(name, "/"+subname) { + t.Errorf("name %q does not contain subname %q", name, subname) + } + if prev, ok := seen[name]; ok { + t.Errorf("name %q generated by both %q and %q", name, prev, subname) + } + if len(seen) > 1e6 { + // Free up memory. + reset() + } + seen[name] = subname + }) +} + +// GoString returns a string that is more readable than the default, which makes +// it easier to read test errors. +func (m alternationMatch) GoString() string { + s := make([]string, len(m)) + for i, m := range m { + s[i] = fmt.Sprintf("%#v", m) + } + return fmt.Sprintf("(%s)", strings.Join(s, " | ")) +} diff --git a/cmd/envtool/tests.go b/cmd/envtool/tests.go index f45dcea99304..c730b1ec0204 100644 --- a/cmd/envtool/tests.go +++ b/cmd/envtool/tests.go @@ -24,7 +24,6 @@ import ( "io" "os" "os/exec" - "regexp" "slices" "sort" "strconv" @@ -38,6 +37,7 @@ import ( "go.uber.org/zap" "golang.org/x/exp/maps" + "github.com/FerretDB/FerretDB/cmd/envtool/internal/testmatch" "github.com/FerretDB/FerretDB/internal/util/lazyerrors" "github.com/FerretDB/FerretDB/internal/util/must" "github.com/FerretDB/FerretDB/internal/util/observability" @@ -363,7 +363,7 @@ func testsRun(ctx context.Context, index, total uint, run, skip string, args []s return fmt.Errorf("--shard-index and --shard-total must be specified when --run is not") } - all, err := listTestFuncsWithRegex("", run, skip) + all, err := listTestFuncs("") if err != nil { return lazyerrors.Error(err) } @@ -372,12 +372,35 @@ func testsRun(ctx context.Context, index, total uint, run, skip string, args []s return fmt.Errorf("no tests to run") } - shard, err := shardTestFuncs(index, total, all) + var tests []string + + // Filter what top-level functions we want to test using the same logic as "go test". + m := testmatch.New(run, skip) + for _, t := range all { + if m.Match(t) { + tests = append(tests, t) + } + } + + // Then, shard all the tests but only run the ones that match the regex and that should + // be run on the specific shard. + shard, skipShard, err := shardTestFuncs(index, total, tests) if err != nil { return lazyerrors.Error(err) } - args = append(args, "-run="+buildGoTestRunRegex(shard)) + args = append(args, "-run="+run) + + if len(skipShard) > 0 { + if skip != "" { + skip += "|" + } + skip += "^(" + strings.Join(skipShard, "|") + ")$" + } + + if skip != "" { + args = append(args, "-skip="+skip) + } return runGoTest(ctx, args, len(shard), true, logger) } @@ -433,104 +456,40 @@ func listTestFuncs(dir string) ([]string, error) { return res, nil } -// listTestFuncsWithRegex returns regex-filtered names of all top-level test -// functions (tests, benchmarks, examples, fuzz functions) in the specified -// directory and subdirectories. -func listTestFuncsWithRegex(dir, run, skip string) ([]string, error) { - tests, err := listTestFuncs(dir) - if err != nil { - return nil, err - } - - if run == "" && skip == "" { - return tests, nil - } - - includeRegex, err := regexp.Compile(run) - if err != nil { - return nil, err - } - - if skip == "" { - return filterStringsByRegex(tests, includeRegex, nil), nil - } - - excludeRegex, err := regexp.Compile(skip) - if err != nil { - return nil, err - } - - return filterStringsByRegex(tests, includeRegex, excludeRegex), nil -} - -// filterStringsByRegex filters a slice of strings based on inclusion and exclusion -// criteria defined by regular expressions. -func filterStringsByRegex(tests []string, include, exclude *regexp.Regexp) []string { - res := []string{} - - for _, test := range tests { - if exclude != nil && exclude.MatchString(test) { - continue - } - - if include != nil && !include.MatchString(test) { - continue - } - - res = append(res, test) - } - - return res -} - -// buildGoTestRunRegex builds a regex for `go test -run` from the given test names. -func buildGoTestRunRegex(tests []string) string { - var sb strings.Builder - sb.WriteString("^(") - - for i, test := range tests { - if i != 0 { - sb.WriteString("|") - } - - sb.WriteString(test) - } - - sb.WriteString(")$") - - return sb.String() -} - // shardTestFuncs shards given top-level test functions. -func shardTestFuncs(index, total uint, testFuncs []string) ([]string, error) { +// It returns a slice of test functions to run and what test functions to skip for the given shard. +func shardTestFuncs(index, total uint, testFuncs []string) (run, skip []string, err error) { if index == 0 { - return nil, fmt.Errorf("index must be greater than 0") + return nil, nil, fmt.Errorf("index must be greater than 0") } if total == 0 { - return nil, fmt.Errorf("total must be greater than 0") + return nil, nil, fmt.Errorf("total must be greater than 0") } if index > total { - return nil, fmt.Errorf("cannot shard when index is greater than total (%d > %d)", index, total) + return nil, nil, fmt.Errorf("cannot shard when index is greater than total (%d > %d)", index, total) } l := uint(len(testFuncs)) if total > l { - return nil, fmt.Errorf("cannot shard when total is greater than a number of test functions (%d > %d)", total, l) + return nil, nil, fmt.Errorf("cannot shard when total is greater than a number of test functions (%d > %d)", total, l) } - res := make([]string, 0, l/total+1) + run = make([]string, 0, l/total+1) + skip = make([]string, 0, len(testFuncs)-len(run)) shard := uint(1) // use different shards for tests with similar names for better load balancing for _, test := range testFuncs { if index == shard { - res = append(res, test) + run = append(run, test) + } else { + skip = append(skip, test) } shard = shard%total + 1 } - return res, nil + return run, skip, nil } diff --git a/cmd/envtool/tests_test.go b/cmd/envtool/tests_test.go index 8d795366edf2..ab333b7f910c 100644 --- a/cmd/envtool/tests_test.go +++ b/cmd/envtool/tests_test.go @@ -222,200 +222,6 @@ func TestListTestFuncs(t *testing.T) { assert.Equal(t, expected, actual) } -func TestListTestFuncsWithRegex(t *testing.T) { - tests := []struct { - wantErr assert.ErrorAssertionFunc - name string - run string - skip string - expected []string - }{ - { - name: "NoRunNoSkip", - run: "", - skip: "", - expected: []string{ - "TestError1", - "TestError2", - "TestNormal1", - "TestNormal2", - "TestPanic1", - "TestSkip1", - }, - wantErr: assert.NoError, - }, - { - name: "Run", - run: "TestError", - skip: "", - expected: []string{ - "TestError1", - "TestError2", - }, - wantErr: assert.NoError, - }, - { - name: "Skip", - run: "", - skip: "TestError", - expected: []string{ - "TestNormal1", - "TestNormal2", - "TestPanic1", - "TestSkip1", - }, - wantErr: assert.NoError, - }, - { - name: "RunSkip", - run: "TestError", - skip: "TestError2", - expected: []string{ - "TestError1", - }, - wantErr: assert.NoError, - }, - { - name: "RunSkipAll", - run: "TestError", - skip: "TestError", - expected: []string{}, - wantErr: assert.NoError, - }, - { - name: "InvalidRun", - run: "[", - skip: "", - wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { - return assert.Contains(t, err.Error(), "error parsing regexp") - }, - }, - { - name: "InvalidSkip", - run: "", - skip: "[", - wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { - return assert.Contains(t, err.Error(), "error parsing regexp") - }, - }, - } - - for _, tt := range tests { - tt := tt - - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - actual, err := listTestFuncsWithRegex("./testdata", tt.run, tt.skip) - tt.wantErr(t, err) - assert.Equal(t, tt.expected, actual) - }) - } -} - -func TestBuildGoTestRunRegex(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - expected string - tests []string - }{ - { - name: "Empty", - tests: []string{}, - expected: "^()$", - }, - { - name: "Single", - tests: []string{"Test1"}, - expected: "^(Test1)$", - }, - { - name: "Multiple", - tests: []string{"Test1", "Test2"}, - expected: "^(Test1|Test2)$", - }, - } - - for _, tt := range tests { - tt := tt - - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - actual := buildGoTestRunRegex(tt.tests) - assert.Equal(t, tt.expected, actual) - }) - } -} - -func TestFilterStringsByRegex(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - tests []string - include *regexp.Regexp - exclude *regexp.Regexp - expected []string - }{ - { - name: "Empty", - tests: []string{}, - include: nil, - exclude: nil, - expected: []string{}, - }, - { - name: "Include", - tests: []string{"Test1", "Test2"}, - include: regexp.MustCompile("Test1"), - exclude: nil, - expected: []string{"Test1"}, - }, - { - name: "Exclude", - tests: []string{"Test1", "Test2"}, - include: nil, - exclude: regexp.MustCompile("Test1"), - expected: []string{"Test2"}, - }, - { - name: "IncludeExclude", - tests: []string{"Test1", "Test2"}, - include: regexp.MustCompile("Test1"), - exclude: regexp.MustCompile("Test1"), - expected: []string{}, - }, - { - name: "IncludeExclude2", - tests: []string{"Test1", "Test2"}, - include: regexp.MustCompile("Test1"), - exclude: regexp.MustCompile("Test2"), - expected: []string{"Test1"}, - }, - { - name: "NotMatch", - tests: []string{"Test1", "Test2"}, - include: regexp.MustCompile("Test3"), - exclude: regexp.MustCompile("Test3"), - expected: []string{}, - }, - } - - for _, tt := range tests { - tt := tt - - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - actual := filterStringsByRegex(tt.tests, tt.include, tt.exclude) - assert.Equal(t, tt.expected, actual) - }) - } -} - func TestShardTestFuncs(t *testing.T) { t.Parallel() @@ -427,34 +233,34 @@ func TestShardTestFuncs(t *testing.T) { t.Run("InvalidIndex", func(t *testing.T) { t.Parallel() - _, err := shardTestFuncs(0, 3, testFuncs) + _, _, err := shardTestFuncs(0, 3, testFuncs) assert.EqualError(t, err, "index must be greater than 0") - _, err = shardTestFuncs(3, 3, testFuncs) + _, _, err = shardTestFuncs(3, 3, testFuncs) assert.NoError(t, err) - _, err = shardTestFuncs(4, 3, testFuncs) + _, _, err = shardTestFuncs(4, 3, testFuncs) assert.EqualError(t, err, "cannot shard when index is greater than total (4 > 3)") }) t.Run("InvalidTotal", func(t *testing.T) { t.Parallel() - _, err := shardTestFuncs(3, 1000, testFuncs[:42]) + _, _, err := shardTestFuncs(3, 1000, testFuncs[:42]) assert.EqualError(t, err, "cannot shard when total is greater than a number of test functions (1000 > 42)") }) t.Run("Valid", func(t *testing.T) { t.Parallel() - res, err := shardTestFuncs(1, 3, testFuncs) + res, _, err := shardTestFuncs(1, 3, testFuncs) require.NoError(t, err) assert.Equal(t, testFuncs[0], res[0]) assert.NotEqual(t, testFuncs[1], res[1]) assert.NotEqual(t, testFuncs[2], res[1]) assert.Equal(t, testFuncs[3], res[1]) - res, err = shardTestFuncs(3, 3, testFuncs) + res, _, err = shardTestFuncs(3, 3, testFuncs) require.NoError(t, err) assert.NotEqual(t, testFuncs[0], res[0]) assert.NotEqual(t, testFuncs[1], res[0]) From 2d8fc80de5412b82d0ff9f5ef6666da7b3c2dabb Mon Sep 17 00:00:00 2001 From: Henrique Vicente Date: Wed, 21 Feb 2024 02:34:17 +0100 Subject: [PATCH 02/13] Removing testmatch code --- cmd/envtool/internal/testmatch/api.go | 40 --- cmd/envtool/internal/testmatch/api_test.go | 65 ---- cmd/envtool/internal/testmatch/match.go | 334 ------------------- cmd/envtool/internal/testmatch/match_test.go | 277 --------------- cmd/envtool/tests.go | 19 +- cmd/envtool/tests_test.go | 3 +- 6 files changed, 18 insertions(+), 720 deletions(-) delete mode 100644 cmd/envtool/internal/testmatch/api.go delete mode 100644 cmd/envtool/internal/testmatch/api_test.go delete mode 100644 cmd/envtool/internal/testmatch/match.go delete mode 100644 cmd/envtool/internal/testmatch/match_test.go diff --git a/cmd/envtool/internal/testmatch/api.go b/cmd/envtool/internal/testmatch/api.go deleted file mode 100644 index 350221277944..000000000000 --- a/cmd/envtool/internal/testmatch/api.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2021 FerretDB Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package testmatch - -import "regexp" - -type Matcher struct { - m *matcher -} - -// New matcher. -func New(run, skip string) *Matcher { - return &Matcher{ - m: newMatcher(regexp.MatchString, run, "-test.run", skip), - } -} - -// Match top-level test function. -func (m *Matcher) Match(testFunction string) bool { - _, ok, _ := m.m.fullName(&common{}, testFunction) - return ok -} - -// common is used internally by the matcher. -type common struct { - name string // name of the test - level int // level of the test -} diff --git a/cmd/envtool/internal/testmatch/api_test.go b/cmd/envtool/internal/testmatch/api_test.go deleted file mode 100644 index 2d7cb9e4dc1f..000000000000 --- a/cmd/envtool/internal/testmatch/api_test.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2021 FerretDB Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Tests based on match_test.go -// This file is a modification of -// https://go.googlesource.com/go/+/d31efbc95e6803742aaca39e3a825936791e6b5a/src/testing/match_test.go -// Copyright 2015 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file at -// https://go.googlesource.com/go/+/d31efbc95e6803742aaca39e3a825936791e6b5a/LICENSE - -package testmatch - -import ( - "testing" -) - -func TestMatcherAPI(t *testing.T) { - testCases := []struct { - pattern string - skip string - name string - ok bool - }{ - // Behavior without subtests. - {"", "", "TestFoo", true}, - {"TestFoo", "", "TestFoo", true}, - {"TestFoo/", "", "TestFoo", true}, - {"TestFoo/bar/baz", "", "TestFoo", true}, - {"TestFoo", "", "TestBar", false}, - {"TestFoo/", "", "TestBar", false}, - {"TestFoo/bar/baz", "", "TestBar/bar/baz", false}, - {"", "TestBar", "TestFoo", true}, - {"", "TestBar", "TestBar", false}, - - // Skipping a non-existent test doesn't change anything. - {"", "TestFoo/skipped", "TestFoo", true}, - {"TestFoo", "TestFoo/skipped", "TestFoo", true}, - {"TestFoo/", "TestFoo/skipped", "TestFoo", true}, - {"TestFoo/bar/baz", "TestFoo/skipped", "TestFoo", true}, - {"TestFoo", "TestFoo/skipped", "TestBar", false}, - {"TestFoo/", "TestFoo/skipped", "TestBar", false}, - {"TestFoo/bar/baz", "TestFoo/skipped", "TestBar/bar/baz", false}, - } - - for _, tc := range testCases { - m := New(tc.pattern, tc.skip) - - if ok := m.Match(tc.name); ok != tc.ok { - t.Errorf("for pattern %q, Match(%q) = %v; want ok %v", - tc.pattern, tc.name, ok, tc.ok) - } - } -} diff --git a/cmd/envtool/internal/testmatch/match.go b/cmd/envtool/internal/testmatch/match.go deleted file mode 100644 index 005fe3d6ce98..000000000000 --- a/cmd/envtool/internal/testmatch/match.go +++ /dev/null @@ -1,334 +0,0 @@ -// Copyright 2021 FerretDB Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file is originally from -// https://go.googlesource.com/go/+/3bc28402fae2a1646e4d2756344b5eb34994d25f/src/testing/match.go -// Copyright 2015 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file at -// https://go.googlesource.com/go/+/3bc28402fae2a1646e4d2756344b5eb34994d25f/LICENSE - -package testmatch - -import ( - "fmt" - "os" - "strconv" - "strings" - "sync" -) - -// matcher sanitizes, uniques, and filters names of subtests and subbenchmarks. -type matcher struct { - filter filterMatch - skip filterMatch - matchFunc func(pat, str string) (bool, error) - - mu sync.Mutex - - // subNames is used to deduplicate subtest names. - // Each key is the subtest name joined to the deduplicated name of the parent test. - // Each value is the count of the number of occurrences of the given subtest name - // already seen. - subNames map[string]int32 -} - -type filterMatch interface { - // matches checks the name against the receiver's pattern strings using the - // given match function. - matches(name []string, matchString func(pat, str string) (bool, error)) (ok, partial bool) - - // verify checks that the receiver's pattern strings are valid filters by - // calling the given match function. - verify(name string, matchString func(pat, str string) (bool, error)) error -} - -// simpleMatch matches a test name if all of the pattern strings match in -// sequence. -type simpleMatch []string - -// alternationMatch matches a test name if one of the alternations match. -type alternationMatch []filterMatch - -// TODO: fix test_main to avoid race and improve caching, also allowing to -// eliminate this Mutex. -var matchMutex sync.Mutex - -func allMatcher() *matcher { - return newMatcher(nil, "", "", "") -} - -func newMatcher(matchString func(pat, str string) (bool, error), patterns, name, skips string) *matcher { - var filter, skip filterMatch - if patterns == "" { - filter = simpleMatch{} // always partial true - } else { - filter = splitRegexp(patterns) - if err := filter.verify(name, matchString); err != nil { - fmt.Fprintf(os.Stderr, "testing: invalid regexp for %s\n", err) - os.Exit(1) - } - } - if skips == "" { - skip = alternationMatch{} // always false - } else { - skip = splitRegexp(skips) - if err := skip.verify("-test.skip", matchString); err != nil { - fmt.Fprintf(os.Stderr, "testing: invalid regexp for %v\n", err) - os.Exit(1) - } - } - return &matcher{ - filter: filter, - skip: skip, - matchFunc: matchString, - subNames: map[string]int32{}, - } -} - -func (m *matcher) fullName(c *common, subname string) (name string, ok, partial bool) { - name = subname - - m.mu.Lock() - defer m.mu.Unlock() - - if c != nil && c.level > 0 { - name = m.unique(c.name, rewrite(subname)) - } - - matchMutex.Lock() - defer matchMutex.Unlock() - - // We check the full array of paths each time to allow for the case that a pattern contains a '/'. - elem := strings.Split(name, "/") - - // filter must match. - // accept partial match that may produce full match later. - ok, partial = m.filter.matches(elem, m.matchFunc) - if !ok { - return name, false, false - } - - // skip must not match. - // ignore partial match so we can get to more precise match later. - skip, partialSkip := m.skip.matches(elem, m.matchFunc) - if skip && !partialSkip { - return name, false, false - } - - return name, ok, partial -} - -// clearSubNames clears the matcher's internal state, potentially freeing -// memory. After this is called, T.Name may return the same strings as it did -// for earlier subtests. -func (m *matcher) clearSubNames() { - m.mu.Lock() - defer m.mu.Unlock() - clear(m.subNames) -} - -func (m simpleMatch) matches(name []string, matchString func(pat, str string) (bool, error)) (ok, partial bool) { - for i, s := range name { - if i >= len(m) { - break - } - if ok, _ := matchString(m[i], s); !ok { - return false, false - } - } - return true, len(name) < len(m) -} - -func (m simpleMatch) verify(name string, matchString func(pat, str string) (bool, error)) error { - for i, s := range m { - m[i] = rewrite(s) - } - // Verify filters before doing any processing. - for i, s := range m { - if _, err := matchString(s, "non-empty"); err != nil { - return fmt.Errorf("element %d of %s (%q): %s", i, name, s, err) - } - } - return nil -} - -func (m alternationMatch) matches(name []string, matchString func(pat, str string) (bool, error)) (ok, partial bool) { - for _, m := range m { - if ok, partial = m.matches(name, matchString); ok { - return ok, partial - } - } - return false, false -} - -func (m alternationMatch) verify(name string, matchString func(pat, str string) (bool, error)) error { - for i, m := range m { - if err := m.verify(name, matchString); err != nil { - return fmt.Errorf("alternation %d of %s", i, err) - } - } - return nil -} - -func splitRegexp(s string) filterMatch { - a := make(simpleMatch, 0, strings.Count(s, "/")) - b := make(alternationMatch, 0, strings.Count(s, "|")) - cs := 0 - cp := 0 - for i := 0; i < len(s); { - switch s[i] { - case '[': - cs++ - case ']': - if cs--; cs < 0 { // An unmatched ']' is legal. - cs = 0 - } - case '(': - if cs == 0 { - cp++ - } - case ')': - if cs == 0 { - cp-- - } - case '\\': - i++ - case '/': - if cs == 0 && cp == 0 { - a = append(a, s[:i]) - s = s[i+1:] - i = 0 - continue - } - case '|': - if cs == 0 && cp == 0 { - a = append(a, s[:i]) - s = s[i+1:] - i = 0 - b = append(b, a) - a = make(simpleMatch, 0, len(a)) - continue - } - } - i++ - } - - a = append(a, s) - if len(b) == 0 { - return a - } - return append(b, a) -} - -// unique creates a unique name for the given parent and subname by affixing it -// with one or more counts, if necessary. -func (m *matcher) unique(parent, subname string) string { - base := parent + "/" + subname - - for { - n := m.subNames[base] - if n < 0 { - panic("subtest count overflow") - } - m.subNames[base] = n + 1 - - if n == 0 && subname != "" { - prefix, nn := parseSubtestNumber(base) - if len(prefix) < len(base) && nn < m.subNames[prefix] { - // This test is explicitly named like "parent/subname#NN", - // and #NN was already used for the NNth occurrence of "parent/subname". - // Loop to add a disambiguating suffix. - continue - } - return base - } - - name := fmt.Sprintf("%s#%02d", base, n) - if m.subNames[name] != 0 { - // This is the nth occurrence of base, but the name "parent/subname#NN" - // collides with the first occurrence of a subtest *explicitly* named - // "parent/subname#NN". Try the next number. - continue - } - - return name - } -} - -// parseSubtestNumber splits a subtest name into a "#%02d"-formatted int32 -// suffix (if present), and a prefix preceding that suffix (always). -func parseSubtestNumber(s string) (prefix string, nn int32) { - i := strings.LastIndex(s, "#") - if i < 0 { - return s, 0 - } - - prefix, suffix := s[:i], s[i+1:] - if len(suffix) < 2 || (len(suffix) > 2 && suffix[0] == '0') { - // Even if suffix is numeric, it is not a possible output of a "%02" format - // string: it has either too few digits or too many leading zeroes. - return s, 0 - } - if suffix == "00" { - if !strings.HasSuffix(prefix, "/") { - // We only use "#00" as a suffix for subtests named with the empty - // string — it isn't a valid suffix if the subtest name is non-empty. - return s, 0 - } - } - - n, err := strconv.ParseInt(suffix, 10, 32) - if err != nil || n < 0 { - return s, 0 - } - return prefix, int32(n) -} - -// rewrite rewrites a subname to having only printable characters and no white -// space. -func rewrite(s string) string { - b := []byte{} - for _, r := range s { - switch { - case isSpace(r): - b = append(b, '_') - case !strconv.IsPrint(r): - s := strconv.QuoteRune(r) - b = append(b, s[1:len(s)-1]...) - default: - b = append(b, string(r)...) - } - } - return string(b) -} - -func isSpace(r rune) bool { - if r < 0x2000 { - switch r { - // Note: not the same as Unicode Z class. - case '\t', '\n', '\v', '\f', '\r', ' ', 0x85, 0xA0, 0x1680: - return true - } - } else { - if r <= 0x200a { - return true - } - switch r { - case 0x2028, 0x2029, 0x202f, 0x205f, 0x3000: - return true - } - } - return false -} diff --git a/cmd/envtool/internal/testmatch/match_test.go b/cmd/envtool/internal/testmatch/match_test.go deleted file mode 100644 index a68b493df483..000000000000 --- a/cmd/envtool/internal/testmatch/match_test.go +++ /dev/null @@ -1,277 +0,0 @@ -// Copyright 2021 FerretDB Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file is originally from -// https://go.googlesource.com/go/+/d31efbc95e6803742aaca39e3a825936791e6b5a/src/testing/match_test.go -// Copyright 2015 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file at -// https://go.googlesource.com/go/+/d31efbc95e6803742aaca39e3a825936791e6b5a/LICENSE - -package testmatch - -import ( - "fmt" - "reflect" - "regexp" - "strings" - "testing" - "unicode" -) - -// Verify that our IsSpace agrees with unicode.IsSpace. -func TestIsSpace(t *testing.T) { - n := 0 - for r := rune(0); r <= unicode.MaxRune; r++ { - if isSpace(r) != unicode.IsSpace(r) { - t.Errorf("IsSpace(%U)=%t incorrect", r, isSpace(r)) - n++ - if n > 10 { - return - } - } - } -} - -func TestSplitRegexp(t *testing.T) { - res := func(s ...string) filterMatch { return simpleMatch(s) } - alt := func(m ...filterMatch) filterMatch { return alternationMatch(m) } - testCases := []struct { - pattern string - result filterMatch - }{ - // Correct patterns - // If a regexp pattern is correct, all split regexps need to be correct - // as well. - {"", res("")}, - {"/", res("", "")}, - {"//", res("", "", "")}, - {"A", res("A")}, - {"A/B", res("A", "B")}, - {"A/B/", res("A", "B", "")}, - {"/A/B/", res("", "A", "B", "")}, - {"[A]/(B)", res("[A]", "(B)")}, - {"[/]/[/]", res("[/]", "[/]")}, - {"[/]/[:/]", res("[/]", "[:/]")}, - {"/]", res("", "]")}, - {"]/", res("]", "")}, - {"]/[/]", res("]", "[/]")}, - {`([)/][(])`, res(`([)/][(])`)}, - {"[(]/[)]", res("[(]", "[)]")}, - - {"A/B|C/D", alt(res("A", "B"), res("C", "D"))}, - - // Faulty patterns - // Errors in original should produce at least one faulty regexp in results. - {")/", res(")/")}, - {")/(/)", res(")/(", ")")}, - {"a[/)b", res("a[/)b")}, - {"(/]", res("(/]")}, - {"(/", res("(/")}, - {"[/]/[/", res("[/]", "[/")}, - {`\p{/}`, res(`\p{`, "}")}, - {`\p/`, res(`\p`, "")}, - {`[[:/:]]`, res(`[[:/:]]`)}, - } - for _, tc := range testCases { - a := splitRegexp(tc.pattern) - if !reflect.DeepEqual(a, tc.result) { - t.Errorf("splitRegexp(%q) = %#v; want %#v", tc.pattern, a, tc.result) - } - - // If there is any error in the pattern, one of the returned subpatterns - // needs to have an error as well. - if _, err := regexp.Compile(tc.pattern); err != nil { - ok := true - if err := a.verify("", regexp.MatchString); err != nil { - ok = false - } - if ok { - t.Errorf("%s: expected error in any of %q", tc.pattern, a) - } - } - } -} - -func TestMatcher(t *testing.T) { - testCases := []struct { - pattern string - skip string - parent, sub string - ok bool - partial bool - }{ - // Behavior without subtests. - {"", "", "", "TestFoo", true, false}, - {"TestFoo", "", "", "TestFoo", true, false}, - {"TestFoo/", "", "", "TestFoo", true, true}, - {"TestFoo/bar/baz", "", "", "TestFoo", true, true}, - {"TestFoo", "", "", "TestBar", false, false}, - {"TestFoo/", "", "", "TestBar", false, false}, - {"TestFoo/bar/baz", "", "", "TestBar/bar/baz", false, false}, - {"", "TestBar", "", "TestFoo", true, false}, - {"", "TestBar", "", "TestBar", false, false}, - - // Skipping a non-existent test doesn't change anything. - {"", "TestFoo/skipped", "", "TestFoo", true, false}, - {"TestFoo", "TestFoo/skipped", "", "TestFoo", true, false}, - {"TestFoo/", "TestFoo/skipped", "", "TestFoo", true, true}, - {"TestFoo/bar/baz", "TestFoo/skipped", "", "TestFoo", true, true}, - {"TestFoo", "TestFoo/skipped", "", "TestBar", false, false}, - {"TestFoo/", "TestFoo/skipped", "", "TestBar", false, false}, - {"TestFoo/bar/baz", "TestFoo/skipped", "", "TestBar/bar/baz", false, false}, - - // with subtests - {"", "", "TestFoo", "x", true, false}, - {"TestFoo", "", "TestFoo", "x", true, false}, - {"TestFoo/", "", "TestFoo", "x", true, false}, - {"TestFoo/bar/baz", "", "TestFoo", "bar", true, true}, - - {"", "TestFoo/skipped", "TestFoo", "x", true, false}, - {"TestFoo", "TestFoo/skipped", "TestFoo", "x", true, false}, - {"TestFoo", "TestFoo/skipped", "TestFoo", "skipped", false, false}, - {"TestFoo/", "TestFoo/skipped", "TestFoo", "x", true, false}, - {"TestFoo/bar/baz", "TestFoo/skipped", "TestFoo", "bar", true, true}, - - // Subtest with a '/' in its name still allows for copy and pasted names - // to match. - {"TestFoo/bar/baz", "", "TestFoo", "bar/baz", true, false}, - {"TestFoo/bar/baz", "TestFoo/bar/baz", "TestFoo", "bar/baz", false, false}, - {"TestFoo/bar/baz", "TestFoo/bar/baz/skip", "TestFoo", "bar/baz", true, false}, - {"TestFoo/bar/baz", "", "TestFoo/bar", "baz", true, false}, - {"TestFoo/bar/baz", "", "TestFoo", "x", false, false}, - {"TestFoo", "", "TestBar", "x", false, false}, - {"TestFoo/", "", "TestBar", "x", false, false}, - {"TestFoo/bar/baz", "", "TestBar", "x/bar/baz", false, false}, - - {"A/B|C/D", "", "TestA", "B", true, false}, - {"A/B|C/D", "", "TestC", "D", true, false}, - {"A/B|C/D", "", "TestA", "C", false, false}, - - // subtests only - {"", "", "TestFoo", "x", true, false}, - {"/", "", "TestFoo", "x", true, false}, - {"./", "", "TestFoo", "x", true, false}, - {"./.", "", "TestFoo", "x", true, false}, - {"/bar/baz", "", "TestFoo", "bar", true, true}, - {"/bar/baz", "", "TestFoo", "bar/baz", true, false}, - {"//baz", "", "TestFoo", "bar/baz", true, false}, - {"//", "", "TestFoo", "bar/baz", true, false}, - {"/bar/baz", "", "TestFoo/bar", "baz", true, false}, - {"//foo", "", "TestFoo", "bar/baz", false, false}, - {"/bar/baz", "", "TestFoo", "x", false, false}, - {"/bar/baz", "", "TestBar", "x/bar/baz", false, false}, - } - - for _, tc := range testCases { - m := newMatcher(regexp.MatchString, tc.pattern, "-test.run", tc.skip) - - parent := &common{name: tc.parent} - if tc.parent != "" { - parent.level = 1 - } - if n, ok, partial := m.fullName(parent, tc.sub); ok != tc.ok || partial != tc.partial { - t.Errorf("for pattern %q, fullName(parent=%q, sub=%q) = %q, ok %v partial %v; want ok %v partial %v", - tc.pattern, tc.parent, tc.sub, n, ok, partial, tc.ok, tc.partial) - } - } -} - -var namingTestCases = []struct{ name, want string }{ - // Uniqueness - {"", "x/#00"}, - {"", "x/#01"}, - {"#0", "x/#0"}, // Doesn't conflict with #00 because the number of digits differs. - {"#00", "x/#00#01"}, // Conflicts with implicit #00 (used above), so add a suffix. - {"#", "x/#"}, - {"#", "x/##01"}, - - {"t", "x/t"}, - {"t", "x/t#01"}, - {"t", "x/t#02"}, - {"t#00", "x/t#00"}, // Explicit "#00" doesn't conflict with the unsuffixed first subtest. - - {"a#01", "x/a#01"}, // user has subtest with this name. - {"a", "x/a"}, // doesn't conflict with this name. - {"a", "x/a#02"}, // This string is claimed now, so resume - {"a", "x/a#03"}, // with counting. - {"a#02", "x/a#02#01"}, // We already used a#02 once, so add a suffix. - - {"b#00", "x/b#00"}, - {"b", "x/b"}, // Implicit 0 doesn't conflict with explicit "#00". - {"b", "x/b#01"}, - {"b#9223372036854775807", "x/b#9223372036854775807"}, // MaxInt64 - {"b", "x/b#02"}, - {"b", "x/b#03"}, - - // Sanitizing - {"A:1 B:2", "x/A:1_B:2"}, - {"s\t\r\u00a0", "x/s___"}, - {"\x01", `x/\x01`}, - {"\U0010ffff", `x/\U0010ffff`}, -} - -func TestNaming(t *testing.T) { - m := newMatcher(regexp.MatchString, "", "", "") - parent := &common{name: "x", level: 1} // top-level test. - - for i, tc := range namingTestCases { - if got, _, _ := m.fullName(parent, tc.name); got != tc.want { - t.Errorf("%d:%s: got %q; want %q", i, tc.name, got, tc.want) - } - } -} - -func FuzzNaming(f *testing.F) { - for _, tc := range namingTestCases { - f.Add(tc.name) - } - parent := &common{name: "x", level: 1} - var m *matcher - var seen map[string]string - reset := func() { - m = allMatcher() - seen = make(map[string]string) - } - reset() - - f.Fuzz(func(t *testing.T, subname string) { - if len(subname) > 10 { - // Long names attract the OOM killer. - t.Skip() - } - name := m.unique(parent.name, subname) - if !strings.Contains(name, "/"+subname) { - t.Errorf("name %q does not contain subname %q", name, subname) - } - if prev, ok := seen[name]; ok { - t.Errorf("name %q generated by both %q and %q", name, prev, subname) - } - if len(seen) > 1e6 { - // Free up memory. - reset() - } - seen[name] = subname - }) -} - -// GoString returns a string that is more readable than the default, which makes -// it easier to read test errors. -func (m alternationMatch) GoString() string { - s := make([]string, len(m)) - for i, m := range m { - s[i] = fmt.Sprintf("%#v", m) - } - return fmt.Sprintf("(%s)", strings.Join(s, " | ")) -} diff --git a/cmd/envtool/tests.go b/cmd/envtool/tests.go index c730b1ec0204..b690b609065a 100644 --- a/cmd/envtool/tests.go +++ b/cmd/envtool/tests.go @@ -24,6 +24,7 @@ import ( "io" "os" "os/exec" + "regexp" "slices" "sort" "strconv" @@ -37,7 +38,6 @@ import ( "go.uber.org/zap" "golang.org/x/exp/maps" - "github.com/FerretDB/FerretDB/cmd/envtool/internal/testmatch" "github.com/FerretDB/FerretDB/internal/util/lazyerrors" "github.com/FerretDB/FerretDB/internal/util/must" "github.com/FerretDB/FerretDB/internal/util/observability" @@ -375,9 +375,22 @@ func testsRun(ctx context.Context, index, total uint, run, skip string, args []s var tests []string // Filter what top-level functions we want to test using the same logic as "go test". - m := testmatch.New(run, skip) + var ( + rxRun *regexp.Regexp + rxSkip *regexp.Regexp + ) + + if run != "" { + rxRun = regexp.MustCompile(run) + } + + if skip != "" { + rxSkip = regexp.MustCompile(skip) + } + for _, t := range all { - if m.Match(t) { + if (skip == "" || !rxSkip.MatchString(t)) && + (run == "" || rxRun.MatchString(t)) { tests = append(tests, t) } } diff --git a/cmd/envtool/tests_test.go b/cmd/envtool/tests_test.go index ab333b7f910c..4ba8775d9d58 100644 --- a/cmd/envtool/tests_test.go +++ b/cmd/envtool/tests_test.go @@ -253,12 +253,13 @@ func TestShardTestFuncs(t *testing.T) { t.Run("Valid", func(t *testing.T) { t.Parallel() - res, _, err := shardTestFuncs(1, 3, testFuncs) + res, skip, err := shardTestFuncs(1, 3, testFuncs) require.NoError(t, err) assert.Equal(t, testFuncs[0], res[0]) assert.NotEqual(t, testFuncs[1], res[1]) assert.NotEqual(t, testFuncs[2], res[1]) assert.Equal(t, testFuncs[3], res[1]) + assert.NotEmpty(t, skip) res, _, err = shardTestFuncs(3, 3, testFuncs) require.NoError(t, err) From 438e31d9dd4a82628e6f08a34a205b94455bd580 Mon Sep 17 00:00:00 2001 From: Henrique Vicente Date: Wed, 21 Feb 2024 10:10:16 +0100 Subject: [PATCH 03/13] wip --- cmd/envtool/tests.go | 81 +++++++++++++------- cmd/envtool/tests_test.go | 157 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 210 insertions(+), 28 deletions(-) diff --git a/cmd/envtool/tests.go b/cmd/envtool/tests.go index b690b609065a..f400e84b5680 100644 --- a/cmd/envtool/tests.go +++ b/cmd/envtool/tests.go @@ -363,38 +363,12 @@ func testsRun(ctx context.Context, index, total uint, run, skip string, args []s return fmt.Errorf("--shard-index and --shard-total must be specified when --run is not") } - all, err := listTestFuncs("") + tests, err := listTestFuncsWithRegex("", run, skip) + if err != nil { return lazyerrors.Error(err) } - if len(all) == 0 { - return fmt.Errorf("no tests to run") - } - - var tests []string - - // Filter what top-level functions we want to test using the same logic as "go test". - var ( - rxRun *regexp.Regexp - rxSkip *regexp.Regexp - ) - - if run != "" { - rxRun = regexp.MustCompile(run) - } - - if skip != "" { - rxSkip = regexp.MustCompile(skip) - } - - for _, t := range all { - if (skip == "" || !rxSkip.MatchString(t)) && - (run == "" || rxRun.MatchString(t)) { - tests = append(tests, t) - } - } - // Then, shard all the tests but only run the ones that match the regex and that should // be run on the specific shard. shard, skipShard, err := shardTestFuncs(index, total, tests) @@ -469,6 +443,57 @@ func listTestFuncs(dir string) ([]string, error) { return res, nil } +// listTestFuncsWithRegex returns regex-filtered names of all top-level test +// functions (tests, benchmarks, examples, fuzz functions) in the specified +// directory and subdirectories. +func listTestFuncsWithRegex(dir, run, skip string) ([]string, error) { + all, err := listTestFuncs(dir) + if err != nil { + return nil, err + } + + if len(all) == 0 { + return nil, fmt.Errorf("no tests to run") + } + + // Filter what top-level functions we want to test using the same logic as "go test". + var ( + rxRun *regexp.Regexp + rxSkip *regexp.Regexp + ) + + if run != "" { + rxRun, err = regexp.Compile(run) + if err != nil { + return nil, err + } + } + + if skip != "" { + rxSkip, err = regexp.Compile(skip) + if err != nil { + return nil, err + } + } + + return filterStringsByRegex(all, rxRun, rxSkip), nil +} + +// filterStringsByRegex filters a slice of strings based on inclusion and exclusion +// criteria defined by regular expressions. +func filterStringsByRegex(tests []string, include, exclude *regexp.Regexp) []string { + res := []string{} + + for _, t := range tests { + if (exclude == nil || !exclude.MatchString(t)) && + (include == nil || include.MatchString(t)) { + res = append(res, t) + } + } + + return res +} + // shardTestFuncs shards given top-level test functions. // It returns a slice of test functions to run and what test functions to skip for the given shard. func shardTestFuncs(index, total uint, testFuncs []string) (run, skip []string, err error) { diff --git a/cmd/envtool/tests_test.go b/cmd/envtool/tests_test.go index 4ba8775d9d58..aae64b2b6fc9 100644 --- a/cmd/envtool/tests_test.go +++ b/cmd/envtool/tests_test.go @@ -222,6 +222,163 @@ func TestListTestFuncs(t *testing.T) { assert.Equal(t, expected, actual) } +func TestListTestFuncsWithRegex(t *testing.T) { + tests := []struct { + wantErr assert.ErrorAssertionFunc + name string + run string + skip string + expected []string + }{ + { + name: "NoRunNoSkip", + run: "", + skip: "", + expected: []string{ + "TestError1", + "TestError2", + "TestNormal1", + "TestNormal2", + "TestPanic1", + "TestSkip1", + }, + wantErr: assert.NoError, + }, + { + name: "Run", + run: "TestError", + skip: "", + expected: []string{ + "TestError1", + "TestError2", + }, + wantErr: assert.NoError, + }, + { + name: "Skip", + run: "", + skip: "TestError", + expected: []string{ + "TestNormal1", + "TestNormal2", + "TestPanic1", + "TestSkip1", + }, + wantErr: assert.NoError, + }, + { + name: "RunSkip", + run: "TestError", + skip: "TestError2", + expected: []string{ + "TestError1", + }, + wantErr: assert.NoError, + }, + { + name: "RunSkipAll", + run: "TestError", + skip: "TestError", + expected: []string{}, + wantErr: assert.NoError, + }, + { + name: "InvalidRun", + run: "[", + skip: "", + wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.Contains(t, err.Error(), "error parsing regexp") + }, + }, + { + name: "InvalidSkip", + run: "", + skip: "[", + wantErr: func(t assert.TestingT, err error, i ...interface{}) bool { + return assert.Contains(t, err.Error(), "error parsing regexp") + }, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + actual, err := listTestFuncsWithRegex("./testdata", tt.run, tt.skip) + tt.wantErr(t, err) + assert.Equal(t, tt.expected, actual) + }) + } +} + +func TestFilterStringsByRegex(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + tests []string + include *regexp.Regexp + exclude *regexp.Regexp + expected []string + }{ + { + name: "Empty", + tests: []string{}, + include: nil, + exclude: nil, + expected: []string{}, + }, + { + name: "Include", + tests: []string{"Test1", "Test2"}, + include: regexp.MustCompile("Test1"), + exclude: nil, + expected: []string{"Test1"}, + }, + { + name: "Exclude", + tests: []string{"Test1", "Test2"}, + include: nil, + exclude: regexp.MustCompile("Test1"), + expected: []string{"Test2"}, + }, + { + name: "IncludeExclude", + tests: []string{"Test1", "Test2"}, + include: regexp.MustCompile("Test1"), + exclude: regexp.MustCompile("Test1"), + expected: []string{}, + }, + { + name: "IncludeExclude2", + tests: []string{"Test1", "Test2"}, + include: regexp.MustCompile("Test1"), + exclude: regexp.MustCompile("Test2"), + expected: []string{"Test1"}, + }, + { + name: "NotMatch", + tests: []string{"Test1", "Test2"}, + include: regexp.MustCompile("Test3"), + exclude: regexp.MustCompile("Test3"), + expected: []string{}, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + actual := filterStringsByRegex(tt.tests, tt.include, tt.exclude) + assert.Equal(t, tt.expected, actual) + }) + } +} + func TestShardTestFuncs(t *testing.T) { t.Parallel() From eda2ab3f89accd7264f888b36e8b51bf7c3de195 Mon Sep 17 00:00:00 2001 From: Henrique Vicente Date: Wed, 21 Feb 2024 10:18:30 +0100 Subject: [PATCH 04/13] linter --- cmd/envtool/tests.go | 1 - 1 file changed, 1 deletion(-) diff --git a/cmd/envtool/tests.go b/cmd/envtool/tests.go index f400e84b5680..20e6f5c25785 100644 --- a/cmd/envtool/tests.go +++ b/cmd/envtool/tests.go @@ -364,7 +364,6 @@ func testsRun(ctx context.Context, index, total uint, run, skip string, args []s } tests, err := listTestFuncsWithRegex("", run, skip) - if err != nil { return lazyerrors.Error(err) } From 9334c5fa56f6c7b45234d5bdc2c25c204baa82ae Mon Sep 17 00:00:00 2001 From: Henrique Vicente Date: Wed, 21 Feb 2024 12:22:06 +0100 Subject: [PATCH 05/13] wip --- .golangci.yml | 2 -- cmd/envtool/tests.go | 20 ++++++++++++-------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 7c7fbc84be45..4ca6047af25f 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -3,8 +3,6 @@ run: timeout: 3m - skip-dirs: - - cmd/envtool/internal/testmatch # due to files match.go and match_test.go from the Go standard library linters-settings: # asciicheck diff --git a/cmd/envtool/tests.go b/cmd/envtool/tests.go index 20e6f5c25785..a6fcfdf289e1 100644 --- a/cmd/envtool/tests.go +++ b/cmd/envtool/tests.go @@ -446,16 +446,15 @@ func listTestFuncs(dir string) ([]string, error) { // functions (tests, benchmarks, examples, fuzz functions) in the specified // directory and subdirectories. func listTestFuncsWithRegex(dir, run, skip string) ([]string, error) { - all, err := listTestFuncs(dir) + tests, err := listTestFuncs(dir) if err != nil { return nil, err } - if len(all) == 0 { + if len(tests) == 0 { return nil, fmt.Errorf("no tests to run") } - // Filter what top-level functions we want to test using the same logic as "go test". var ( rxRun *regexp.Regexp rxSkip *regexp.Regexp @@ -475,7 +474,7 @@ func listTestFuncsWithRegex(dir, run, skip string) ([]string, error) { } } - return filterStringsByRegex(all, rxRun, rxSkip), nil + return filterStringsByRegex(tests, rxRun, rxSkip), nil } // filterStringsByRegex filters a slice of strings based on inclusion and exclusion @@ -483,11 +482,16 @@ func listTestFuncsWithRegex(dir, run, skip string) ([]string, error) { func filterStringsByRegex(tests []string, include, exclude *regexp.Regexp) []string { res := []string{} - for _, t := range tests { - if (exclude == nil || !exclude.MatchString(t)) && - (include == nil || include.MatchString(t)) { - res = append(res, t) + for _, test := range tests { + if exclude != nil && exclude.MatchString(test) { + continue } + + if include != nil && !include.MatchString(test) { + continue + } + + res = append(res, test) } return res From a02e6ba96f35202fc71b8d99654e1d663c18e055 Mon Sep 17 00:00:00 2001 From: Henrique Vicente Date: Wed, 21 Feb 2024 13:37:14 +0100 Subject: [PATCH 06/13] wip --- cmd/envtool/tests_test.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/cmd/envtool/tests_test.go b/cmd/envtool/tests_test.go index aae64b2b6fc9..1a53d842c9e2 100644 --- a/cmd/envtool/tests_test.go +++ b/cmd/envtool/tests_test.go @@ -418,10 +418,14 @@ func TestShardTestFuncs(t *testing.T) { assert.Equal(t, testFuncs[3], res[1]) assert.NotEmpty(t, skip) - res, _, err = shardTestFuncs(3, 3, testFuncs) + lastRes, lastSkip, err := shardTestFuncs(3, 3, testFuncs) require.NoError(t, err) - assert.NotEqual(t, testFuncs[0], res[0]) - assert.NotEqual(t, testFuncs[1], res[0]) - assert.Equal(t, testFuncs[2], res[0]) + assert.NotEqual(t, testFuncs[0], lastRes[0]) + assert.NotEqual(t, testFuncs[1], lastRes[0]) + assert.Equal(t, testFuncs[2], lastRes[0]) + assert.NotEmpty(t, lastSkip) + + assert.NotEqual(t, res, lastRes) + assert.NotEqual(t, skip, lastSkip) }) } From ee288a75b8e7d4614f889f5f60b7561f3cb5e1e3 Mon Sep 17 00:00:00 2001 From: Henrique Vicente Date: Fri, 23 Feb 2024 09:05:37 +0100 Subject: [PATCH 07/13] wip --- cmd/envtool/testdata/subtest_test.go | 13 ++++++ cmd/envtool/tests_test.go | 69 ++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+) create mode 100644 cmd/envtool/testdata/subtest_test.go diff --git a/cmd/envtool/testdata/subtest_test.go b/cmd/envtool/testdata/subtest_test.go new file mode 100644 index 000000000000..f40054a952ef --- /dev/null +++ b/cmd/envtool/testdata/subtest_test.go @@ -0,0 +1,13 @@ +package testdata + +import "testing" + +func TestWithSubtest(t *testing.T) { + t.Run("First", func(t *testing.T) {}) + t.Run("Second", func(t *testing.T) {}) + t.Run("Third", func(t *testing.T) { + t.Run("NestedOne", func(t *testing.T) {}) + t.Run("NestedTwo", func(t *testing.T) {}) + t.Run("NestedThree", func(t *testing.T) {}) + }) +} diff --git a/cmd/envtool/tests_test.go b/cmd/envtool/tests_test.go index 1a53d842c9e2..dc502bfbf66b 100644 --- a/cmd/envtool/tests_test.go +++ b/cmd/envtool/tests_test.go @@ -19,6 +19,7 @@ import ( "os/exec" "path/filepath" "regexp" + "sort" "strings" "testing" @@ -98,6 +99,51 @@ func TestRunGoTest(t *testing.T) { assert.Equal(t, expected, actual, "actual:\n%s", strings.Join(actual, "\n")) }) + t.Run("SubtestsPartial", func(t *testing.T) { + t.Parallel() + + var actual []string + logger, err := makeTestLogger(&actual) + require.NoError(t, err) + + err = runGoTest(context.TODO(), []string{"./testdata", "-count=1", "-run=TestWithSubtest/Third"}, 1, false, logger.Sugar()) + require.NoError(t, err) + + expected := []string{ + "PASS TestWithSubtest 1/1", + "PASS", + "ok github.com/FerretDB/FerretDB/cmd/envtool/testdata s", + "PASS github.com/FerretDB/FerretDB/cmd/envtool/testdata", + } + + cleanup(actual) + + assert.Equal(t, expected, actual, "actual:\n%s", strings.Join(actual, "\n")) + }) + + t.Run("SubtestsNotFound", func(t *testing.T) { + t.Parallel() + + var actual []string + logger, err := makeTestLogger(&actual) + require.NoError(t, err) + + err = runGoTest(context.TODO(), []string{"./testdata", "-count=1", "-run=TestWithSubtest/None"}, 1, false, logger.Sugar()) + require.NoError(t, err) + + expected := []string{ + "PASS TestWithSubtest 1/1", + "testing: warning: no tests to run", + "PASS", + "ok github.com/FerretDB/FerretDB/cmd/envtool/testdata s [no tests to run]", + "PASS github.com/FerretDB/FerretDB/cmd/envtool/testdata", + } + + cleanup(actual) + + assert.Equal(t, expected, actual, "actual:\n%s", strings.Join(actual, "\n")) + }) + t.Run("Error", func(t *testing.T) { t.Parallel() @@ -218,6 +264,7 @@ func TestListTestFuncs(t *testing.T) { "TestNormal2", "TestPanic1", "TestSkip1", + "TestWithSubtest", } assert.Equal(t, expected, actual) } @@ -241,6 +288,7 @@ func TestListTestFuncsWithRegex(t *testing.T) { "TestNormal2", "TestPanic1", "TestSkip1", + "TestWithSubtest", }, wantErr: assert.NoError, }, @@ -263,6 +311,7 @@ func TestListTestFuncsWithRegex(t *testing.T) { "TestNormal2", "TestPanic1", "TestSkip1", + "TestWithSubtest", }, wantErr: assert.NoError, }, @@ -429,3 +478,23 @@ func TestShardTestFuncs(t *testing.T) { assert.NotEqual(t, skip, lastSkip) }) } + +func TestListTestFuncsWithSkip(t *testing.T) { + t.Parallel() + + testFuncs, err := listTestFuncsWithRegex(filepath.Join("testdata"), "", "Skip") + require.NoError(t, err) + + sort.Strings(testFuncs) + + res, skip, err := shardTestFuncs(1, 2, testFuncs) + + assert.Equal(t, []string{"TestError2", "TestNormal2", "TestWithSubtest"}, skip) + assert.Equal(t, []string{"TestError1", "TestNormal1", "TestPanic1"}, res) + assert.Nil(t, err) + + lastRes, lastSkip, err := shardTestFuncs(3, 3, testFuncs) + assert.Equal(t, []string{"TestNormal1", "TestSubtest]"}, lastRes) + assert.Equal(t, []string{"TestError1", "TestError2", "TestNormal2", "TestPanic1"}, lastSkip) + require.NoError(t, err) +} From 7055e01b40f1cf4cce70145c0bcd29c062880804 Mon Sep 17 00:00:00 2001 From: Alexander Tobi Fashakin Date: Wed, 21 Feb 2024 16:57:54 +0100 Subject: [PATCH 08/13] Fix Codapi file error (#4077) Closes #4058. --- ...-mongodb-crud-operations-with-ferretdb.mdx | 26 +++++++++---------- website/docusaurus.config-blog.js | 1 - website/static/codapi/init.js | 20 -------------- 3 files changed, 13 insertions(+), 34 deletions(-) delete mode 100644 website/static/codapi/init.js diff --git a/website/blog/2022-11-14-mongodb-crud-operations-with-ferretdb.mdx b/website/blog/2022-11-14-mongodb-crud-operations-with-ferretdb.mdx index 155b68745f70..40b62a5b2c67 100644 --- a/website/blog/2022-11-14-mongodb-crud-operations-with-ferretdb.mdx +++ b/website/blog/2022-11-14-mongodb-crud-operations-with-ferretdb.mdx @@ -46,7 +46,7 @@ And if the database does not exist, FerretDB creates a new database. use league ``` - + If there's no existing database with this name, a new database (**league**) is created in your FerretDB storage backend on PostgreSQL. @@ -78,7 +78,7 @@ db.league.insertOne({ }) ``` - + This line of code creates a new document in your collection. @@ -130,7 +130,7 @@ db.league.insertMany([ ]) ``` - + ## Read operation @@ -150,7 +150,7 @@ First, let's select all the documents in the **league** collection created earli db.league.find({}) ``` - + This operation retrieves and displays all the documents present in the collection. @@ -160,7 +160,7 @@ Now, let's add a query parameter to the `find()` operation to filter for a speci db.league.find({ club: 'PSG' }) ``` - + You can also filter a collection in FerretDB using any of the commonly used MongoDB operators: @@ -189,7 +189,7 @@ Let's filter the `league` data for teams with 80 or 60 `points`: db.league.find({ points: { $in: [80, 60] } }) ``` - + #### Find documents using the `$lt` operator @@ -201,7 +201,7 @@ For example, let's select the documents with less than 60 _points_ : db.league.find({ points: { $lt: 60 } }) ``` - + ### findOne() @@ -213,7 +213,7 @@ For instance, let's filter the collection for documents with the _qualified_ set db.league.findOne({ qualified: true }) ``` - + Even though two documents match this query, the result only displays the first document. @@ -240,7 +240,7 @@ This update operation will only affect the first document that's retrieved in th db.league.updateOne({ club: 'PSG' }, { $set: { points: 35 } }) ``` - + ### updateMany() @@ -252,7 +252,7 @@ For example, let's update all documents with a _points_ field that's less than o db.league.updateMany({ points: { $lte: 90 } }, { $set: { qualified: false } }) ``` - + ### replaceOne() @@ -272,7 +272,7 @@ db.league.replaceOne( ) ``` - + ## Delete operation @@ -289,7 +289,7 @@ Note that this operation only deletes the first document that matches the query db.league.deleteOne({ club: 'Arsenal' }) ``` - + This operation deletes one document from the collection: @@ -303,7 +303,7 @@ The operation takes in a query and then filters and deletes all the documents ma db.league.deleteMany({ qualified: false }) ``` - + ## Get started with FerretDB diff --git a/website/docusaurus.config-blog.js b/website/docusaurus.config-blog.js index b44ef35982e1..08cba9350fa7 100644 --- a/website/docusaurus.config-blog.js +++ b/website/docusaurus.config-blog.js @@ -29,7 +29,6 @@ const config = { scripts: [ {src: 'https://plausible.io/js/script.js', defer: true, "data-domain": "blog.ferretdb.io"}, {src: '/codapi/snippet.js', defer: true}, - {src: '/codapi/init.js', defer: true}, ], plugins: [ diff --git a/website/static/codapi/init.js b/website/static/codapi/init.js deleted file mode 100644 index 4e06d886f991..000000000000 --- a/website/static/codapi/init.js +++ /dev/null @@ -1,20 +0,0 @@ -function initCodapi() { - setTimeout(() => { - document.querySelectorAll("codapi-snippet").forEach((el) => { - const snippet = document.createElement("codapi-snippet"); - setAttribute(snippet, el, "sandbox"); - setAttribute(snippet, el, "editor"); - setAttribute(snippet, el, "template"); - el.replaceWith(snippet); - }); - }, 500); -} - -function setAttribute(dst, src, attrName) { - if (!src.hasAttribute(attrName)) { - return; - } - dst.setAttribute(attrName, src.getAttribute(attrName)); -} - -addEventListener("load", initCodapi); From 84b6eec390768f73c5f03391267c7bc8a227bdf6 Mon Sep 17 00:00:00 2001 From: Alexander Tobi Fashakin Date: Wed, 21 Feb 2024 17:14:41 +0100 Subject: [PATCH 09/13] Add Tembo QA blog post (#4081) --- ...sql-can-it-be-a-database-for-everything.md | 145 ++++++++++++++++++ website/static/img/blog/ferretdb-tembo-qa.jpg | 3 + website/static/img/blog/samay-sharma.jpeg | 3 + 3 files changed, 151 insertions(+) create mode 100644 website/blog/2024-02-20-postgresql-can-it-be-a-database-for-everything.md create mode 100644 website/static/img/blog/ferretdb-tembo-qa.jpg create mode 100644 website/static/img/blog/samay-sharma.jpeg diff --git a/website/blog/2024-02-20-postgresql-can-it-be-a-database-for-everything.md b/website/blog/2024-02-20-postgresql-can-it-be-a-database-for-everything.md new file mode 100644 index 000000000000..0aa9f177fb16 --- /dev/null +++ b/website/blog/2024-02-20-postgresql-can-it-be-a-database-for-everything.md @@ -0,0 +1,145 @@ +--- +slug: postgresql-can-it-be-a-database-for-everything +title: 'PostgreSQL - can it be a database for everything?' +authors: + - name: Marcin Gwóźdź + title: Director of Strategic Alliances at FerretDB + url: https://www.linkedin.com/in/marcin-gwóźdź-277abaa9 + image_url: /img/blog/marcin-gwozdz.jpeg + - name: Samay Sharma + title: Chief Technology Officer at Tembo + image_url: /img/blog/samay-sharma.jpeg +description: > + We recently had the opportunity to speak with the Tembo team and ask about their thoughts on the PostgreSQL ecosystem, how it can be a database for everything, and how FerretDB can be used with Tembo. +image: /img/blog/ferretdb-vultr.jpg +tags: + [ + open source, + community, + document databases, + compatible applications, + postgresql tools, + cloud + ] +--- + +![PostgreSQL - can it be a database for everything?](/img/blog/ferretdb-tembo-qa.jpg) + +[PostgreSQL](https://www.postgresql.org/) is one of the most popular databases around the world. +A lot of companies have decided to build a business around that. + + + +There are many PostgreSQL experts around the world, and we can't even count the number of applications powered by this database. + +Why is this possible? +PostgreSQL is open-source, so anyone can use it and learn easily and without limitations. +Every year more companies are founded and trying to find their niche. + +We recently had the opportunity to speak with the [Tembo](https://tembo.io/) team and ask about their thoughts on the PostgreSQL ecosystem. + +**What is Tembo? What is unique?** + +_At Tembo, our goal is to productize the entire Postgres ecosystem into a developer-friendly platform, so that developers need to use less tools in their data stack. The modern data stack has become a sprawling landscape even though Postgres, with its extension ecosystem, can solve a lot of those problems. However, it's very hard to do in practice._ + +_We want it to be easy to use Postgres for non-typical use cases and, over time, for everything._ + +_To that end, we provide [Tembo Stacks](https://tembo.io/blog/tembo-stacks-intro), which are curated selections of extensions, apps, and Postgres configurations that are designed to address particular use cases. They are available as [open source](https://github.com/tembo-io/tembo/tree/main/tembo-operator/src/stacks), deployable with our Kubernetes operator, and are also available for a single-click deploy on [Tembo Cloud](https://cloud.tembo.io/)._ + +**Why did you decide to start such a company? What is the feedback from early adopters? Why PostgreSQL?** + +_As we outline in the [Tembo Manifesto](https://tembo.io/blog/tembo-manifesto/), the modern data stack has too many moving parts. Each tool promises a solution to a data problem, yet collectively they contribute to ever-growing complexity and cost._ + +_PostgreSQL is a stable, community-supported, reliable open source project with a well-earned reputation. With extensions, it can support a large number of these use cases. Eg. you can do vector search with pgvector, ML with postgresml, OLAP with columnar, geospatial with PostGIS, and mongo with ferretdb._ + +_However, you can easily get stuck discovering which extensions to use, how to install and put them together, and how to configure Postgres appropriately for your use case. We aim to make it extremely easy for developers to discover and deploy extensions, so that they can use Postgres for everything._ + +_Overall feedback has been very encouraging. We launched our managed Cloud platform ~6 months ago, and developers from over 750 organizations have tried out Tembo Cloud. We've also received positive feedback for our community contributions, including pgmq, trunk, pg_later, and pg_vectorize._ + +**How do you believe Tembo is helping develop the PostgreSQL community?** + +_More often than not, we hear that developers use a very small fraction of PostgreSQL capabilities, including extensions. It's hard for developers to discover, evaluate, trust, install, and successfully use extensions. Our goal is to enable every PostgreSQL user to use more extensions and to use Postgres for more use cases._ + +_We're now working to design a community solution to address the challenges related to extension discovery and distribution, and an initial version of the [proposal](https://gist.github.com/theory/898c8802937ad8361ccbcc313054c29d) is already under discussion._ + +_We've also contributed a number of extensions to the ecosystem._ + +**Why did you choose an open-source license for Tembo?** + +_We're first and foremost, a Postgres company, and our philosophies are aligned with the community. While Trunk and Tembo Stacks are integrated into our platform, they are PostgreSQL licensed, and usable without our SaaS._ + +_We have also released all our extensions (namely pgmq, pg_later, pg_vectorize, prometheus_fdw, clerk_fdw) under the PostgreSQL license._ + +_Open sourcing software projects offer transparency and encourage community participation. We want to enable everybody to help us improve our products via issues or contributions. This allows us to build the best products and benefit from the wisdom of others._ + +**What is important for you in the open-source project?** + +_We want to encourage collaboration on our open source projects. We want developers to feel a part of the community and make it easy for them to contribute to our projects._ + +_In fact, we now have a person who is a codeowner on pgmq who is not employed by Tembo. We consider this an important metric for the project's success._ + +**How do you see the future of open-source databases?** + +_Open-source databases are here to stay. Both PostgreSQL and MySQL have been around for decades, and most new databases are also open source projects. Vendor lock-in is a genuine problem, especially for databases because of their critical-ness for a business._ + +_Being based on an open-source database gives us a lot of options and flexibility and the ability to report issues and contribute back as we see fit._ + +_Postgres is even more unique because it is genuinely a community-run project, and in spite of so many forks which have come and gone, Postgres has outlasted all of them. The future of open-source databases is bright._ + +**What is interesting for you in FerretDB? What use cases can you imagine for using Tembo and FerretDB together?** + +_FerretDB is a perfect example of how one can use Postgres for more use cases than what it's typically been used for. Using FerretDB, developers can use Postgres to power their document store workloads without having to make application changes. That makes it much more attractive for them to migrate._ + +_Using FerretDB, we built our Mongo Alternative on Postgres stack, to allow users to benefit from getting a managed Postgres experience with the API being powered by Ferret. That way, developers who aren't experts in setting up and running Postgres operationally can also benefit from Ferret in a one-click manner._ + +**What was the biggest challenge during the integration of FerretDB with Tembo?** + +_Integrating FerretDB with Tembo was an exciting process and a team effort. One of the key steps was to introduce new routing configurations to our Kubernetes operator. Before working on the FerretDB integration, the [tembo-operator](https://github.com/tembo-io/tembo/tree/main/tembo-operator) only supported TCP ingress for services communicating with Postgres, and HTTP ingress for any [application services](https://tembo.io/blog/tembo-operator-apps) running in pods next to Postgres. So in order for our users to successfully communicate with the FerretDB container, we need to build an API to allow appServices to request a TCP ingress from Kubernetes. All of this comes together to support a user experience which only requires downloading an SSL certificate and running a mongosh connection string to get things up and running._ + +**What are your expectations from FerretDB in the future?** + +_We would love to partner with FerretDB even more deeply. We would like to hear about best practices and how we could optimize Postgres to give the best MongoDB alternative experience to our users. We are also excited to share any feedback we receive from our users about FerretDB with the team to improve the product._ + +**We understand Tembo believes in the philosophy of PostgreSQL being a "one database for everything". What are some key challenges you think need to be overcome to make that a reality?** + +_First of all, Postgres has been designed for extendability, so the potential is clearly established._ + +_Next, Postgres can be very performant when configured well for a specific workload. That's why we combined the both of them to build stacks, which are optimized Postgres instances to tackle specific workloads._ + +_A key challenge is going to be comparing PostgreSQL to all other databases, so we can prove to our users that Postgres and its ecosystem of extensions is actually enough for most use cases. Once they see the potential of what these stacks can power, there's nothing which will stop them from picking "Postgres for everything"._ + +## Conclusion + +As we can see, PostgreSQL is a powerful database with limitless capabilities. +The idea of open source allows us to create new ideas around it, and the whole ecosystem provides the necessary pieces to bring the ideas to life. + +FerretDB is one of the pieces you can use to enhance your solution - by adding MongoDB compatibility to your PostgreSQL environment, you're adding flexibility to make life easier for your developers. +The key to success is to provide tools as simple as can be so developers will want to use them in different use cases. + +The community may suggest the direction for evolving the project, which is another great advantage of open source philosophy and can speed up the process. +And who knows? +Maybe at one point, PostgreSQL may become one database for everything. + +[Check out FerretDB on Github](https://github.com/FerretDB/FerretDB). + +[Check out Tembo on GitHub](https://github.com/tembo-io/tembo). + +### About speakers + +- Samay Sharma - Chief Technology Officer - Tembo +- Marcin Gwozdz - Director of Strategic Alliances at FerretDB + +### About Tembo + +Tembo is the Postgres developer platform for building every data service. +We collapse the database sprawl and empower users with a high-performance, fully-extensible managed Postgres service. +With Tembo, developers can quickly create specialized data services using Stacks, pre-built Postgres configurations and deploy without complex builds or additional data teams. + +### About FerretDB + +FerretDB is a truly open-source alternative to MongoDB built on Postgres. +FerretDB allows you to use MongoDB drivers seamlessly with PostgreSQL as the database backend. +Use all tools, drivers, UIs, and the same query language and stay open-source. +Our mission is to enable the open-source community and developers to reap the benefits of an easy-to-use document database while avoiding vendor lock-in and faux pen licenses. + +We are not affiliated, associated, authorized, endorsed by, or in any way officially connected with MongoDB Inc., or any of its subsidiaries or its affiliates. diff --git a/website/static/img/blog/ferretdb-tembo-qa.jpg b/website/static/img/blog/ferretdb-tembo-qa.jpg new file mode 100644 index 000000000000..755781c7df84 --- /dev/null +++ b/website/static/img/blog/ferretdb-tembo-qa.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e64ed65047216b61e66c4cbcbc14bd2c2b90ff1a0a2ccd3dff08855617d45023 +size 52438 diff --git a/website/static/img/blog/samay-sharma.jpeg b/website/static/img/blog/samay-sharma.jpeg new file mode 100644 index 000000000000..8d305500e445 --- /dev/null +++ b/website/static/img/blog/samay-sharma.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4831b1c17f9f327794453636482c1e12a16d480749903db1aa1d3facca785faa +size 89833 From eafa245fb53fe45e2a0408ed32618b3414b8ad0b Mon Sep 17 00:00:00 2001 From: Alexey Palazhchenko Date: Wed, 21 Feb 2024 20:18:13 +0400 Subject: [PATCH 10/13] Refactor `bson2` package (#4105) --- internal/bson2/bson2.go | 14 +- .../bson2/{document_test.go => bson2_test.go} | 816 ++++++++++-------- internal/bson2/decode.go | 130 +++ internal/bson2/decodemode_string.go | 5 +- internal/bson2/document.go | 149 +--- internal/bson2/encode.go | 164 ++++ internal/bson2/raw_array.go | 14 - internal/bson2/raw_document.go | 189 +--- 8 files changed, 780 insertions(+), 701 deletions(-) rename internal/bson2/{document_test.go => bson2_test.go} (54%) create mode 100644 internal/bson2/decode.go create mode 100644 internal/bson2/encode.go diff --git a/internal/bson2/bson2.go b/internal/bson2/bson2.go index 727143809ae7..f289017ea04c 100644 --- a/internal/bson2/bson2.go +++ b/internal/bson2/bson2.go @@ -41,7 +41,6 @@ package bson2 import ( "fmt" - "math" "time" "github.com/cristalhq/bson/bsonproto" @@ -99,10 +98,6 @@ const ( // DecodeDeep represents a mode in which nested documents and arrays are decoded recursively; // RawDocuments and RawArrays are never returned. decodeDeep - - // DecodeCheckOnly represents a mode in which only validity checks are performed (recursively) - // and no decoding happens. - decodeCheckOnly ) var ( @@ -144,19 +139,14 @@ type CompositeType interface { *Document | *Array | RawDocument | RawArray } -// validBSON checks if v is a valid BSON value (including values of raw types). -func validBSON(v any) error { +// validBSONType checks if v is a valid BSON type (including raw types). +func validBSONType(v any) error { switch v := v.(type) { case *Document: case RawDocument: case *Array: case RawArray: - case float64: - if noNaN && math.IsNaN(v) { - return lazyerrors.New("invalid float64 value NaN") - } - case string: case Binary: case ObjectID: diff --git a/internal/bson2/document_test.go b/internal/bson2/bson2_test.go similarity index 54% rename from internal/bson2/document_test.go rename to internal/bson2/bson2_test.go index efe2c1fbe870..238d56a9f611 100644 --- a/internal/bson2/document_test.go +++ b/internal/bson2/bson2_test.go @@ -17,7 +17,6 @@ package bson2 import ( "bufio" "bytes" - "encoding/hex" "io" "testing" "time" @@ -31,21 +30,36 @@ import ( "github.com/FerretDB/FerretDB/internal/util/testutil" ) -// testCase represents a single test case. +// normalTestCase represents a single test case for successful decoding/encoding. // //nolint:vet // for readability -type testCase struct { - name string - raw RawDocument - doc *types.Document - decodeErr error +type normalTestCase struct { + name string + raw RawDocument + tdoc *types.Document } -var ( - handshake1 = testCase{ +// decodeTestCase represents a single test case for unsuccessful decoding. +// +//nolint:vet // for readability +type decodeTestCase struct { + name string + raw RawDocument + + oldOk bool + + findRawErr error + findRawL int + decodeErr error + decodeDeepErr error // defaults to decodeErr +} + +// normalTestCases represents test cases for successful decoding/encoding. +var normalTestCases = []normalTestCase{ + { name: "handshake1", raw: testutil.MustParseDumpFile("testdata", "handshake1.hex"), - doc: must.NotFail(types.NewDocument( + tdoc: must.NotFail(types.NewDocument( "ismaster", true, "client", must.NotFail(types.NewDocument( "driver", must.NotFail(types.NewDocument( @@ -66,12 +80,11 @@ var ( "compression", must.NotFail(types.NewArray("none")), "loadBalanced", false, )), - } - - handshake2 = testCase{ + }, + { name: "handshake2", raw: testutil.MustParseDumpFile("testdata", "handshake2.hex"), - doc: must.NotFail(types.NewDocument( + tdoc: must.NotFail(types.NewDocument( "ismaster", true, "client", must.NotFail(types.NewDocument( "driver", must.NotFail(types.NewDocument( @@ -92,12 +105,11 @@ var ( "compression", must.NotFail(types.NewArray("none")), "loadBalanced", false, )), - } - - handshake3 = testCase{ + }, + { name: "handshake3", raw: testutil.MustParseDumpFile("testdata", "handshake3.hex"), - doc: must.NotFail(types.NewDocument( + tdoc: must.NotFail(types.NewDocument( "buildInfo", int32(1), "lsid", must.NotFail(types.NewDocument( "id", types.Binary{ @@ -110,12 +122,11 @@ var ( )), "$db", "admin", )), - } - - handshake4 = testCase{ + }, + { name: "handshake4", raw: testutil.MustParseDumpFile("testdata", "handshake4.hex"), - doc: must.NotFail(types.NewDocument( + tdoc: must.NotFail(types.NewDocument( "version", "5.0.0", "gitVersion", "1184f004a99660de6f5e745573419bda8a28c0e9", "modules", must.NotFail(types.NewArray()), @@ -156,12 +167,11 @@ var ( "storageEngines", must.NotFail(types.NewArray("devnull", "ephemeralForTest", "wiredTiger")), "ok", float64(1), )), - } - - all = testCase{ + }, + { name: "all", raw: testutil.MustParseDumpFile("testdata", "all.hex"), - doc: must.NotFail(types.NewDocument( + tdoc: must.NotFail(types.NewDocument( "array", must.NotFail(types.NewArray( must.NotFail(types.NewArray("")), must.NotFail(types.NewArray("foo")), @@ -186,9 +196,8 @@ var ( "string", must.NotFail(types.NewArray("foo", "")), "timestamp", must.NotFail(types.NewArray(types.Timestamp(42), types.Timestamp(0))), )), - } - - float64Doc = testCase{ + }, + { name: "float64Doc", raw: RawDocument{ 0x10, 0x00, 0x00, 0x00, @@ -196,12 +205,11 @@ var ( 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, 0x00, }, - doc: must.NotFail(types.NewDocument( + tdoc: must.NotFail(types.NewDocument( "f", float64(3.141592653589793), )), - } - - stringDoc = testCase{ + }, + { name: "stringDoc", raw: RawDocument{ 0x0e, 0x00, 0x00, 0x00, @@ -210,12 +218,11 @@ var ( 0x76, 0x00, 0x00, }, - doc: must.NotFail(types.NewDocument( + tdoc: must.NotFail(types.NewDocument( "f", "v", )), - } - - binaryDoc = testCase{ + }, + { name: "binaryDoc", raw: RawDocument{ 0x0e, 0x00, 0x00, 0x00, @@ -225,12 +232,11 @@ var ( 0x76, 0x00, }, - doc: must.NotFail(types.NewDocument( + tdoc: must.NotFail(types.NewDocument( "f", types.Binary{B: []byte("v"), Subtype: types.BinaryUser}, )), - } - - objectIDDoc = testCase{ + }, + { name: "objectIDDoc", raw: RawDocument{ 0x14, 0x00, 0x00, 0x00, @@ -238,12 +244,11 @@ var ( 0x62, 0x56, 0xc5, 0xba, 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40, 0x00, }, - doc: must.NotFail(types.NewDocument( + tdoc: must.NotFail(types.NewDocument( "f", types.ObjectID{0x62, 0x56, 0xc5, 0xba, 0x18, 0x2d, 0x44, 0x54, 0xfb, 0x21, 0x09, 0x40}, )), - } - - boolDoc = testCase{ + }, + { name: "boolDoc", raw: RawDocument{ 0x09, 0x00, 0x00, 0x00, @@ -251,12 +256,11 @@ var ( 0x01, 0x00, }, - doc: must.NotFail(types.NewDocument( + tdoc: must.NotFail(types.NewDocument( "f", true, )), - } - - timeDoc = testCase{ + }, + { name: "timeDoc", raw: RawDocument{ 0x10, 0x00, 0x00, 0x00, @@ -264,24 +268,22 @@ var ( 0x0b, 0xce, 0x82, 0x18, 0x8d, 0x01, 0x00, 0x00, 0x00, }, - doc: must.NotFail(types.NewDocument( + tdoc: must.NotFail(types.NewDocument( "f", time.Date(2024, 1, 17, 17, 40, 42, 123000000, time.UTC), )), - } - - nullDoc = testCase{ + }, + { name: "nullDoc", raw: RawDocument{ 0x08, 0x00, 0x00, 0x00, 0x0a, 0x66, 0x00, 0x00, }, - doc: must.NotFail(types.NewDocument( + tdoc: must.NotFail(types.NewDocument( "f", types.Null, )), - } - - regexDoc = testCase{ + }, + { name: "regexDoc", raw: RawDocument{ 0x0c, 0x00, 0x00, 0x00, @@ -290,12 +292,11 @@ var ( 0x6f, 0x00, 0x00, }, - doc: must.NotFail(types.NewDocument( + tdoc: must.NotFail(types.NewDocument( "f", types.Regex{Pattern: "p", Options: "o"}, )), - } - - int32Doc = testCase{ + }, + { name: "int32Doc", raw: RawDocument{ 0x0c, 0x00, 0x00, 0x00, @@ -303,12 +304,11 @@ var ( 0xa1, 0xb0, 0xb9, 0x12, 0x00, }, - doc: must.NotFail(types.NewDocument( + tdoc: must.NotFail(types.NewDocument( "f", int32(314159265), )), - } - - timestampDoc = testCase{ + }, + { name: "timestampDoc", raw: RawDocument{ 0x10, 0x00, 0x00, 0x00, @@ -316,12 +316,11 @@ var ( 0x2a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }, - doc: must.NotFail(types.NewDocument( + tdoc: must.NotFail(types.NewDocument( "f", types.Timestamp(42), )), - } - - int64Doc = testCase{ + }, + { name: "int64Doc", raw: RawDocument{ 0x10, 0x00, 0x00, 0x00, @@ -329,18 +328,11 @@ var ( 0x21, 0x6d, 0x25, 0xa, 0x43, 0x29, 0xb, 0x00, 0x00, }, - doc: must.NotFail(types.NewDocument( + tdoc: must.NotFail(types.NewDocument( "f", int64(3141592653589793), )), - } - - eof = testCase{ - name: "EOF", - raw: RawDocument{0x00}, - decodeErr: ErrDecodeShortInput, - } - - smallDoc = testCase{ + }, + { name: "smallDoc", raw: RawDocument{ 0x0f, 0x00, 0x00, 0x00, // document length @@ -348,218 +340,281 @@ var ( 0x05, 0x00, 0x00, 0x00, 0x00, // subdocument length and end of subdocument 0x00, // end of document }, - doc: must.NotFail(types.NewDocument( + tdoc: must.NotFail(types.NewDocument( "foo", must.NotFail(types.NewDocument()), )), - } - - shortDoc = testCase{ - name: "shortDoc", + }, + { + name: "smallArray", raw: RawDocument{ 0x0f, 0x00, 0x00, 0x00, // document length - 0x03, 0x66, 0x6f, 0x6f, 0x00, // subdocument "foo" - 0x06, 0x00, 0x00, 0x00, 0x00, // invalid subdocument length and end of subdocument + 0x04, 0x66, 0x6f, 0x6f, 0x00, // subarray "foo" + 0x05, 0x00, 0x00, 0x00, 0x00, // subarray length and end of subarray 0x00, // end of document }, - decodeErr: ErrDecodeShortInput, - } + tdoc: must.NotFail(types.NewDocument( + "foo", must.NotFail(types.NewArray()), + )), + }, + { + name: "duplicateKeys", + raw: RawDocument{ + 0x0b, 0x00, 0x00, 0x00, // document length + 0x08, 0x00, 0x00, // "": false + 0x08, 0x00, 0x01, // "": true + 0x00, // end of document + }, + tdoc: must.NotFail(types.NewDocument( + "", false, + "", true, + )), + }, +} - invalidDoc = testCase{ - name: "invalidDoc", +// decodeTestCases represents test cases for unsuccessful decoding. +var decodeTestCases = []decodeTestCase{ + { + name: "EOF", + raw: RawDocument{0x00}, + findRawErr: ErrDecodeShortInput, + decodeErr: ErrDecodeShortInput, + }, + { + name: "invalidLength", raw: RawDocument{ - 0x0f, 0x00, 0x00, 0x00, // document length - 0x03, 0x66, 0x6f, 0x6f, 0x00, // subdocument "foo" - 0x05, 0x00, 0x00, 0x00, // subdocument length - 0x30, // invalid end of subdocument + 0x00, 0x00, 0x00, 0x00, // invalid document length + 0x00, // end of document + }, + findRawErr: ErrDecodeInvalidInput, + decodeErr: ErrDecodeInvalidInput, + }, + { + name: "missingByte", + raw: RawDocument{ + 0x06, 0x00, 0x00, 0x00, // document length 0x00, // end of document }, + findRawErr: ErrDecodeShortInput, + decodeErr: ErrDecodeShortInput, + }, + { + name: "extraByte", + raw: RawDocument{ + 0x05, 0x00, 0x00, 0x00, // document length + 0x00, // end of document + 0x00, // extra byte + }, + oldOk: true, + findRawL: 5, decodeErr: ErrDecodeInvalidInput, - } - - smallArray = testCase{ - name: "smallArray", + }, + { + name: "unexpectedTag", raw: RawDocument{ - 0x0f, 0x00, 0x00, 0x00, // document length - 0x04, 0x66, 0x6f, 0x6f, 0x00, // subarray "foo" - 0x05, 0x00, 0x00, 0x00, 0x00, // subarray length and end of subarray + 0x06, 0x00, 0x00, 0x00, // document length + 0xdd, // unexpected tag 0x00, // end of document }, - doc: must.NotFail(types.NewDocument( - "foo", must.NotFail(types.NewArray()), - )), - } - - shortArray = testCase{ - name: "shortArray", + findRawL: 6, + decodeErr: ErrDecodeInvalidInput, + }, + { + name: "invalidTag", + raw: RawDocument{ + 0x06, 0x00, 0x00, 0x00, // document length + 0x00, // invalid tag + 0x00, // end of document + }, + findRawL: 6, + decodeErr: ErrDecodeInvalidInput, + }, + { + name: "shortDoc", raw: RawDocument{ 0x0f, 0x00, 0x00, 0x00, // document length - 0x04, 0x66, 0x6f, 0x6f, 0x00, // subarray "foo" - 0x06, 0x00, 0x00, 0x00, 0x00, // invalid subarray length and end of subarray + 0x03, 0x66, 0x6f, 0x6f, 0x00, // subdocument "foo" + 0x06, 0x00, 0x00, 0x00, // invalid subdocument length + 0x00, // end of subdocument 0x00, // end of document }, - decodeErr: ErrDecodeShortInput, - } - - invalidArray = testCase{ - name: "invalidArray", + findRawL: 15, + decodeErr: ErrDecodeShortInput, + decodeDeepErr: ErrDecodeInvalidInput, + }, + { + name: "invalidDoc", raw: RawDocument{ 0x0f, 0x00, 0x00, 0x00, // document length - 0x04, 0x66, 0x6f, 0x6f, 0x00, // subarray "foo" - 0x05, 0x00, 0x00, 0x00, // subarray length - 0x30, // invalid end of subarray + 0x03, 0x66, 0x6f, 0x6f, 0x00, // subdocument "foo" + 0x05, 0x00, 0x00, 0x00, // subdocument length + 0x30, // invalid end of subdocument 0x00, // end of document }, + findRawL: 15, decodeErr: ErrDecodeInvalidInput, - } - - duplicateKeys = testCase{ - name: "duplicateKeys", + }, + { + name: "invalidDocTag", raw: RawDocument{ - 0x0b, 0x00, 0x00, 0x00, // document length - 0x08, 0x00, 0x00, // "": false - 0x08, 0x00, 0x01, // "": true + 0x10, 0x00, 0x00, 0x00, // document length + 0x03, 0x66, 0x6f, 0x6f, 0x00, // subdocument "foo" + 0x06, 0x00, 0x00, 0x00, // subdocument length + 0x00, // invalid tag + 0x00, // end of subdocument 0x00, // end of document }, - doc: must.NotFail(types.NewDocument( - "", false, - "", true, - )), - } - - documentTestCases = []testCase{ - handshake1, handshake2, handshake3, handshake4, all, - float64Doc, stringDoc, binaryDoc, objectIDDoc, boolDoc, timeDoc, nullDoc, regexDoc, int32Doc, timestampDoc, int64Doc, - eof, smallDoc, shortDoc, invalidDoc, smallArray, shortArray, invalidArray, duplicateKeys, - } -) + findRawL: 16, + decodeDeepErr: ErrDecodeInvalidInput, + }, +} -func TestDocument(t *testing.T) { +func TestNormal(t *testing.T) { prev := noNaN noNaN = false t.Cleanup(func() { noNaN = prev }) - for _, tc := range documentTestCases { - tc := tc - + for _, tc := range normalTestCases { t.Run(tc.name, func(t *testing.T) { - require.NotEqual(t, tc.doc == nil, tc.decodeErr == nil) + t.Run("bson", func(t *testing.T) { + t.Run("ReadFrom", func(t *testing.T) { + var doc bson.Document + buf := bufio.NewReader(bytes.NewReader(tc.raw)) + err := doc.ReadFrom(buf) + require.NoError(t, err) - t.Run("FindRawDocument", func(t *testing.T) { - assert.Nil(t, FindRawDocument(nil)) - assert.Nil(t, FindRawDocument(tc.raw[:0])) - assert.Nil(t, FindRawDocument(tc.raw[:1])) + _, err = buf.ReadByte() + assert.Equal(t, err, io.EOF) - if tc.name != "EOF" { - assert.Nil(t, FindRawDocument(tc.raw[:5])) - assert.Nil(t, FindRawDocument(tc.raw[:len(tc.raw)-1])) + tdoc, err := types.ConvertDocument(&doc) + require.NoError(t, err) + testutil.AssertEqual(t, tc.tdoc, tdoc) + }) - assert.Equal(t, tc.raw, FindRawDocument(tc.raw)) + t.Run("MarshalBinary", func(t *testing.T) { + doc, err := bson.ConvertDocument(tc.tdoc) + require.NoError(t, err) - b := append([]byte(nil), tc.raw...) - b = append(b, 0) - assert.Equal(t, tc.raw, FindRawDocument(b)) - } + raw, err := doc.MarshalBinary() + require.NoError(t, err) + assert.Equal(t, []byte(tc.raw), raw) + }) }) - t.Run("Encode", func(t *testing.T) { - if tc.doc == nil { - t.Skip() - } + t.Run("bson2", func(t *testing.T) { + t.Run("FindRaw", func(t *testing.T) { + ls := tc.raw.LogValue().Resolve().String() + assert.NotContains(t, ls, "panicked") + assert.NotContains(t, ls, "called too many times") - t.Run("bson", func(t *testing.T) { - doc, err := bson.ConvertDocument(tc.doc) + l, err := FindRaw(tc.raw) require.NoError(t, err) + require.Len(t, tc.raw, l) + }) - actual, err := doc.MarshalBinary() + t.Run("DecodeEncode", func(t *testing.T) { + doc, err := tc.raw.Decode() require.NoError(t, err) - assert.Equal(t, []byte(tc.raw), actual, "actual:\n%s", hex.Dump(actual)) - }) - t.Run("bson2", func(t *testing.T) { - doc, err := ConvertDocument(tc.doc) + ls := doc.LogValue().Resolve().String() + assert.NotContains(t, ls, "panicked") + assert.NotContains(t, ls, "called too many times") + + tdoc, err := doc.Convert() + require.NoError(t, err) + testutil.AssertEqual(t, tc.tdoc, tdoc) + + raw, err := doc.Encode() require.NoError(t, err) + assert.Equal(t, tc.raw, raw) + }) - actual, err := doc.Encode() + t.Run("DecodeDeepEncode", func(t *testing.T) { + doc, err := tc.raw.DecodeDeep() require.NoError(t, err) - assert.Equal(t, tc.raw, actual, "actual:\n%s", hex.Dump(actual)) ls := doc.LogValue().Resolve().String() assert.NotContains(t, ls, "panicked") assert.NotContains(t, ls, "called too many times") - }) - }) - t.Run("Decode", func(t *testing.T) { - t.Run("bson", func(t *testing.T) { - var doc bson.Document - buf := bufio.NewReader(bytes.NewReader(tc.raw)) - err := doc.ReadFrom(buf) - - if tc.decodeErr != nil { - require.Error(t, err) - return - } + tdoc, err := doc.Convert() require.NoError(t, err) + testutil.AssertEqual(t, tc.tdoc, tdoc) - _, err = buf.ReadByte() - assert.Equal(t, err, io.EOF) - - actual, err := types.ConvertDocument(&doc) + raw, err := doc.Encode() require.NoError(t, err) - testutil.AssertEqual(t, tc.doc, actual) + assert.Equal(t, tc.raw, raw) }) - t.Run("bson2", func(t *testing.T) { - raw := RawDocument(tc.raw) + t.Run("ConvertEncode", func(t *testing.T) { + doc, err := ConvertDocument(tc.tdoc) + require.NoError(t, err) - t.Run("Check", func(t *testing.T) { - err := raw.Check() + raw, err := doc.Encode() + require.NoError(t, err) + assert.Equal(t, tc.raw, raw) + }) + }) + }) + } +} - if tc.decodeErr != nil { - require.Error(t, err, "b:\n\n%s\n%#v", hex.Dump(tc.raw), tc.raw) - require.ErrorIs(t, err, tc.decodeErr) +func TestDecode(t *testing.T) { + prev := noNaN + noNaN = false - return - } + t.Cleanup(func() { noNaN = prev }) - require.NoError(t, err) - }) + for _, tc := range decodeTestCases { + if tc.decodeDeepErr == nil { + tc.decodeDeepErr = tc.decodeErr + } - t.Run("Decode", func(t *testing.T) { - doc, err := raw.Decode() + require.NotNil(t, tc.decodeDeepErr, "invalid test case %q", tc.name) - if tc.decodeErr != nil { - return - } + t.Run(tc.name, func(t *testing.T) { + t.Run("bson", func(t *testing.T) { + t.Run("ReadFrom", func(t *testing.T) { + var doc bson.Document + buf := bufio.NewReader(bytes.NewReader(tc.raw)) + err := doc.ReadFrom(buf) + if tc.oldOk { require.NoError(t, err) + return + } - actual, err := doc.Convert() - require.NoError(t, err) - testutil.AssertEqual(t, tc.doc, actual) - }) + require.Error(t, err) + }) + }) - t.Run("DecodeDeep", func(t *testing.T) { - doc, err := raw.DecodeDeep() + t.Run("bson2", func(t *testing.T) { + t.Run("FindRaw", func(t *testing.T) { + l, err := FindRaw(tc.raw) - if tc.decodeErr != nil { - require.Error(t, err, "b:\n\n%s\n%#v", hex.Dump(tc.raw), tc.raw) - require.ErrorIs(t, err, tc.decodeErr) + if tc.findRawErr != nil { + require.ErrorIs(t, err, tc.findRawErr) + return + } - return - } + require.NoError(t, err) + require.Equal(t, tc.findRawL, l) + }) - require.NoError(t, err) + t.Run("Decode", func(t *testing.T) { + _, err := tc.raw.Decode() - actual, err := doc.Convert() - require.NoError(t, err) - testutil.AssertEqual(t, tc.doc, actual) + if tc.decodeErr != nil { + require.ErrorIs(t, err, tc.decodeErr) + return + } - ls := doc.LogValue().Resolve().String() - assert.NotContains(t, ls, "panicked") - assert.NotContains(t, ls, "called too many times") - }) + require.NoError(t, err) + }) + + t.Run("DecodeDeep", func(t *testing.T) { + _, err := tc.raw.DecodeDeep() + require.ErrorIs(t, err, tc.decodeDeepErr) }) }) }) @@ -572,247 +627,254 @@ func BenchmarkDocument(b *testing.B) { b.Cleanup(func() { noNaN = prev }) - for _, tc := range documentTestCases { - tc := tc - + for _, tc := range normalTestCases { b.Run(tc.name, func(b *testing.B) { - b.Run("Encode", func(b *testing.B) { - if tc.doc == nil { - b.Skip() - } - - b.Run("bson", func(b *testing.B) { - doc, err := bson.ConvertDocument(tc.doc) - require.NoError(b, err) - - var actual []byte + b.Run("bson", func(b *testing.B) { + b.Run("ReadFrom", func(b *testing.B) { + var doc bson.Document + var buf *bufio.Reader + var err error + br := bytes.NewReader(tc.raw) + b.ReportAllocs() b.ResetTimer() - for i := 0; i < b.N; i++ { - actual, err = doc.MarshalBinary() + for range b.N { + _, _ = br.Seek(0, io.SeekStart) + buf = bufio.NewReader(br) + err = doc.ReadFrom(buf) } b.StopTimer() require.NoError(b, err) - assert.NotNil(b, actual) }) - b.Run("bson2", func(b *testing.B) { - doc, err := ConvertDocument(tc.doc) + b.Run("MarshalBinary", func(b *testing.B) { + doc, err := bson.ConvertDocument(tc.tdoc) require.NoError(b, err) - var actual []byte + var raw []byte + b.ReportAllocs() b.ResetTimer() - for i := 0; i < b.N; i++ { - actual, err = doc.Encode() + for range b.N { + raw, err = doc.MarshalBinary() } b.StopTimer() require.NoError(b, err) - assert.NotNil(b, actual) + assert.NotNil(b, raw) }) }) - b.Run("Decode", func(b *testing.B) { - b.Run("bson/ReadFrom", func(b *testing.B) { - var doc bson.Document - var buf *bufio.Reader - var err error - br := bytes.NewReader(tc.raw) + b.Run("bson2", func(b *testing.B) { + var doc *Document + var raw []byte + var err error - b.ResetTimer() + b.Run("Decode", func(b *testing.B) { + b.ReportAllocs() - for i := 0; i < b.N; i++ { - _, _ = br.Seek(0, io.SeekStart) - buf = bufio.NewReader(br) - err = doc.ReadFrom(buf) + for range b.N { + doc, err = tc.raw.Decode() } b.StopTimer() - if tc.decodeErr != nil { - require.Error(b, err) - return - } - require.NoError(b, err) + require.NotNil(b, doc) }) - b.Run("bson2", func(b *testing.B) { - raw := RawDocument(tc.raw) - - var doc *Document - var err error + b.Run("Encode", func(b *testing.B) { + doc, err = tc.raw.Decode() + require.NoError(b, err) - b.Run("Check", func(b *testing.B) { - for i := 0; i < b.N; i++ { - err = raw.Check() - } + b.ReportAllocs() + b.ResetTimer() - b.StopTimer() + for range b.N { + raw, err = doc.Encode() + } - if tc.decodeErr != nil { - require.Error(b, err) - return - } + b.StopTimer() - require.NoError(b, err) - }) + require.NoError(b, err) + assert.NotNil(b, raw) + }) - b.Run("Decode", func(b *testing.B) { - for i := 0; i < b.N; i++ { - doc, err = raw.Decode() - } + b.Run("DecodeDeep", func(b *testing.B) { + b.ReportAllocs() - b.StopTimer() + for range b.N { + doc, err = tc.raw.DecodeDeep() + } - if tc.decodeErr != nil { - return - } + b.StopTimer() - require.NoError(b, err) - require.NotNil(b, doc) - }) + require.NoError(b, err) + require.NotNil(b, doc) + }) - b.Run("DecodeDeep", func(b *testing.B) { - for i := 0; i < b.N; i++ { - doc, err = raw.DecodeDeep() - } + b.Run("EncodeDeep", func(b *testing.B) { + doc, err = tc.raw.DecodeDeep() + require.NoError(b, err) - b.StopTimer() + b.ReportAllocs() + b.ResetTimer() - if tc.decodeErr != nil { - require.Error(b, err) - require.Nil(b, doc) + for range b.N { + raw, err = doc.Encode() + } - return - } + b.StopTimer() - require.NoError(b, err) - require.NotNil(b, doc) - }) + require.NoError(b, err) + assert.NotNil(b, raw) }) }) }) } } -func FuzzDocument(f *testing.F) { - noNaN = false +// testRawDocument tests a single RawDocument (that might or might not be valid). +// It is adapted from tests above. +func testRawDocument(t *testing.T, rawDoc RawDocument) { + t.Helper() - for _, tc := range documentTestCases { - f.Add([]byte(tc.raw)) - } + t.Run("bson2", func(t *testing.T) { + t.Run("FindRaw", func(t *testing.T) { + ls := rawDoc.LogValue().Resolve().String() + assert.NotContains(t, ls, "panicked") + assert.NotContains(t, ls, "called too many times") - f.Fuzz(func(t *testing.T, b []byte) { - t.Parallel() + _, _ = FindRaw(rawDoc) + }) - raw := RawDocument(b) + t.Run("DecodeEncode", func(t *testing.T) { + doc, err := rawDoc.Decode() + if err != nil { + _, err = rawDoc.DecodeDeep() + assert.Error(t, err) // it might be different - t.Run("bson2", func(t *testing.T) { - t.Parallel() + return + } - t.Run("Check", func(t *testing.T) { - t.Parallel() + ls := doc.LogValue().Resolve().String() + assert.NotContains(t, ls, "panicked") + assert.NotContains(t, ls, "called too many times") - _ = raw.Check() - }) + _, _ = doc.Convert() - t.Run("Decode", func(t *testing.T) { - t.Parallel() + raw, err := doc.Encode() + if err == nil { + assert.Equal(t, rawDoc, raw) + } + }) - doc, err := raw.Decode() - if err != nil { - t.Skip() - } + t.Run("DecodeDeepEncode", func(t *testing.T) { + doc, err := rawDoc.DecodeDeep() + if err != nil { + return + } - actual, err := doc.Encode() - require.NoError(t, err) - assert.Equal(t, raw, actual, "actual:\n%s", hex.Dump(actual)) - }) + ls := doc.LogValue().Resolve().String() + assert.NotContains(t, ls, "panicked") + assert.NotContains(t, ls, "called too many times") - t.Run("DecodeDeep", func(t *testing.T) { - t.Parallel() + _, err = doc.Convert() + require.NoError(t, err) - doc, err := raw.DecodeDeep() - if err != nil { - t.Skip() - } + raw, err := doc.Encode() + require.NoError(t, err) + assert.Equal(t, rawDoc, raw) + }) + }) - actual, err := doc.Encode() - require.NoError(t, err) - assert.Equal(t, raw, actual, "actual:\n%s", hex.Dump(actual)) + t.Run("cross", func(t *testing.T) { + br := bytes.NewReader(rawDoc) + bufr := bufio.NewReader(br) - ls := doc.LogValue().Resolve().String() - assert.NotContains(t, ls, "panicked") - assert.NotContains(t, ls, "called too many times") - }) - }) + var doc1 bson.Document + err1 := doc1.ReadFrom(bufr) - t.Run("cross", func(t *testing.T) { - t.Parallel() + if err1 != nil { + _, err2 := rawDoc.DecodeDeep() + require.Error(t, err2, "bson1 err = %v", err1) - br := bytes.NewReader(b) - bufr := bufio.NewReader(br) + return + } - var bdoc1 bson.Document - err1 := bdoc1.ReadFrom(bufr) + // remove extra tail + b := []byte(rawDoc[:len(rawDoc)-bufr.Buffered()-br.Len()]) + l, err := FindRaw(rawDoc) + require.NoError(t, err) + require.Equal(t, b, []byte(rawDoc[:l])) - if err1 != nil { - _, err2 := raw.DecodeDeep() - require.Error(t, err2, "bson1 err = %v", err1) - return - } + // decode - // remove extra tail - cb := b[:len(b)-bufr.Buffered()-br.Len()] - assert.Equal(t, cb, []byte(FindRawDocument(b))) + bdoc2, err2 := RawDocument(b).DecodeDeep() + require.NoError(t, err2) - // decode + ls := bdoc2.LogValue().Resolve().String() + assert.NotContains(t, ls, "panicked") + assert.NotContains(t, ls, "called too many times") - checkErr := RawDocument(cb).Check() - require.NoError(t, checkErr) + tdoc1, err := types.ConvertDocument(&doc1) + require.NoError(t, err) - bdoc2, err2 := RawDocument(cb).DecodeDeep() - require.NoError(t, err2) + tdoc2, err := bdoc2.Convert() + require.NoError(t, err) - ls := bdoc2.LogValue().Resolve().String() - assert.NotContains(t, ls, "panicked") - assert.NotContains(t, ls, "called too many times") + testutil.AssertEqual(t, tdoc1, tdoc2) - doc1, err := types.ConvertDocument(&bdoc1) - require.NoError(t, err) + // encode - doc2, err := bdoc2.Convert() - require.NoError(t, err) + doc1e, err := bson.ConvertDocument(tdoc1) + require.NoError(t, err) - testutil.AssertEqual(t, doc1, doc2) + doc2e, err := ConvertDocument(tdoc2) + require.NoError(t, err) - // encode + ls = doc2e.LogValue().Resolve().String() + assert.NotContains(t, ls, "panicked") + assert.NotContains(t, ls, "called too many times") - bdoc1e, err := bson.ConvertDocument(doc1) - require.NoError(t, err) + b1, err := doc1e.MarshalBinary() + require.NoError(t, err) - bdoc2e, err := ConvertDocument(doc2) - require.NoError(t, err) + b2, err := doc2e.Encode() + require.NoError(t, err) - ls = bdoc2e.LogValue().Resolve().String() - assert.NotContains(t, ls, "panicked") - assert.NotContains(t, ls, "called too many times") + assert.Equal(t, b1, []byte(b2)) + assert.Equal(t, b, []byte(b2)) + }) +} - b1, err := bdoc1e.MarshalBinary() - require.NoError(t, err) +func FuzzDocument(f *testing.F) { + prev := noNaN + noNaN = false - b2, err := bdoc2e.Encode() - require.NoError(t, err) + f.Cleanup(func() { noNaN = prev }) - assert.Equal(t, b1, []byte(b2)) - assert.Equal(t, cb, []byte(b2)) - }) + for _, tc := range normalTestCases { + f.Add([]byte(tc.raw)) + } + + for _, tc := range decodeTestCases { + f.Add([]byte(tc.raw)) + } + + f.Fuzz(func(t *testing.T, b []byte) { + t.Parallel() + + testRawDocument(t, RawDocument(b)) + + l, err := FindRaw(b) + if err == nil { + testRawDocument(t, RawDocument(b[:l])) + } }) } diff --git a/internal/bson2/decode.go b/internal/bson2/decode.go new file mode 100644 index 000000000000..ae1039da6643 --- /dev/null +++ b/internal/bson2/decode.go @@ -0,0 +1,130 @@ +// Copyright 2021 FerretDB Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bson2 + +import ( + "encoding/binary" + "math" + + "github.com/cristalhq/bson/bsonproto" + + "github.com/FerretDB/FerretDB/internal/util/lazyerrors" +) + +// FindRaw finds the first raw BSON document or array in b and returns its length l. +// It should start from the first byte of b. +// RawDocument(b[:l] / RawArray(b[:l]) might not be valid. It is the caller's responsibility to check it. +// +// Use RawDocument(b) / RawArray(b) conversion instead if b contains exactly one document/array and no extra bytes. +func FindRaw(b []byte) (int, error) { + bl := len(b) + if bl < 5 { + return 0, lazyerrors.Errorf("len(b) = %d: %w", bl, ErrDecodeShortInput) + } + + dl := int(binary.LittleEndian.Uint32(b)) + if dl < 5 { + return 0, lazyerrors.Errorf("dl = %d: %w", dl, ErrDecodeInvalidInput) + } + + if bl < dl { + return 0, lazyerrors.Errorf("len(b) = %d, dl = %d: %w", bl, dl, ErrDecodeShortInput) + } + + if b[dl-1] != 0 { + return 0, lazyerrors.Errorf("invalid last byte: %w", ErrDecodeInvalidInput) + } + + return dl, nil +} + +// decodeCheckOffset checks that b has enough bytes to decode size bytes starting from offset. +func decodeCheckOffset(b []byte, offset, size int) error { + if l := len(b); l < offset+size { + return lazyerrors.Errorf("len(b) = %d, offset = %d, size = %d: %w", l, offset, size, ErrDecodeShortInput) + } + + return nil +} + +func decodeScalarField(b []byte, t tag) (v any, size int, err error) { + switch t { + case tagFloat64: + var f float64 + f, err = bsonproto.DecodeFloat64(b) + v = f + size = bsonproto.SizeFloat64 + + if noNaN && math.IsNaN(f) { + err = lazyerrors.Errorf("got NaN value: %w", ErrDecodeInvalidInput) + } + + case tagString: + var s string + s, err = bsonproto.DecodeString(b) + v = s + size = bsonproto.SizeString(s) + + case tagBinary: + var bin Binary + bin, err = bsonproto.DecodeBinary(b) + v = bin + size = bsonproto.SizeBinary(bin) + + case tagObjectID: + v, err = bsonproto.DecodeObjectID(b) + size = bsonproto.SizeObjectID + + case tagBool: + v, err = bsonproto.DecodeBool(b) + size = bsonproto.SizeBool + + case tagTime: + v, err = bsonproto.DecodeTime(b) + size = bsonproto.SizeTime + + case tagNull: + v = Null + + case tagRegex: + var re Regex + re, err = bsonproto.DecodeRegex(b) + v = re + size = bsonproto.SizeRegex(re) + + case tagInt32: + v, err = bsonproto.DecodeInt32(b) + size = bsonproto.SizeInt32 + + case tagTimestamp: + v, err = bsonproto.DecodeTimestamp(b) + size = bsonproto.SizeTimestamp + + case tagInt64: + v, err = bsonproto.DecodeInt64(b) + size = bsonproto.SizeInt64 + + case tagUndefined, tagDBPointer, tagJavaScript, tagSymbol, tagJavaScriptScope, tagDecimal, tagMinKey, tagMaxKey: + err = lazyerrors.Errorf("unsupported tag %s: %w", t, ErrDecodeInvalidInput) + + case tagDocument, tagArray: + err = lazyerrors.Errorf("non-scalar tag: %s", t) + + default: + err = lazyerrors.Errorf("unexpected tag %s: %w", t, ErrDecodeInvalidInput) + } + + return +} diff --git a/internal/bson2/decodemode_string.go b/internal/bson2/decodemode_string.go index 993f9ee95927..da8d19b38a7f 100644 --- a/internal/bson2/decodemode_string.go +++ b/internal/bson2/decodemode_string.go @@ -10,12 +10,11 @@ func _() { var x [1]struct{} _ = x[decodeShallow-1] _ = x[decodeDeep-2] - _ = x[decodeCheckOnly-3] } -const _decodeMode_name = "decodeShallowdecodeDeepdecodeCheckOnly" +const _decodeMode_name = "decodeShallowdecodeDeep" -var _decodeMode_index = [...]uint8{0, 13, 23, 38} +var _decodeMode_index = [...]uint8{0, 13, 23} func (i decodeMode) String() string { i -= 1 diff --git a/internal/bson2/document.go b/internal/bson2/document.go index 60d75c8d490a..67a52471da43 100644 --- a/internal/bson2/document.go +++ b/internal/bson2/document.go @@ -18,12 +18,8 @@ import ( "bytes" "encoding/binary" "errors" - "fmt" "log/slog" "math" - "time" - - "github.com/cristalhq/bson/bsonproto" "github.com/FerretDB/FerretDB/internal/types" "github.com/FerretDB/FerretDB/internal/util/iterator" @@ -140,10 +136,16 @@ func (doc *Document) Get(name string) any { // add adds a new field to the Document. func (doc *Document) add(name string, value any) error { - if err := validBSON(value); err != nil { + if err := validBSONType(value); err != nil { return lazyerrors.Errorf("%q: %w", name, err) } + if f, ok := value.(float64); ok { + if noNaN && math.IsNaN(f) { + return lazyerrors.New("invalid float64 value NaN") + } + } + doc.fields = append(doc.fields, field{ name: name, value: value, @@ -183,143 +185,6 @@ func (doc *Document) LogValue() slog.Value { return slogValue(doc) } -// encodeField encodes document field. -// -// It panics if v is not a valid type. -func encodeField(buf *bytes.Buffer, name string, v any) error { - switch v := v.(type) { - case *Document: - if err := buf.WriteByte(byte(tagDocument)); err != nil { - return lazyerrors.Error(err) - } - - b := make([]byte, bsonproto.SizeCString(name)) - bsonproto.EncodeCString(b, name) - - if _, err := buf.Write(b); err != nil { - return lazyerrors.Error(err) - } - - b, err := v.Encode() - if err != nil { - return lazyerrors.Error(err) - } - - if _, err := buf.Write(b); err != nil { - return lazyerrors.Error(err) - } - - case RawDocument: - if err := buf.WriteByte(byte(tagDocument)); err != nil { - return lazyerrors.Error(err) - } - - b := make([]byte, bsonproto.SizeCString(name)) - bsonproto.EncodeCString(b, name) - - if _, err := buf.Write(b); err != nil { - return lazyerrors.Error(err) - } - - if _, err := buf.Write(v); err != nil { - return lazyerrors.Error(err) - } - - case *Array: - if err := buf.WriteByte(byte(tagArray)); err != nil { - return lazyerrors.Error(err) - } - - b := make([]byte, bsonproto.SizeCString(name)) - bsonproto.EncodeCString(b, name) - - if _, err := buf.Write(b); err != nil { - return lazyerrors.Error(err) - } - - b, err := v.Encode() - if err != nil { - return lazyerrors.Error(err) - } - - if _, err := buf.Write(b); err != nil { - return lazyerrors.Error(err) - } - - case RawArray: - if err := buf.WriteByte(byte(tagArray)); err != nil { - return lazyerrors.Error(err) - } - - b := make([]byte, bsonproto.SizeCString(name)) - bsonproto.EncodeCString(b, name) - - if _, err := buf.Write(b); err != nil { - return lazyerrors.Error(err) - } - - if _, err := buf.Write(v); err != nil { - return lazyerrors.Error(err) - } - - default: - return encodeScalarField(buf, name, v) - } - - return nil -} - -// encodeScalarField encodes scalar document field. -// -// It panics if v is not a scalar value. -func encodeScalarField(buf *bytes.Buffer, name string, v any) error { - switch v := v.(type) { - case float64: - if noNaN && math.IsNaN(v) { - return lazyerrors.Errorf("got NaN value") - } - buf.WriteByte(byte(tagFloat64)) - case string: - buf.WriteByte(byte(tagString)) - case Binary: - buf.WriteByte(byte(tagBinary)) - case ObjectID: - buf.WriteByte(byte(tagObjectID)) - case bool: - buf.WriteByte(byte(tagBool)) - case time.Time: - buf.WriteByte(byte(tagTime)) - case NullType: - buf.WriteByte(byte(tagNull)) - case Regex: - buf.WriteByte(byte(tagRegex)) - case int32: - buf.WriteByte(byte(tagInt32)) - case Timestamp: - buf.WriteByte(byte(tagTimestamp)) - case int64: - buf.WriteByte(byte(tagInt64)) - default: - panic(fmt.Sprintf("invalid BSON type %T", v)) - } - - b := make([]byte, bsonproto.SizeCString(name)) - bsonproto.EncodeCString(b, name) - - if _, err := buf.Write(b); err != nil { - return lazyerrors.Error(err) - } - - b = make([]byte, bsonproto.SizeAny(v)) - bsonproto.EncodeAny(b, v) - - if _, err := buf.Write(b); err != nil { - return lazyerrors.Error(err) - } - - return nil -} - // check interfaces var ( _ slog.LogValuer = (*Document)(nil) diff --git a/internal/bson2/encode.go b/internal/bson2/encode.go new file mode 100644 index 000000000000..d5c335ff1f5b --- /dev/null +++ b/internal/bson2/encode.go @@ -0,0 +1,164 @@ +// Copyright 2021 FerretDB Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bson2 + +import ( + "bytes" + "fmt" + "math" + "time" + + "github.com/cristalhq/bson/bsonproto" + + "github.com/FerretDB/FerretDB/internal/util/lazyerrors" +) + +// encodeField encodes document/array field. +// +// It panics if v is not a valid type. +func encodeField(buf *bytes.Buffer, name string, v any) error { + switch v := v.(type) { + case *Document: + if err := buf.WriteByte(byte(tagDocument)); err != nil { + return lazyerrors.Error(err) + } + + b := make([]byte, SizeCString(name)) + EncodeCString(b, name) + + if _, err := buf.Write(b); err != nil { + return lazyerrors.Error(err) + } + + b, err := v.Encode() + if err != nil { + return lazyerrors.Error(err) + } + + if _, err := buf.Write(b); err != nil { + return lazyerrors.Error(err) + } + + case RawDocument: + if err := buf.WriteByte(byte(tagDocument)); err != nil { + return lazyerrors.Error(err) + } + + b := make([]byte, SizeCString(name)) + EncodeCString(b, name) + + if _, err := buf.Write(b); err != nil { + return lazyerrors.Error(err) + } + + if _, err := buf.Write(v); err != nil { + return lazyerrors.Error(err) + } + + case *Array: + if err := buf.WriteByte(byte(tagArray)); err != nil { + return lazyerrors.Error(err) + } + + b := make([]byte, SizeCString(name)) + EncodeCString(b, name) + + if _, err := buf.Write(b); err != nil { + return lazyerrors.Error(err) + } + + b, err := v.Encode() + if err != nil { + return lazyerrors.Error(err) + } + + if _, err := buf.Write(b); err != nil { + return lazyerrors.Error(err) + } + + case RawArray: + if err := buf.WriteByte(byte(tagArray)); err != nil { + return lazyerrors.Error(err) + } + + b := make([]byte, SizeCString(name)) + EncodeCString(b, name) + + if _, err := buf.Write(b); err != nil { + return lazyerrors.Error(err) + } + + if _, err := buf.Write(v); err != nil { + return lazyerrors.Error(err) + } + + default: + return encodeScalarField(buf, name, v) + } + + return nil +} + +// encodeScalarField encodes scalar document field. +// +// It panics if v is not a scalar value. +func encodeScalarField(buf *bytes.Buffer, name string, v any) error { + switch v := v.(type) { + case float64: + if noNaN && math.IsNaN(v) { + return lazyerrors.Errorf("got NaN value") + } + + buf.WriteByte(byte(tagFloat64)) + case string: + buf.WriteByte(byte(tagString)) + case Binary: + buf.WriteByte(byte(tagBinary)) + case ObjectID: + buf.WriteByte(byte(tagObjectID)) + case bool: + buf.WriteByte(byte(tagBool)) + case time.Time: + buf.WriteByte(byte(tagTime)) + case NullType: + buf.WriteByte(byte(tagNull)) + case Regex: + buf.WriteByte(byte(tagRegex)) + case int32: + buf.WriteByte(byte(tagInt32)) + case Timestamp: + buf.WriteByte(byte(tagTimestamp)) + case int64: + buf.WriteByte(byte(tagInt64)) + default: + panic(fmt.Sprintf("invalid BSON type %T", v)) + } + + b := make([]byte, SizeCString(name)) + EncodeCString(b, name) + + if _, err := buf.Write(b); err != nil { + return lazyerrors.Error(err) + } + + b = make([]byte, bsonproto.SizeAny(v)) + bsonproto.EncodeAny(b, v) + + if _, err := buf.Write(b); err != nil { + return lazyerrors.Error(err) + } + + return nil +} diff --git a/internal/bson2/raw_array.go b/internal/bson2/raw_array.go index d262c405dc59..504b89ccc5d0 100644 --- a/internal/bson2/raw_array.go +++ b/internal/bson2/raw_array.go @@ -58,16 +58,6 @@ func (raw RawArray) DecodeDeep() (*Array, error) { return res, nil } -// Check recursively checks that the whole byte slice contains a single valid BSON document. -func (raw RawArray) Check() error { - _, err := raw.decode(decodeCheckOnly) - if err != nil { - return lazyerrors.Error(err) - } - - return nil -} - // Convert converts a single valid BSON array that takes the whole byte slice into [*types.Array]. func (raw RawArray) Convert() (*types.Array, error) { arr, err := raw.decode(decodeShallow) @@ -90,10 +80,6 @@ func (raw RawArray) decode(mode decodeMode) (*Array, error) { return nil, lazyerrors.Error(err) } - if mode == decodeCheckOnly { - return nil, nil - } - res := &Array{ elements: make([]any, len(doc.fields)), } diff --git a/internal/bson2/raw_document.go b/internal/bson2/raw_document.go index 34c24fa776d7..1f234f3106a4 100644 --- a/internal/bson2/raw_document.go +++ b/internal/bson2/raw_document.go @@ -15,11 +15,7 @@ package bson2 import ( - "encoding/binary" "log/slog" - "math" - - "github.com/cristalhq/bson/bsonproto" "github.com/FerretDB/FerretDB/internal/types" "github.com/FerretDB/FerretDB/internal/util/lazyerrors" @@ -31,30 +27,6 @@ import ( // It generally references a part of a larger slice, not a copy. type RawDocument []byte -// FindRawDocument returns the first BSON document in the byte slice. -// It should start at the first byte. -// -// Returned document might not be valid. It is the caller's responsibility to check it. -// -// Use RawDocument(b) conversion instead of b contains exactly one document and no extra bytes. -func FindRawDocument(b []byte) RawDocument { - bl := len(b) - if bl < 5 { - return nil - } - - dl := int(binary.LittleEndian.Uint32(b)) - if bl < dl { - return nil - } - - if b[dl-1] != 0 { - return nil - } - - return b[:dl] -} - // LogValue implements slog.LogValuer interface. func (doc RawDocument) LogValue() slog.Value { return slogValue(doc) @@ -86,16 +58,6 @@ func (raw RawDocument) DecodeDeep() (*Document, error) { return res, nil } -// Check recursively checks that the whole byte slice contains a single valid BSON document. -func (raw RawDocument) Check() error { - _, err := raw.decode(decodeCheckOnly) - if err != nil { - return lazyerrors.Error(err) - } - - return nil -} - // Convert converts a single valid BSON document that takes the whole byte slice into [*types.Document]. func (raw RawDocument) Convert() (*types.Document, error) { doc, err := raw.decode(decodeShallow) @@ -113,172 +75,93 @@ func (raw RawDocument) Convert() (*types.Document, error) { // decode decodes a single BSON document that takes the whole byte slice. func (raw RawDocument) decode(mode decodeMode) (*Document, error) { - bl := len(raw) - if bl < 5 { - return nil, lazyerrors.Errorf("len(b) = %d: %w", bl, ErrDecodeShortInput) - } - - if dl := int(binary.LittleEndian.Uint32(raw)); bl != dl { - return nil, lazyerrors.Errorf("len(b) = %d, document length = %d: %w", bl, dl, ErrDecodeInvalidInput) + l, err := FindRaw(raw) + if err != nil { + return nil, lazyerrors.Error(err) } - if last := raw[bl-1]; last != 0 { - return nil, lazyerrors.Errorf("last = %d: %w", last, ErrDecodeInvalidInput) + if rl := len(raw); rl != l { + return nil, lazyerrors.Errorf("len(raw) = %d, l = %d: %w", rl, l, ErrDecodeInvalidInput) } - var res *Document - if mode != decodeCheckOnly { - res = MakeDocument(1) - } + res := MakeDocument(1) offset := 4 - for offset != len(raw)-1 { + + for { if err := decodeCheckOffset(raw, offset, 1); err != nil { return nil, lazyerrors.Error(err) } t := tag(raw[offset]) + if t == 0 { + if rl := len(raw); rl != offset+1 { + return nil, lazyerrors.Errorf("len(raw) = %d, offset = %d, got %s: %w", rl, offset, t, ErrDecodeInvalidInput) + } + + return res, nil + } + offset++ if err := decodeCheckOffset(raw, offset, 1); err != nil { return nil, lazyerrors.Error(err) } - name, err := bsonproto.DecodeCString(raw[offset:]) + name, err := DecodeCString(raw[offset:]) if err != nil { return nil, lazyerrors.Error(err) } - offset += len(name) + 1 + offset += SizeCString(name) var v any - switch t { - case tagFloat64: - var f float64 - f, err = bsonproto.DecodeFloat64(raw[offset:]) - offset += bsonproto.SizeFloat64 - v = f - - if noNaN && math.IsNaN(f) { - return nil, lazyerrors.Errorf("got NaN value: %w", ErrDecodeInvalidInput) - } - - case tagString: - var s string - s, err = bsonproto.DecodeString(raw[offset:]) - offset += bsonproto.SizeString(s) - v = s + // to check if we can even `raw[offset:]` below + if err = decodeCheckOffset(raw, offset, 0); err != nil { + return nil, lazyerrors.Error(err) + } + switch t { //nolint:exhaustive // other tags are handled by decodeScalarField case tagDocument: - if err = decodeCheckOffset(raw, offset, 4); err != nil { - return nil, lazyerrors.Error(err) + if l, err = FindRaw(raw[offset:]); err != nil { + return nil, lazyerrors.Errorf("no document at offset = %d: %w", offset, err) } - l := int(binary.LittleEndian.Uint32(raw[offset:])) - - if err = decodeCheckOffset(raw, offset, l); err != nil { - return nil, lazyerrors.Error(err) - } - - doc := RawDocument(raw[offset : offset+l]) + rawDoc := RawDocument(raw[offset : offset+l]) offset += l switch mode { case decodeShallow: - v = doc + v = rawDoc case decodeDeep: - v, err = doc.decode(decodeDeep) - case decodeCheckOnly: - _, err = doc.decode(decodeCheckOnly) + v, err = rawDoc.decode(decodeDeep) } case tagArray: - if err = decodeCheckOffset(raw, offset, 4); err != nil { - return nil, lazyerrors.Error(err) - } - - l := int(binary.LittleEndian.Uint32(raw[offset:])) - - if err = decodeCheckOffset(raw, offset, l); err != nil { - return nil, lazyerrors.Error(err) + if l, err = FindRaw(raw[offset:]); err != nil { + return nil, lazyerrors.Errorf("no array at offset = %d: %w", offset, err) } - raw := RawArray(raw[offset : offset+l]) + rawArr := RawArray(raw[offset : offset+l]) offset += l switch mode { case decodeShallow: - v = raw + v = rawArr case decodeDeep: - v, err = raw.decode(decodeDeep) - case decodeCheckOnly: - _, err = raw.decode(decodeCheckOnly) + v, err = rawArr.decode(decodeDeep) } - case tagBinary: - var s Binary - s, err = bsonproto.DecodeBinary(raw[offset:]) - offset += bsonproto.SizeBinary(s) - v = s - - case tagObjectID: - v, err = bsonproto.DecodeObjectID(raw[offset:]) - offset += bsonproto.SizeObjectID - - case tagBool: - v, err = bsonproto.DecodeBool(raw[offset:]) - offset += bsonproto.SizeBool - - case tagTime: - v, err = bsonproto.DecodeTime(raw[offset:]) - offset += bsonproto.SizeTime - - case tagNull: - v = Null - - case tagRegex: - var s Regex - s, err = bsonproto.DecodeRegex(raw[offset:]) - offset += bsonproto.SizeRegex(s) - v = s - - case tagInt32: - v, err = bsonproto.DecodeInt32(raw[offset:]) - offset += bsonproto.SizeInt32 - - case tagTimestamp: - v, err = bsonproto.DecodeTimestamp(raw[offset:]) - offset += bsonproto.SizeTimestamp - - case tagInt64: - v, err = bsonproto.DecodeInt64(raw[offset:]) - offset += bsonproto.SizeInt64 - - case tagUndefined, tagDBPointer, tagJavaScript, tagSymbol, tagJavaScriptScope, tagDecimal, tagMinKey, tagMaxKey: - return nil, lazyerrors.Errorf("unsupported tag: %s", t) - default: - return nil, lazyerrors.Errorf("unexpected tag: %s", t) + v, l, err = decodeScalarField(raw[offset:], t) + offset += l } if err != nil { return nil, lazyerrors.Error(err) } - if mode != decodeCheckOnly { - must.NoError(res.add(name, v)) - } + must.NoError(res.add(name, v)) } - - return res, nil -} - -// decodeCheckOffset checks that b has enough bytes to decode size bytes starting from offset. -func decodeCheckOffset(b []byte, offset, size int) error { - if len(b[offset:]) < size+1 { - return lazyerrors.Errorf("offset = %d, size = %d: %w", offset, size, ErrDecodeShortInput) - } - - return nil } From b0b7a58bab48bc1f80abf0cc5e736ec33c8017b5 Mon Sep 17 00:00:00 2001 From: Alexey Palazhchenko Date: Thu, 22 Feb 2024 05:23:35 +0400 Subject: [PATCH 11/13] Use `bson2` package for wire queries and replies (#4108) --- internal/handler/cmd_query.go | 4 +- internal/handler/common/ismaster.go | 4 +- internal/handler/sjson/sjson_test.go | 7 +- internal/wire/msg_body.go | 11 +- internal/wire/op_msg.go | 18 +- internal/wire/op_msg_test.go | 823 ++++++++++++++------------- internal/wire/op_query.go | 167 +++--- internal/wire/op_query_test.go | 149 ++--- internal/wire/op_reply.go | 156 +++-- internal/wire/op_reply_test.go | 131 ++--- internal/wire/wire_test.go | 9 + 11 files changed, 758 insertions(+), 721 deletions(-) diff --git a/internal/handler/cmd_query.go b/internal/handler/cmd_query.go index c1c563938ddc..57fb7ab8a3c1 100644 --- a/internal/handler/cmd_query.go +++ b/internal/handler/cmd_query.go @@ -41,9 +41,7 @@ func (h *Handler) CmdQuery(ctx context.Context, query *wire.OpQuery) (*wire.OpRe if cmd == "saslStart" && strings.HasSuffix(collection, ".$cmd") { var emptyPayload types.Binary - reply := wire.OpReply{ - NumberReturned: 1, - } + var reply wire.OpReply reply.SetDocument(must.NotFail(types.NewDocument( "conversationId", int32(1), "done", true, diff --git a/internal/handler/common/ismaster.go b/internal/handler/common/ismaster.go index 69c83f5a58cc..acba20d5c985 100644 --- a/internal/handler/common/ismaster.go +++ b/internal/handler/common/ismaster.go @@ -31,9 +31,7 @@ func IsMaster(ctx context.Context, query *types.Document, tcpHost, name string) return nil, lazyerrors.Error(err) } - reply := wire.OpReply{ - NumberReturned: 1, - } + var reply wire.OpReply reply.SetDocument(IsMasterDocument(tcpHost, name)) return &reply, nil diff --git a/internal/handler/sjson/sjson_test.go b/internal/handler/sjson/sjson_test.go index da040a29c454..bf58fbc65f24 100644 --- a/internal/handler/sjson/sjson_test.go +++ b/internal/handler/sjson/sjson_test.go @@ -216,7 +216,12 @@ func addRecordedFuzzDocs(f *testing.F, needDocument, needSchema bool) int { } case *wire.OpReply: - docs = append(docs, b.Documents()...) + doc, err := b.Document() + require.NoError(f, err) + + if doc != nil { + docs = append(docs, doc) + } } for _, doc := range docs { diff --git a/internal/wire/msg_body.go b/internal/wire/msg_body.go index 005dda69ed80..658441ac890c 100644 --- a/internal/wire/msg_body.go +++ b/internal/wire/msg_body.go @@ -32,8 +32,9 @@ import ( type MsgBody interface { msgbody() // seal for sumtype - readFrom(*bufio.Reader) error - encoding.BinaryUnmarshaler + // UnmarshalBinaryNocopy is a variant of [encoding.BinaryUnmarshaler] that does not have to copy the data. + UnmarshalBinaryNocopy([]byte) error + encoding.BinaryMarshaler fmt.Stringer } @@ -59,7 +60,7 @@ func ReadMessage(r *bufio.Reader) (*MsgHeader, MsgBody, error) { switch header.OpCode { case OpCodeReply: // not sent by clients, but we should be able to read replies from a proxy var reply OpReply - if err := reply.UnmarshalBinary(b); err != nil { + if err := reply.UnmarshalBinaryNocopy(b); err != nil { return nil, nil, lazyerrors.Error(err) } @@ -71,7 +72,7 @@ func ReadMessage(r *bufio.Reader) (*MsgHeader, MsgBody, error) { } var msg OpMsg - if err := msg.UnmarshalBinary(b); err != nil { + if err := msg.UnmarshalBinaryNocopy(b); err != nil { return &header, nil, lazyerrors.Error(err) } @@ -79,7 +80,7 @@ func ReadMessage(r *bufio.Reader) (*MsgHeader, MsgBody, error) { case OpCodeQuery: var query OpQuery - if err := query.UnmarshalBinary(b); err != nil { + if err := query.UnmarshalBinaryNocopy(b); err != nil { return nil, nil, lazyerrors.Error(err) } diff --git a/internal/wire/op_msg.go b/internal/wire/op_msg.go index 8fa3baae3a2b..2801d66ef4e0 100644 --- a/internal/wire/op_msg.go +++ b/internal/wire/op_msg.go @@ -164,7 +164,11 @@ func (msg *OpMsg) RawDocument() (bson2.RawDocument, error) { func (msg *OpMsg) msgbody() {} -func (msg *OpMsg) readFrom(bufr *bufio.Reader) error { +// UnmarshalBinaryNocopy implements [MsgBody] interface. +func (msg *OpMsg) UnmarshalBinaryNocopy(b []byte) error { + br := bytes.NewReader(b) + bufr := bufio.NewReader(br) + if err := binary.Read(bufr, binary.LittleEndian, &msg.FlagBits); err != nil { return lazyerrors.Error(err) } @@ -255,18 +259,6 @@ func (msg *OpMsg) readFrom(bufr *bufio.Reader) error { return err } - return nil -} - -// UnmarshalBinary reads an OpMsg from a byte array. -func (msg *OpMsg) UnmarshalBinary(b []byte) error { - br := bytes.NewReader(b) - bufr := bufio.NewReader(br) - - if err := msg.readFrom(bufr); err != nil { - return lazyerrors.Error(err) - } - if _, err := bufr.Peek(1); err != io.EOF { return lazyerrors.Errorf("unexpected end of the OpMsg: %v", err) } diff --git a/internal/wire/op_msg_test.go b/internal/wire/op_msg_test.go index d812b8fde506..d371f8f59eb1 100644 --- a/internal/wire/op_msg_test.go +++ b/internal/wire/op_msg_test.go @@ -23,430 +23,455 @@ import ( "github.com/FerretDB/FerretDB/internal/util/testutil" ) -var msgTestCases = []testCase{{ - name: "handshake5", - headerB: testutil.MustParseDumpFile("testdata", "handshake5_header.hex"), - bodyB: testutil.MustParseDumpFile("testdata", "handshake5_body.hex"), - msgHeader: &MsgHeader{ - MessageLength: 92, - RequestID: 3, - OpCode: OpCodeMsg, - }, - msgBody: &OpMsg{ - sections: []OpMsgSection{{ - documents: []*types.Document{must.NotFail(types.NewDocument( - "buildInfo", int32(1), - "lsid", must.NotFail(types.NewDocument( - "id", types.Binary{ - Subtype: types.BinaryUUID, - B: []byte{ - 0xa3, 0x19, 0xf2, 0xb4, 0xa1, 0x75, 0x40, 0xc7, - 0xb8, 0xe7, 0xa3, 0xa3, 0x2e, 0xc2, 0x56, 0xbe, +var msgTestCases = []testCase{ + { + name: "handshake5", + headerB: testutil.MustParseDumpFile("testdata", "handshake5_header.hex"), + bodyB: testutil.MustParseDumpFile("testdata", "handshake5_body.hex"), + msgHeader: &MsgHeader{ + MessageLength: 92, + RequestID: 3, + OpCode: OpCodeMsg, + }, + msgBody: &OpMsg{ + sections: []OpMsgSection{{ + documents: []*types.Document{must.NotFail(types.NewDocument( + "buildInfo", int32(1), + "lsid", must.NotFail(types.NewDocument( + "id", types.Binary{ + Subtype: types.BinaryUUID, + B: []byte{ + 0xa3, 0x19, 0xf2, 0xb4, 0xa1, 0x75, 0x40, 0xc7, + 0xb8, 0xe7, 0xa3, 0xa3, 0x2e, 0xc2, 0x56, 0xbe, + }, }, - }, - )), - "$db", "admin", - ))}, - }}, - }, - command: "buildInfo", -}, { - name: "handshake6", - headerB: testutil.MustParseDumpFile("testdata", "handshake6_header.hex"), - bodyB: testutil.MustParseDumpFile("testdata", "handshake6_body.hex"), - msgHeader: &MsgHeader{ - MessageLength: 1931, - RequestID: 292, - ResponseTo: 3, - OpCode: OpCodeMsg, - }, - msgBody: &OpMsg{ - sections: []OpMsgSection{{ - documents: []*types.Document{must.NotFail(types.NewDocument( - "version", "5.0.0", - "gitVersion", "1184f004a99660de6f5e745573419bda8a28c0e9", - "modules", must.NotFail(types.NewArray()), - "allocator", "tcmalloc", - "javascriptEngine", "mozjs", - "sysInfo", "deprecated", - "versionArray", must.NotFail(types.NewArray(int32(5), int32(0), int32(0), int32(0))), - "openssl", must.NotFail(types.NewDocument( - "running", "OpenSSL 1.1.1f 31 Mar 2020", - "compiled", "OpenSSL 1.1.1f 31 Mar 2020", - )), - "buildEnvironment", must.NotFail(types.NewDocument( - "distmod", "ubuntu2004", - "distarch", "x86_64", - "cc", "/opt/mongodbtoolchain/v3/bin/gcc: gcc (GCC) 8.5.0", - "ccflags", "-Werror -include mongo/platform/basic.h -fasynchronous-unwind-tables "+ - "-ggdb -Wall -Wsign-compare -Wno-unknown-pragmas -Winvalid-pch -fno-omit-frame-pointer "+ - "-fno-strict-aliasing -O2 -march=sandybridge -mtune=generic -mprefer-vector-width=128 "+ - "-Wno-unused-local-typedefs -Wno-unused-function -Wno-deprecated-declarations "+ - "-Wno-unused-const-variable -Wno-unused-but-set-variable -Wno-missing-braces "+ - "-fstack-protector-strong -Wa,--nocompress-debug-sections -fno-builtin-memcmp", - "cxx", "/opt/mongodbtoolchain/v3/bin/g++: g++ (GCC) 8.5.0", - "cxxflags", "-Woverloaded-virtual -Wno-maybe-uninitialized -fsized-deallocation -std=c++17", - "linkflags", "-Wl,--fatal-warnings -pthread -Wl,-z,now -fuse-ld=gold -fstack-protector-strong "+ - "-Wl,--no-threads -Wl,--build-id -Wl,--hash-style=gnu -Wl,-z,noexecstack -Wl,--warn-execstack "+ - "-Wl,-z,relro -Wl,--compress-debug-sections=none -Wl,-z,origin -Wl,--enable-new-dtags", - "target_arch", "x86_64", - "target_os", "linux", - "cppdefines", "SAFEINT_USE_INTRINSICS 0 PCRE_STATIC NDEBUG _XOPEN_SOURCE 700 "+ - "_GNU_SOURCE _REENTRANT 1 _FORTIFY_SOURCE 2 BOOST_THREAD_VERSION 5 "+ - "BOOST_THREAD_USES_DATETIME BOOST_SYSTEM_NO_DEPRECATED "+ - "BOOST_MATH_NO_LONG_DOUBLE_MATH_FUNCTIONS BOOST_ENABLE_ASSERT_DEBUG_HANDLER "+ - "BOOST_LOG_NO_SHORTHAND_NAMES BOOST_LOG_USE_NATIVE_SYSLOG "+ - "BOOST_LOG_WITHOUT_THREAD_ATTR ABSL_FORCE_ALIGNED_ACCESS", - )), - "bits", int32(64), - "debug", false, - "maxBsonObjectSize", int32(16777216), - "storageEngines", must.NotFail(types.NewArray("devnull", "ephemeralForTest", "wiredTiger")), - "ok", float64(1), - ))}, - }}, + )), + "$db", "admin", + ))}, + }}, + }, + command: "buildInfo", }, - command: "version", -}, { - name: "import", - expectedB: testutil.MustParseDumpFile("testdata", "import.hex"), - msgHeader: &MsgHeader{ - MessageLength: 327, - RequestID: 7, - OpCode: OpCodeMsg, + { + name: "handshake6", + headerB: testutil.MustParseDumpFile("testdata", "handshake6_header.hex"), + bodyB: testutil.MustParseDumpFile("testdata", "handshake6_body.hex"), + msgHeader: &MsgHeader{ + MessageLength: 1931, + RequestID: 292, + ResponseTo: 3, + OpCode: OpCodeMsg, + }, + msgBody: &OpMsg{ + sections: []OpMsgSection{{ + documents: []*types.Document{must.NotFail(types.NewDocument( + "version", "5.0.0", + "gitVersion", "1184f004a99660de6f5e745573419bda8a28c0e9", + "modules", must.NotFail(types.NewArray()), + "allocator", "tcmalloc", + "javascriptEngine", "mozjs", + "sysInfo", "deprecated", + "versionArray", must.NotFail(types.NewArray(int32(5), int32(0), int32(0), int32(0))), + "openssl", must.NotFail(types.NewDocument( + "running", "OpenSSL 1.1.1f 31 Mar 2020", + "compiled", "OpenSSL 1.1.1f 31 Mar 2020", + )), + "buildEnvironment", must.NotFail(types.NewDocument( + "distmod", "ubuntu2004", + "distarch", "x86_64", + "cc", "/opt/mongodbtoolchain/v3/bin/gcc: gcc (GCC) 8.5.0", + "ccflags", "-Werror -include mongo/platform/basic.h -fasynchronous-unwind-tables "+ + "-ggdb -Wall -Wsign-compare -Wno-unknown-pragmas -Winvalid-pch -fno-omit-frame-pointer "+ + "-fno-strict-aliasing -O2 -march=sandybridge -mtune=generic -mprefer-vector-width=128 "+ + "-Wno-unused-local-typedefs -Wno-unused-function -Wno-deprecated-declarations "+ + "-Wno-unused-const-variable -Wno-unused-but-set-variable -Wno-missing-braces "+ + "-fstack-protector-strong -Wa,--nocompress-debug-sections -fno-builtin-memcmp", + "cxx", "/opt/mongodbtoolchain/v3/bin/g++: g++ (GCC) 8.5.0", + "cxxflags", "-Woverloaded-virtual -Wno-maybe-uninitialized -fsized-deallocation -std=c++17", + "linkflags", "-Wl,--fatal-warnings -pthread -Wl,-z,now -fuse-ld=gold -fstack-protector-strong "+ + "-Wl,--no-threads -Wl,--build-id -Wl,--hash-style=gnu -Wl,-z,noexecstack -Wl,--warn-execstack "+ + "-Wl,-z,relro -Wl,--compress-debug-sections=none -Wl,-z,origin -Wl,--enable-new-dtags", + "target_arch", "x86_64", + "target_os", "linux", + "cppdefines", "SAFEINT_USE_INTRINSICS 0 PCRE_STATIC NDEBUG _XOPEN_SOURCE 700 "+ + "_GNU_SOURCE _REENTRANT 1 _FORTIFY_SOURCE 2 BOOST_THREAD_VERSION 5 "+ + "BOOST_THREAD_USES_DATETIME BOOST_SYSTEM_NO_DEPRECATED "+ + "BOOST_MATH_NO_LONG_DOUBLE_MATH_FUNCTIONS BOOST_ENABLE_ASSERT_DEBUG_HANDLER "+ + "BOOST_LOG_NO_SHORTHAND_NAMES BOOST_LOG_USE_NATIVE_SYSLOG "+ + "BOOST_LOG_WITHOUT_THREAD_ATTR ABSL_FORCE_ALIGNED_ACCESS", + )), + "bits", int32(64), + "debug", false, + "maxBsonObjectSize", int32(16777216), + "storageEngines", must.NotFail(types.NewArray("devnull", "ephemeralForTest", "wiredTiger")), + "ok", float64(1), + ))}, + }}, + }, + command: "version", }, - msgBody: &OpMsg{ - sections: []OpMsgSection{{ - documents: []*types.Document{must.NotFail(types.NewDocument( - "insert", "actor", - "ordered", true, - "writeConcern", must.NotFail(types.NewDocument( - "w", "majority", - )), - "$db", "monila", - ))}, - }, { - Kind: 1, - Identifier: "documents", - documents: []*types.Document{ - must.NotFail(types.NewDocument( - "_id", types.ObjectID{0x61, 0x2e, 0xc2, 0x80, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01}, - "actor_id", int32(1), - "first_name", "PENELOPE", - "last_name", "GUINESS", - "last_update", lastUpdate, - )), - must.NotFail(types.NewDocument( - "_id", types.ObjectID{0x61, 0x2e, 0xc2, 0x80, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02}, - "actor_id", int32(2), - "first_name", "NICK", - "last_name", "WAHLBERG", - "last_update", lastUpdate, - )), + { + name: "import", + expectedB: testutil.MustParseDumpFile("testdata", "import.hex"), + msgHeader: &MsgHeader{ + MessageLength: 327, + RequestID: 7, + OpCode: OpCodeMsg, + }, + msgBody: &OpMsg{ + sections: []OpMsgSection{ + { + documents: []*types.Document{must.NotFail(types.NewDocument( + "insert", "actor", + "ordered", true, + "writeConcern", must.NotFail(types.NewDocument( + "w", "majority", + )), + "$db", "monila", + ))}, + }, + { + Kind: 1, + Identifier: "documents", + documents: []*types.Document{ + must.NotFail(types.NewDocument( + "_id", types.ObjectID{0x61, 0x2e, 0xc2, 0x80, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01}, + "actor_id", int32(1), + "first_name", "PENELOPE", + "last_name", "GUINESS", + "last_update", lastUpdate, + )), + must.NotFail(types.NewDocument( + "_id", types.ObjectID{0x61, 0x2e, 0xc2, 0x80, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02}, + "actor_id", int32(2), + "first_name", "NICK", + "last_name", "WAHLBERG", + "last_update", lastUpdate, + )), + }, + }, }, - }}, - }, - command: "insert", -}, { - name: "msg_fuzz1", - expectedB: testutil.MustParseDumpFile("testdata", "msg_fuzz1.hex"), - err: `wire.OpMsg.readFrom: invalid kind 1 section length -13619152`, -}, { - name: "NaN", - expectedB: []byte{ - 0x79, 0x00, 0x00, 0x00, // MessageLength - 0x11, 0x00, 0x00, 0x00, // RequestID - 0x00, 0x00, 0x00, 0x00, // ResponseTo - 0xdd, 0x07, 0x00, 0x00, // OpCode - 0x00, 0x00, 0x00, 0x00, // FlagBits - 0x00, // section kind - 0x64, 0x00, 0x00, 0x00, // document size - 0x02, 0x69, 0x6e, 0x73, 0x65, 0x72, 0x74, 0x00, // string "insert" - 0x07, 0x00, 0x00, 0x00, // "values" length - 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x00, // "values" - 0x04, 0x64, 0x6f, 0x63, 0x75, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x00, // array "documents" - 0x29, 0x00, 0x00, 0x00, // document (array) size - 0x03, 0x30, 0x00, // element 0 (document) - 0x21, 0x00, 0x00, 0x00, // element 0 size - 0x01, 0x76, 0x00, // double "v" - 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf8, 0x7f, // NaN - 0x07, 0x5f, 0x69, 0x64, 0x00, // ObjectID "_id" - 0x63, 0x77, 0xf2, 0x13, 0x75, 0x7c, 0x0b, 0xab, 0xde, 0xbc, 0x2f, 0x6a, // ObjectID value - 0x00, // end of element 0 (document) - 0x00, // end of document (array) - 0x08, 0x6f, 0x72, 0x64, 0x65, 0x72, 0x65, 0x64, 0x00, 0x01, // "ordered" true - 0x02, 0x24, 0x64, 0x62, 0x00, // "$db" - 0x05, 0x00, 0x00, 0x00, // "test" length - 0x74, 0x65, 0x73, 0x74, 0x00, // "test" - 0x00, // end of document + }, + command: "insert", }, - msgHeader: &MsgHeader{ - MessageLength: 121, - RequestID: 17, - OpCode: OpCodeMsg, + { + name: "msg_fuzz1", + expectedB: testutil.MustParseDumpFile("testdata", "msg_fuzz1.hex"), + err: `wire.OpMsg.readFrom: invalid kind 1 section length -13619152`, }, - msgBody: &OpMsg{ - sections: []OpMsgSection{{ - documents: []*types.Document{must.NotFail(types.NewDocument( - "insert", "values", - "documents", must.NotFail(types.NewArray( - must.NotFail(types.NewDocument( - "v", math.NaN(), - "_id", types.ObjectID{0x63, 0x77, 0xf2, 0x13, 0x75, 0x7c, 0x0b, 0xab, 0xde, 0xbc, 0x2f, 0x6a}, + { + name: "NaN", + expectedB: []byte{ + 0x79, 0x00, 0x00, 0x00, // MessageLength + 0x11, 0x00, 0x00, 0x00, // RequestID + 0x00, 0x00, 0x00, 0x00, // ResponseTo + 0xdd, 0x07, 0x00, 0x00, // OpCode + 0x00, 0x00, 0x00, 0x00, // FlagBits + 0x00, // section kind + 0x64, 0x00, 0x00, 0x00, // document size + 0x02, 0x69, 0x6e, 0x73, 0x65, 0x72, 0x74, 0x00, // string "insert" + 0x07, 0x00, 0x00, 0x00, // "values" length + 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x00, // "values" + 0x04, 0x64, 0x6f, 0x63, 0x75, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x00, // array "documents" + 0x29, 0x00, 0x00, 0x00, // document (array) size + 0x03, 0x30, 0x00, // element 0 (document) + 0x21, 0x00, 0x00, 0x00, // element 0 size + 0x01, 0x76, 0x00, // double "v" + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf8, 0x7f, // NaN + 0x07, 0x5f, 0x69, 0x64, 0x00, // ObjectID "_id" + 0x63, 0x77, 0xf2, 0x13, 0x75, 0x7c, 0x0b, 0xab, 0xde, 0xbc, 0x2f, 0x6a, // ObjectID value + 0x00, // end of element 0 (document) + 0x00, // end of document (array) + 0x08, 0x6f, 0x72, 0x64, 0x65, 0x72, 0x65, 0x64, 0x00, 0x01, // "ordered" true + 0x02, 0x24, 0x64, 0x62, 0x00, // "$db" + 0x05, 0x00, 0x00, 0x00, // "test" length + 0x74, 0x65, 0x73, 0x74, 0x00, // "test" + 0x00, // end of document + }, + msgHeader: &MsgHeader{ + MessageLength: 121, + RequestID: 17, + OpCode: OpCodeMsg, + }, + msgBody: &OpMsg{ + sections: []OpMsgSection{{ + documents: []*types.Document{must.NotFail(types.NewDocument( + "insert", "values", + "documents", must.NotFail(types.NewArray( + must.NotFail(types.NewDocument( + "v", math.NaN(), + "_id", types.ObjectID{0x63, 0x77, 0xf2, 0x13, 0x75, 0x7c, 0x0b, 0xab, 0xde, 0xbc, 0x2f, 0x6a}, + )), )), - )), - "ordered", true, - "$db", "test", - ))}, - }}, + "ordered", true, + "$db", "test", + ))}, + }}, + }, + err: `wire.OpMsg.Document: validation failed for { insert: "values", documents: ` + + `[ { v: nan.0, _id: ObjectId('6377f213757c0babdebc2f6a') } ], ordered: true, $db: "test" }` + + ` with: NaN is not supported`, }, - err: `wire.OpMsg.Document: validation failed for { insert: "values", documents: ` + - `[ { v: nan.0, _id: ObjectId('6377f213757c0babdebc2f6a') } ], ordered: true, $db: "test" }` + - ` with: NaN is not supported`, -}, { - name: "negative zero", - expectedB: []byte{ - 0x8b, 0x00, 0x00, 0x00, - 0x0c, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0xdd, 0x07, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, - 0x46, 0x00, 0x00, 0x00, - 0x02, 0x69, 0x6e, 0x73, 0x65, 0x72, 0x74, 0x00, - 0x11, 0x00, 0x00, 0x00, - 0x54, 0x65, 0x73, 0x74, 0x49, 0x6e, 0x73, 0x65, 0x72, 0x74, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x00, - 0x08, 0x6f, 0x72, 0x64, 0x65, 0x72, 0x65, - 0x64, 0x00, 0x01, 0x02, 0x24, 0x64, 0x62, 0x00, 0x11, 0x00, - 0x00, 0x00, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x73, 0x65, - 0x72, 0x74, 0x73, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x00, 0x00, - 0x01, - 0x2f, 0x00, 0x00, 0x00, - 0x64, 0x6f, 0x63, 0x75, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x00, - 0x21, 0x00, 0x00, 0x00, - 0x07, 0x5f, 0x69, 0x64, 0x00, - 0x63, 0x7c, 0xfa, 0xd8, 0x8d, 0xc3, 0xce, 0xcd, 0xe3, 0x8e, 0x1e, 0x6b, - 0x01, 0x76, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, - 0x00, - }, - msgHeader: &MsgHeader{ - MessageLength: 139, - RequestID: 12, - OpCode: OpCodeMsg, - }, - msgBody: &OpMsg{ - sections: []OpMsgSection{{ - documents: []*types.Document{must.NotFail(types.NewDocument( - "insert", "TestInsertSimple", - "ordered", true, - "$db", "testinsertsimple", - ))}, - }, { - Kind: 1, - Identifier: "documents", - documents: []*types.Document{must.NotFail(types.NewDocument( - "_id", types.ObjectID{0x63, 0x7c, 0xfa, 0xd8, 0x8d, 0xc3, 0xce, 0xcd, 0xe3, 0x8e, 0x1e, 0x6b}, - "v", math.Copysign(0, -1), - ))}, - }}, + { + name: "negative zero", + expectedB: []byte{ + 0x8b, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0xdd, 0x07, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, + 0x46, 0x00, 0x00, 0x00, + 0x02, 0x69, 0x6e, 0x73, 0x65, 0x72, 0x74, 0x00, + 0x11, 0x00, 0x00, 0x00, + 0x54, 0x65, 0x73, 0x74, 0x49, 0x6e, 0x73, 0x65, 0x72, 0x74, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x00, + 0x08, 0x6f, 0x72, 0x64, 0x65, 0x72, 0x65, + 0x64, 0x00, 0x01, 0x02, 0x24, 0x64, 0x62, 0x00, 0x11, 0x00, + 0x00, 0x00, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x73, 0x65, + 0x72, 0x74, 0x73, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x00, 0x00, + 0x01, + 0x2f, 0x00, 0x00, 0x00, + 0x64, 0x6f, 0x63, 0x75, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x00, + 0x21, 0x00, 0x00, 0x00, + 0x07, 0x5f, 0x69, 0x64, 0x00, + 0x63, 0x7c, 0xfa, 0xd8, 0x8d, 0xc3, 0xce, 0xcd, 0xe3, 0x8e, 0x1e, 0x6b, + 0x01, 0x76, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, + 0x00, + }, + msgHeader: &MsgHeader{ + MessageLength: 139, + RequestID: 12, + OpCode: OpCodeMsg, + }, + msgBody: &OpMsg{ + sections: []OpMsgSection{ + { + documents: []*types.Document{must.NotFail(types.NewDocument( + "insert", "TestInsertSimple", + "ordered", true, + "$db", "testinsertsimple", + ))}, + }, + { + Kind: 1, + Identifier: "documents", + documents: []*types.Document{must.NotFail(types.NewDocument( + "_id", types.ObjectID{0x63, 0x7c, 0xfa, 0xd8, 0x8d, 0xc3, 0xce, 0xcd, 0xe3, 0x8e, 0x1e, 0x6b}, + "v", math.Copysign(0, -1), + ))}, + }, + }, + }, + command: "insert", }, - command: "insert", -}, { - name: "MultiSectionInsert", - expectedB: []byte{ - 0x76, 0x00, 0x00, 0x00, // MessageLength - 0x0f, 0x00, 0x00, 0x00, // RequestID - 0x00, 0x00, 0x00, 0x00, // ResponseTo - 0xdd, 0x07, 0x00, 0x00, // OpCode - 0x01, 0x00, 0x00, 0x00, // FlagBits + { + name: "MultiSectionInsert", + expectedB: []byte{ + 0x76, 0x00, 0x00, 0x00, // MessageLength + 0x0f, 0x00, 0x00, 0x00, // RequestID + 0x00, 0x00, 0x00, 0x00, // ResponseTo + 0xdd, 0x07, 0x00, 0x00, // OpCode + 0x01, 0x00, 0x00, 0x00, // FlagBits - 0x01, // section kind - 0x2f, 0x00, 0x00, 0x00, // section size - 0x64, 0x6f, 0x63, 0x75, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x00, // section identifier "documents" - 0x21, 0x00, 0x00, 0x00, // document size - 0x07, 0x5f, 0x69, 0x64, 0x00, // ObjectID "_id" - 0x63, 0x8c, 0xec, 0x46, 0xaa, 0x77, 0x8b, 0xf3, 0x70, 0x10, 0x54, 0x29, // ObjectID value - 0x01, 0x61, 0x00, // double "a" - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x40, // 3.0 - 0x00, // end of document + 0x01, // section kind + 0x2f, 0x00, 0x00, 0x00, // section size + 0x64, 0x6f, 0x63, 0x75, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x00, // section identifier "documents" + 0x21, 0x00, 0x00, 0x00, // document size + 0x07, 0x5f, 0x69, 0x64, 0x00, // ObjectID "_id" + 0x63, 0x8c, 0xec, 0x46, 0xaa, 0x77, 0x8b, 0xf3, 0x70, 0x10, 0x54, 0x29, // ObjectID value + 0x01, 0x61, 0x00, // double "a" + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x40, // 3.0 + 0x00, // end of document - 0x00, // section kind - 0x2d, 0x00, 0x00, 0x00, // document size - 0x02, 0x69, 0x6e, 0x73, 0x65, 0x72, 0x74, 0x00, // string "insert" - 0x04, 0x00, 0x00, 0x00, // "foo" length - 0x66, 0x6f, 0x6f, 0x00, // "foo" - 0x08, 0x6f, 0x72, 0x64, 0x65, 0x72, 0x65, 0x64, 0x00, 0x01, // "ordered" true - 0x02, 0x24, 0x64, 0x62, 0x00, // string "$db" - 0x05, 0x00, 0x00, 0x00, // "test" length - 0x74, 0x65, 0x73, 0x74, 0x00, // "test" - 0x00, // end of document + 0x00, // section kind + 0x2d, 0x00, 0x00, 0x00, // document size + 0x02, 0x69, 0x6e, 0x73, 0x65, 0x72, 0x74, 0x00, // string "insert" + 0x04, 0x00, 0x00, 0x00, // "foo" length + 0x66, 0x6f, 0x6f, 0x00, // "foo" + 0x08, 0x6f, 0x72, 0x64, 0x65, 0x72, 0x65, 0x64, 0x00, 0x01, // "ordered" true + 0x02, 0x24, 0x64, 0x62, 0x00, // string "$db" + 0x05, 0x00, 0x00, 0x00, // "test" length + 0x74, 0x65, 0x73, 0x74, 0x00, // "test" + 0x00, // end of document - 0xe2, 0xb7, 0x90, 0x67, // checksum - }, - msgHeader: &MsgHeader{ - MessageLength: 118, - RequestID: 15, - ResponseTo: 0, - OpCode: OpCodeMsg, - }, - msgBody: &OpMsg{ - FlagBits: OpMsgFlags(OpMsgChecksumPresent), - sections: []OpMsgSection{{ - Kind: 1, - Identifier: "documents", - documents: []*types.Document{must.NotFail(types.NewDocument( - "_id", types.ObjectID{0x63, 0x8c, 0xec, 0x46, 0xaa, 0x77, 0x8b, 0xf3, 0x70, 0x10, 0x54, 0x29}, - "a", float64(3), - ))}, - }, { - documents: []*types.Document{must.NotFail(types.NewDocument( - "insert", "foo", - "ordered", true, - "$db", "test", - ))}, - }}, - checksum: 1737537506, + 0xe2, 0xb7, 0x90, 0x67, // checksum + }, + msgHeader: &MsgHeader{ + MessageLength: 118, + RequestID: 15, + ResponseTo: 0, + OpCode: OpCodeMsg, + }, + msgBody: &OpMsg{ + FlagBits: OpMsgFlags(OpMsgChecksumPresent), + sections: []OpMsgSection{ + { + Kind: 1, + Identifier: "documents", + documents: []*types.Document{must.NotFail(types.NewDocument( + "_id", types.ObjectID{0x63, 0x8c, 0xec, 0x46, 0xaa, 0x77, 0x8b, 0xf3, 0x70, 0x10, 0x54, 0x29}, + "a", float64(3), + ))}, + }, + { + documents: []*types.Document{must.NotFail(types.NewDocument( + "insert", "foo", + "ordered", true, + "$db", "test", + ))}, + }, + }, + checksum: 1737537506, + }, + command: "insert", }, - command: "insert", -}, { - name: "MultiSectionUpdate", - expectedB: []byte{ - 0x9a, 0x00, 0x00, 0x00, // MessageLength - 0x0b, 0x00, 0x00, 0x00, // RequestID - 0x00, 0x00, 0x00, 0x00, // ResponseTo - 0xdd, 0x07, 0x00, 0x00, // OpCode - 0x01, 0x00, 0x00, 0x00, // FlagBits + { + name: "MultiSectionUpdate", + expectedB: []byte{ + 0x9a, 0x00, 0x00, 0x00, // MessageLength + 0x0b, 0x00, 0x00, 0x00, // RequestID + 0x00, 0x00, 0x00, 0x00, // ResponseTo + 0xdd, 0x07, 0x00, 0x00, // OpCode + 0x01, 0x00, 0x00, 0x00, // FlagBits - 0x01, // section kind - 0x53, 0x00, 0x00, 0x00, // section size - 0x75, 0x70, 0x64, 0x61, 0x74, 0x65, 0x73, 0x00, // section identifier "updates" - 0x47, 0x00, 0x00, 0x00, // document size + 0x01, // section kind + 0x53, 0x00, 0x00, 0x00, // section size + 0x75, 0x70, 0x64, 0x61, 0x74, 0x65, 0x73, 0x00, // section identifier "updates" + 0x47, 0x00, 0x00, 0x00, // document size - 0x03, 0x71, 0x00, // document "q" - 0x10, 0x00, 0x00, 0x00, // document size - 0x01, 0x61, 0x00, // double "a" - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x34, 0x40, // 20.0 - 0x00, // end of document + 0x03, 0x71, 0x00, // document "q" + 0x10, 0x00, 0x00, 0x00, // document size + 0x01, 0x61, 0x00, // double "a" + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x34, 0x40, // 20.0 + 0x00, // end of document - 0x03, 0x75, 0x00, // document "u" - 0x1b, 0x00, 0x00, 0x00, // document size - 0x03, 0x24, 0x69, 0x6e, 0x63, 0x00, // document "$inc" - 0x10, 0x00, 0x00, 0x00, // document size - 0x01, 0x61, 0x00, // double "a" - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf0, 0x3f, // 1.0 - 0x00, // end of document - 0x00, // end of document + 0x03, 0x75, 0x00, // document "u" + 0x1b, 0x00, 0x00, 0x00, // document size + 0x03, 0x24, 0x69, 0x6e, 0x63, 0x00, // document "$inc" + 0x10, 0x00, 0x00, 0x00, // document size + 0x01, 0x61, 0x00, // double "a" + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf0, 0x3f, // 1.0 + 0x00, // end of document + 0x00, // end of document - 0x08, 0x6d, 0x75, 0x6c, 0x74, 0x69, 0x00, 0x00, // "multi" false - 0x08, 0x75, 0x70, 0x73, 0x65, 0x72, 0x74, 0x00, 0x00, // "upsert" false + 0x08, 0x6d, 0x75, 0x6c, 0x74, 0x69, 0x00, 0x00, // "multi" false + 0x08, 0x75, 0x70, 0x73, 0x65, 0x72, 0x74, 0x00, 0x00, // "upsert" false - 0x00, // end of document + 0x00, // end of document - 0x00, // section kind - 0x2d, 0x00, 0x00, 0x00, // document size - 0x02, 0x75, 0x70, 0x64, 0x61, 0x74, 0x65, 0x00, // string "update" - 0x04, 0x00, 0x00, 0x00, // "foo" length - 0x66, 0x6f, 0x6f, 0x00, // "foo" - 0x08, 0x6f, 0x72, 0x64, 0x65, 0x72, 0x65, 0x64, 0x00, 0x01, // "ordered" true - 0x02, 0x24, 0x64, 0x62, 0x00, // string "$db" - 0x05, 0x00, 0x00, 0x00, // "test" length - 0x74, 0x65, 0x73, 0x74, 0x00, // "test" - 0x00, // end of document + 0x00, // section kind + 0x2d, 0x00, 0x00, 0x00, // document size + 0x02, 0x75, 0x70, 0x64, 0x61, 0x74, 0x65, 0x00, // string "update" + 0x04, 0x00, 0x00, 0x00, // "foo" length + 0x66, 0x6f, 0x6f, 0x00, // "foo" + 0x08, 0x6f, 0x72, 0x64, 0x65, 0x72, 0x65, 0x64, 0x00, 0x01, // "ordered" true + 0x02, 0x24, 0x64, 0x62, 0x00, // string "$db" + 0x05, 0x00, 0x00, 0x00, // "test" length + 0x74, 0x65, 0x73, 0x74, 0x00, // "test" + 0x00, // end of document - 0xf1, 0xfc, 0xd1, 0xae, // checksum - }, - msgHeader: &MsgHeader{ - MessageLength: 154, - RequestID: 11, - ResponseTo: 0, - OpCode: OpCodeMsg, - }, - msgBody: &OpMsg{ - FlagBits: OpMsgFlags(OpMsgChecksumPresent), - sections: []OpMsgSection{{ - Kind: 1, - Identifier: "updates", - documents: []*types.Document{must.NotFail(types.NewDocument( - "q", must.NotFail(types.NewDocument( - "a", float64(20), - )), - "u", must.NotFail(types.NewDocument( - "$inc", must.NotFail(types.NewDocument( - "a", float64(1), - )), - )), - "multi", false, - "upsert", false, - ))}, - }, { - documents: []*types.Document{must.NotFail(types.NewDocument( - "update", "foo", - "ordered", true, - "$db", "test", - ))}, - }}, - checksum: 2932997361, + 0xf1, 0xfc, 0xd1, 0xae, // checksum + }, + msgHeader: &MsgHeader{ + MessageLength: 154, + RequestID: 11, + ResponseTo: 0, + OpCode: OpCodeMsg, + }, + msgBody: &OpMsg{ + FlagBits: OpMsgFlags(OpMsgChecksumPresent), + sections: []OpMsgSection{ + { + Kind: 1, + Identifier: "updates", + documents: []*types.Document{must.NotFail(types.NewDocument( + "q", must.NotFail(types.NewDocument( + "a", float64(20), + )), + "u", must.NotFail(types.NewDocument( + "$inc", must.NotFail(types.NewDocument( + "a", float64(1), + )), + )), + "multi", false, + "upsert", false, + ))}, + }, + { + documents: []*types.Document{must.NotFail(types.NewDocument( + "update", "foo", + "ordered", true, + "$db", "test", + ))}, + }, + }, + checksum: 2932997361, + }, + command: "update", }, - command: "update", -}, { - name: "InvalidChecksum", - expectedB: []byte{ - 0x77, 0x00, 0x00, 0x00, // MessageLength - 0x0f, 0x00, 0x00, 0x00, // RequestID - 0x00, 0x00, 0x00, 0x00, // ResponseTo - 0xdd, 0x07, 0x00, 0x00, // OpCode - 0x01, 0x00, 0x00, 0x00, // FlagBits + { + name: "InvalidChecksum", + expectedB: []byte{ + 0x77, 0x00, 0x00, 0x00, // MessageLength + 0x0f, 0x00, 0x00, 0x00, // RequestID + 0x00, 0x00, 0x00, 0x00, // ResponseTo + 0xdd, 0x07, 0x00, 0x00, // OpCode + 0x01, 0x00, 0x00, 0x00, // FlagBits - 0x01, // section kind - 0x2f, 0x00, 0x00, 0x00, // section size - 0x64, 0x6f, 0x63, 0x75, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x00, // section identifier "documents" - 0x21, 0x00, 0x00, 0x00, // document size - 0x07, 0x5f, 0x69, 0x64, 0x00, // ObjectID "_id" - 0x63, 0x8c, 0xec, 0x46, 0xaa, 0x77, 0x8b, 0xf3, 0x70, 0x10, 0x54, 0x29, // ObjectID value - 0x01, 0x61, 0x00, // double "a" - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x40, // 3.0 - 0x00, // end of document + 0x01, // section kind + 0x2f, 0x00, 0x00, 0x00, // section size + 0x64, 0x6f, 0x63, 0x75, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x00, // section identifier "documents" + 0x21, 0x00, 0x00, 0x00, // document size + 0x07, 0x5f, 0x69, 0x64, 0x00, // ObjectID "_id" + 0x63, 0x8c, 0xec, 0x46, 0xaa, 0x77, 0x8b, 0xf3, 0x70, 0x10, 0x54, 0x29, // ObjectID value + 0x01, 0x61, 0x00, // double "a" + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x40, // 3.0 + 0x00, // end of document - 0x00, // section kind - 0x2d, 0x00, 0x00, 0x00, // document size - 0x02, 0x69, 0x6e, 0x73, 0x65, 0x72, 0x74, 0x00, // string "insert" - 0x04, 0x00, 0x00, 0x00, // "foo" length - 0x66, 0x6f, 0x6f, 0x6f, 0x00, // "fooo" - 0x08, 0x6f, 0x72, 0x64, 0x65, 0x72, 0x65, 0x64, 0x00, 0x01, // "ordered" true - 0x02, 0x24, 0x64, 0x62, 0x00, // string "$db" - 0x05, 0x00, 0x00, 0x00, // "test" length - 0x74, 0x65, 0x73, 0x74, 0x00, // "test" - 0x00, // end of document + 0x00, // section kind + 0x2d, 0x00, 0x00, 0x00, // document size + 0x02, 0x69, 0x6e, 0x73, 0x65, 0x72, 0x74, 0x00, // string "insert" + 0x04, 0x00, 0x00, 0x00, // "foo" length + 0x66, 0x6f, 0x6f, 0x6f, 0x00, // "fooo" + 0x08, 0x6f, 0x72, 0x64, 0x65, 0x72, 0x65, 0x64, 0x00, 0x01, // "ordered" true + 0x02, 0x24, 0x64, 0x62, 0x00, // string "$db" + 0x05, 0x00, 0x00, 0x00, // "test" length + 0x74, 0x65, 0x73, 0x74, 0x00, // "test" + 0x00, // end of document - 0xe2, 0xb7, 0x90, 0x67, // invalid checksum value - }, - msgHeader: &MsgHeader{ - MessageLength: 119, - RequestID: 15, - ResponseTo: 0, - OpCode: OpCodeMsg, - }, - msgBody: &OpMsg{ - FlagBits: OpMsgFlags(OpMsgChecksumPresent), - sections: []OpMsgSection{{ - Kind: 1, - Identifier: "documents", - documents: []*types.Document{must.NotFail(types.NewDocument( - "_id", types.ObjectID{0x63, 0x8c, 0xec, 0x46, 0xaa, 0x77, 0x8b, 0xf3, 0x70, 0x10, 0x54, 0x29}, - "a", float64(3), - ))}, - }, { - documents: []*types.Document{must.NotFail(types.NewDocument( - "insert", "fooo", - "ordered", true, - "$db", "test", - ))}, - }}, - checksum: 1737537506, + 0xe2, 0xb7, 0x90, 0x67, // invalid checksum value + }, + msgHeader: &MsgHeader{ + MessageLength: 119, + RequestID: 15, + ResponseTo: 0, + OpCode: OpCodeMsg, + }, + msgBody: &OpMsg{ + FlagBits: OpMsgFlags(OpMsgChecksumPresent), + sections: []OpMsgSection{ + { + Kind: 1, + Identifier: "documents", + documents: []*types.Document{must.NotFail(types.NewDocument( + "_id", types.ObjectID{0x63, 0x8c, 0xec, 0x46, 0xaa, 0x77, 0x8b, 0xf3, 0x70, 0x10, 0x54, 0x29}, + "a", float64(3), + ))}, + }, + { + documents: []*types.Document{must.NotFail(types.NewDocument( + "insert", "fooo", + "ordered", true, + "$db", "test", + ))}, + }, + }, + checksum: 1737537506, + }, + err: "OP_MSG checksum does not match contents.", }, - err: "OP_MSG checksum does not match contents.", -}} +} func TestMsg(t *testing.T) { t.Parallel() diff --git a/internal/wire/op_query.go b/internal/wire/op_query.go index 5839261537da..068c906736a8 100644 --- a/internal/wire/op_query.go +++ b/internal/wire/op_query.go @@ -15,140 +15,139 @@ package wire import ( - "bufio" - "bytes" "encoding/binary" "encoding/json" - "fmt" - "io" - "github.com/FerretDB/FerretDB/internal/bson" + "github.com/FerretDB/FerretDB/internal/bson2" "github.com/FerretDB/FerretDB/internal/types" "github.com/FerretDB/FerretDB/internal/types/fjson" + "github.com/FerretDB/FerretDB/internal/util/debugbuild" "github.com/FerretDB/FerretDB/internal/util/lazyerrors" "github.com/FerretDB/FerretDB/internal/util/must" ) -// OpQuery is used to query the database for documents in a collection. +// OpQuery is a deprecated request message type. type OpQuery struct { Flags OpQueryFlags FullCollectionName string NumberToSkip int32 NumberToReturn int32 - query *types.Document - returnFieldsSelector *types.Document // may be nil + query bson2.RawDocument + returnFieldsSelector bson2.RawDocument // may be nil } func (query *OpQuery) msgbody() {} -// readFrom composes an OpQuery from a buffered reader. -// It may return ValidationError if the document read from bufr is invalid. -func (query *OpQuery) readFrom(bufr *bufio.Reader) error { - if err := binary.Read(bufr, binary.LittleEndian, &query.Flags); err != nil { - return lazyerrors.Errorf("wire.OpQuery.ReadFrom (binary.Read): %w", err) +// check checks if the query is valid. +func (query *OpQuery) check() error { + if !debugbuild.Enabled { + return nil } - var coll bson.CString - if err := coll.ReadFrom(bufr); err != nil { - return err + if d := query.query; d != nil { + if _, err := d.DecodeDeep(); err != nil { + return lazyerrors.Error(err) + } } - query.FullCollectionName = string(coll) - if err := binary.Read(bufr, binary.LittleEndian, &query.NumberToSkip); err != nil { - return err - } - if err := binary.Read(bufr, binary.LittleEndian, &query.NumberToReturn); err != nil { - return err + if s := query.returnFieldsSelector; s != nil { + if _, err := s.DecodeDeep(); err != nil { + return lazyerrors.Error(err) + } } - var q bson.Document + return nil +} - if err := q.ReadFrom(bufr); err != nil { - return err +// UnmarshalBinaryNocopy implements [MsgBody] interface. +func (query *OpQuery) UnmarshalBinaryNocopy(b []byte) error { + if len(b) < 4 { + return lazyerrors.Errorf("len=%d", len(b)) } - doc := must.NotFail(types.ConvertDocument(&q)) + query.Flags = OpQueryFlags(binary.LittleEndian.Uint32(b[0:4])) + + var err error - if err := validateValue(doc); err != nil { - return newValidationError(fmt.Errorf("wire.OpQuery.ReadFrom: validation failed for %v with: %v", doc, err)) + query.FullCollectionName, err = bson2.DecodeCString(b[4:]) + if err != nil { + return lazyerrors.Error(err) } - query.query = doc + numberLow := 4 + bson2.SizeCString(query.FullCollectionName) + if len(b) < numberLow+8 { + return lazyerrors.Errorf("len=%d, can't unmarshal numbers", len(b)) + } - if _, err := bufr.Peek(1); err == nil { - var r bson.Document - if err := r.ReadFrom(bufr); err != nil { - return err - } + query.NumberToSkip = int32(binary.LittleEndian.Uint32(b[numberLow : numberLow+4])) + query.NumberToReturn = int32(binary.LittleEndian.Uint32(b[numberLow+4 : numberLow+8])) - tr := must.NotFail(types.ConvertDocument(&r)) - query.returnFieldsSelector = tr + l, err := bson2.FindRaw(b[numberLow+8:]) + if err != nil { + return lazyerrors.Error(err) } + query.query = b[numberLow+8 : numberLow+8+l] - return nil -} - -// UnmarshalBinary reads an OpQuery from a byte array. -func (query *OpQuery) UnmarshalBinary(b []byte) error { - br := bytes.NewReader(b) - bufr := bufio.NewReader(br) + selectorLow := numberLow + 8 + l + if len(b) != selectorLow { + l, err = bson2.FindRaw(b[selectorLow:]) + if err != nil { + return lazyerrors.Error(err) + } - if err := query.readFrom(bufr); err != nil { - return lazyerrors.Errorf("wire.OpQuery.UnmarshalBinary: %w", err) + if len(b) != selectorLow+l { + return lazyerrors.Errorf("len=%d, expected=%d", len(b), selectorLow+l) + } + query.returnFieldsSelector = b[selectorLow:] } - if _, err := bufr.Peek(1); err != io.EOF { - return lazyerrors.Errorf("unexpected end of the OpQuery: %v", err) + if err := query.check(); err != nil { + return lazyerrors.Error(err) } return nil } -// MarshalBinary writes an OpQuery to a byte array. +// MarshalBinary implements [MsgBody] interface. func (query *OpQuery) MarshalBinary() ([]byte, error) { - var buf bytes.Buffer - bufw := bufio.NewWriter(&buf) - - if err := binary.Write(bufw, binary.LittleEndian, query.Flags); err != nil { - return nil, err + if err := query.check(); err != nil { + return nil, lazyerrors.Error(err) } - if err := bson.CString(query.FullCollectionName).WriteTo(bufw); err != nil { - return nil, err - } + nameSize := bson2.SizeCString(query.FullCollectionName) + b := make([]byte, 12+nameSize+len(query.query)+len(query.returnFieldsSelector)) - if err := binary.Write(bufw, binary.LittleEndian, query.NumberToSkip); err != nil { - return nil, err - } - if err := binary.Write(bufw, binary.LittleEndian, query.NumberToReturn); err != nil { - return nil, err - } + binary.LittleEndian.PutUint32(b[0:4], uint32(query.Flags)) - if err := must.NotFail(bson.ConvertDocument(query.query)).WriteTo(bufw); err != nil { - return nil, err - } + nameHigh := 4 + nameSize + bson2.EncodeCString(b[4:nameHigh], query.FullCollectionName) - if query.returnFieldsSelector != nil { - if err := must.NotFail(bson.ConvertDocument(query.returnFieldsSelector)).WriteTo(bufw); err != nil { - return nil, err - } - } + binary.LittleEndian.PutUint32(b[nameHigh:nameHigh+4], uint32(query.NumberToSkip)) + binary.LittleEndian.PutUint32(b[nameHigh+4:nameHigh+8], uint32(query.NumberToReturn)) - if err := bufw.Flush(); err != nil { - return nil, err - } + queryHigh := nameHigh + 8 + len(query.query) + copy(b[nameHigh+8:queryHigh], query.query) + copy(b[queryHigh:], query.returnFieldsSelector) - return buf.Bytes(), nil + return b, nil } // Query returns the query document. func (query *OpQuery) Query() *types.Document { - return query.query + if query.query == nil { + return nil + } + + return must.NotFail(query.query.Convert()) } // ReturnFieldsSelector returns the fields selector document (that may be nil). func (query *OpQuery) ReturnFieldsSelector() *types.Document { - return query.returnFieldsSelector + if query.returnFieldsSelector == nil { + return nil + } + + return must.NotFail(query.returnFieldsSelector.Convert()) } // String returns a string representation for logging. @@ -162,10 +161,22 @@ func (query *OpQuery) String() string { "FullCollectionName": query.FullCollectionName, "NumberToSkip": query.NumberToSkip, "NumberToReturn": query.NumberToReturn, - "Query": json.RawMessage(must.NotFail(fjson.Marshal(query.query))), } + + doc, err := query.query.Convert() + if err == nil { + m["Query"] = json.RawMessage(must.NotFail(fjson.Marshal(doc))) + } else { + m["QueryError"] = err.Error() + } + if query.returnFieldsSelector != nil { - m["ReturnFieldsSelector"] = json.RawMessage(must.NotFail(fjson.Marshal(query.returnFieldsSelector))) + doc, err = query.returnFieldsSelector.Convert() + if err == nil { + m["ReturnFieldsSelector"] = json.RawMessage(must.NotFail(fjson.Marshal(doc))) + } else { + m["ReturnFieldsSelectorError"] = err.Error() + } } return string(must.NotFail(json.MarshalIndent(m, "", " "))) diff --git a/internal/wire/op_query_test.go b/internal/wire/op_query_test.go index cd6ea3dbd1c5..07eb0f884b62 100644 --- a/internal/wire/op_query_test.go +++ b/internal/wire/op_query_test.go @@ -22,83 +22,86 @@ import ( "github.com/FerretDB/FerretDB/internal/util/testutil" ) -var queryTestCases = []testCase{{ - name: "handshake1", - headerB: testutil.MustParseDumpFile("testdata", "handshake1_header.hex"), - bodyB: testutil.MustParseDumpFile("testdata", "handshake1_body.hex"), - msgHeader: &MsgHeader{ - MessageLength: 372, - RequestID: 1, - ResponseTo: 0, - OpCode: OpCodeQuery, - }, - msgBody: &OpQuery{ - Flags: 0, - FullCollectionName: "admin.$cmd", - NumberToSkip: 0, - NumberToReturn: -1, - query: must.NotFail(types.NewDocument( - "ismaster", true, - "client", must.NotFail(types.NewDocument( - "driver", must.NotFail(types.NewDocument( - "name", "nodejs", - "version", "4.0.0-beta.6", - )), - "os", must.NotFail(types.NewDocument( - "type", "Darwin", - "name", "darwin", - "architecture", "x64", - "version", "20.6.0", +var queryTestCases = []testCase{ + { + name: "handshake1", + headerB: testutil.MustParseDumpFile("testdata", "handshake1_header.hex"), + bodyB: testutil.MustParseDumpFile("testdata", "handshake1_body.hex"), + msgHeader: &MsgHeader{ + MessageLength: 372, + RequestID: 1, + ResponseTo: 0, + OpCode: OpCodeQuery, + }, + msgBody: &OpQuery{ + Flags: 0, + FullCollectionName: "admin.$cmd", + NumberToSkip: 0, + NumberToReturn: -1, + query: convertDocument(must.NotFail(types.NewDocument( + "ismaster", true, + "client", must.NotFail(types.NewDocument( + "driver", must.NotFail(types.NewDocument( + "name", "nodejs", + "version", "4.0.0-beta.6", + )), + "os", must.NotFail(types.NewDocument( + "type", "Darwin", + "name", "darwin", + "architecture", "x64", + "version", "20.6.0", + )), + "platform", "Node.js v14.17.3, LE (unified)|Node.js v14.17.3, LE (unified)", + "application", must.NotFail(types.NewDocument( + "name", "mongosh 1.0.1", + )), )), - "platform", "Node.js v14.17.3, LE (unified)|Node.js v14.17.3, LE (unified)", - "application", must.NotFail(types.NewDocument( - "name", "mongosh 1.0.1", - )), - )), - "compression", must.NotFail(types.NewArray("none")), - "loadBalanced", false, - )), - returnFieldsSelector: nil, - }, -}, { - name: "handshake3", - headerB: testutil.MustParseDumpFile("testdata", "handshake3_header.hex"), - bodyB: testutil.MustParseDumpFile("testdata", "handshake3_body.hex"), - msgHeader: &MsgHeader{ - MessageLength: 372, - RequestID: 2, - ResponseTo: 0, - OpCode: OpCodeQuery, + "compression", must.NotFail(types.NewArray("none")), + "loadBalanced", false, + ))), + returnFieldsSelector: nil, + }, }, - msgBody: &OpQuery{ - Flags: 0, - FullCollectionName: "admin.$cmd", - NumberToSkip: 0, - NumberToReturn: -1, - query: must.NotFail(types.NewDocument( - "ismaster", true, - "client", must.NotFail(types.NewDocument( - "driver", must.NotFail(types.NewDocument( - "name", "nodejs", - "version", "4.0.0-beta.6", + { + name: "handshake3", + headerB: testutil.MustParseDumpFile("testdata", "handshake3_header.hex"), + bodyB: testutil.MustParseDumpFile("testdata", "handshake3_body.hex"), + msgHeader: &MsgHeader{ + MessageLength: 372, + RequestID: 2, + ResponseTo: 0, + OpCode: OpCodeQuery, + }, + msgBody: &OpQuery{ + Flags: 0, + FullCollectionName: "admin.$cmd", + NumberToSkip: 0, + NumberToReturn: -1, + query: convertDocument(must.NotFail(types.NewDocument( + "ismaster", true, + "client", must.NotFail(types.NewDocument( + "driver", must.NotFail(types.NewDocument( + "name", "nodejs", + "version", "4.0.0-beta.6", + )), + "os", must.NotFail(types.NewDocument( + "type", "Darwin", + "name", "darwin", + "architecture", "x64", + "version", "20.6.0", + )), + "platform", "Node.js v14.17.3, LE (unified)|Node.js v14.17.3, LE (unified)", + "application", must.NotFail(types.NewDocument( + "name", "mongosh 1.0.1", + )), )), - "os", must.NotFail(types.NewDocument( - "type", "Darwin", - "name", "darwin", - "architecture", "x64", - "version", "20.6.0", - )), - "platform", "Node.js v14.17.3, LE (unified)|Node.js v14.17.3, LE (unified)", - "application", must.NotFail(types.NewDocument( - "name", "mongosh 1.0.1", - )), - )), - "compression", must.NotFail(types.NewArray("none")), - "loadBalanced", false, - )), - returnFieldsSelector: nil, + "compression", must.NotFail(types.NewArray("none")), + "loadBalanced", false, + ))), + returnFieldsSelector: nil, + }, }, -}} +} func TestQuery(t *testing.T) { t.Parallel() diff --git a/internal/wire/op_reply.go b/internal/wire/op_reply.go index 659e4ad865f5..4fbab6257316 100644 --- a/internal/wire/op_reply.go +++ b/internal/wire/op_reply.go @@ -15,121 +15,110 @@ package wire import ( - "bufio" - "bytes" "encoding/binary" "encoding/json" - "io" - "github.com/FerretDB/FerretDB/internal/bson" + "github.com/FerretDB/FerretDB/internal/bson2" "github.com/FerretDB/FerretDB/internal/types" "github.com/FerretDB/FerretDB/internal/types/fjson" + "github.com/FerretDB/FerretDB/internal/util/debugbuild" "github.com/FerretDB/FerretDB/internal/util/lazyerrors" "github.com/FerretDB/FerretDB/internal/util/must" ) -const maxNumberReturned = 1000 - -// OpReply is a message sent by the MongoDB database in response to an OpQuery. +// OpReply is a deprecated response message type. +// +// Only up to one returned document is supported. type OpReply struct { - ResponseFlags OpReplyFlags - CursorID int64 - StartingFrom int32 - NumberReturned int32 - documents []*types.Document + ResponseFlags OpReplyFlags + CursorID int64 + StartingFrom int32 + document bson2.RawDocument } func (reply *OpReply) msgbody() {} -func (reply *OpReply) readFrom(bufr *bufio.Reader) error { - if err := binary.Read(bufr, binary.LittleEndian, &reply.ResponseFlags); err != nil { - return lazyerrors.Errorf("wire.OpReply.ReadFrom (binary.Read): %w", err) - } - if err := binary.Read(bufr, binary.LittleEndian, &reply.CursorID); err != nil { - return lazyerrors.Errorf("wire.OpReply.ReadFrom (binary.Read): %w", err) - } - if err := binary.Read(bufr, binary.LittleEndian, &reply.StartingFrom); err != nil { - return lazyerrors.Errorf("wire.OpReply.ReadFrom (binary.Read): %w", err) - } - if err := binary.Read(bufr, binary.LittleEndian, &reply.NumberReturned); err != nil { - return lazyerrors.Errorf("wire.OpReply.ReadFrom (binary.Read): %w", err) - } - - if n := reply.NumberReturned; n < 0 || n > maxNumberReturned { - return lazyerrors.Errorf("wire.OpReply.ReadFrom: invalid NumberReturned %d", n) +// check checks if the reply is valid. +func (reply *OpReply) check() error { + if !debugbuild.Enabled { + return nil } - reply.documents = make([]*types.Document, reply.NumberReturned) - for i := int32(0); i < reply.NumberReturned; i++ { - var doc bson.Document - if err := doc.ReadFrom(bufr); err != nil { - return lazyerrors.Errorf("wire.OpReply.ReadFrom: %w", err) + if d := reply.document; d != nil { + if _, err := d.DecodeDeep(); err != nil { + return lazyerrors.Error(err) } - reply.documents[i] = must.NotFail(types.ConvertDocument(&doc)) } return nil } -// UnmarshalBinary reads an OpReply from a byte array. -func (reply *OpReply) UnmarshalBinary(b []byte) error { - br := bytes.NewReader(b) - bufr := bufio.NewReader(br) +// UnmarshalBinaryNocopy implements [MsgBody] interface. +func (reply *OpReply) UnmarshalBinaryNocopy(b []byte) error { + if len(b) < 20 { + return lazyerrors.Errorf("len=%d", len(b)) + } + + reply.ResponseFlags = OpReplyFlags(binary.LittleEndian.Uint32(b[0:4])) + reply.CursorID = int64(binary.LittleEndian.Uint64(b[4:12])) + reply.StartingFrom = int32(binary.LittleEndian.Uint32(b[12:16])) + numberReturned := int32(binary.LittleEndian.Uint32(b[16:20])) + reply.document = b[20:] + + if numberReturned < 0 || numberReturned > 1 { + return lazyerrors.Errorf("numberReturned=%d", numberReturned) + } + + if len(reply.document) == 0 { + reply.document = nil + } - if err := reply.readFrom(bufr); err != nil { - return lazyerrors.Errorf("wire.OpReply.UnmarshalBinary: %w", err) + if (numberReturned == 0) != (reply.document == nil) { + return lazyerrors.Errorf("numberReturned=%d, document=%v", numberReturned, reply.document) } - if _, err := bufr.Peek(1); err != io.EOF { - return lazyerrors.Errorf("unexpected end of the OpReply: %v", err) + if err := reply.check(); err != nil { + return lazyerrors.Error(err) } return nil } -// MarshalBinary writes an OpReply to a byte array. +// MarshalBinary implements [MsgBody] interface. func (reply *OpReply) MarshalBinary() ([]byte, error) { - if l := len(reply.documents); int32(l) != reply.NumberReturned { - return nil, lazyerrors.Errorf("wire.OpReply.MarshalBinary: len(Documents)=%d, NumberReturned=%d", l, reply.NumberReturned) + if err := reply.check(); err != nil { + return nil, lazyerrors.Error(err) } - var buf bytes.Buffer - bufw := bufio.NewWriter(&buf) + b := make([]byte, 20+len(reply.document)) - if err := binary.Write(bufw, binary.LittleEndian, reply.ResponseFlags); err != nil { - return nil, lazyerrors.Errorf("wire.OpReply.MarshalBinary (binary.Write): %w", err) - } - if err := binary.Write(bufw, binary.LittleEndian, reply.CursorID); err != nil { - return nil, lazyerrors.Errorf("wire.OpReply.MarshalBinary (binary.Write): %w", err) - } - if err := binary.Write(bufw, binary.LittleEndian, reply.StartingFrom); err != nil { - return nil, lazyerrors.Errorf("wire.OpReply.MarshalBinary (binary.Write): %w", err) - } - if err := binary.Write(bufw, binary.LittleEndian, reply.NumberReturned); err != nil { - return nil, lazyerrors.Errorf("wire.OpReply.UnmarshalBinary (binary.Write): %w", err) - } - - for _, doc := range reply.documents { - if err := must.NotFail(bson.ConvertDocument(doc)).WriteTo(bufw); err != nil { - return nil, lazyerrors.Errorf("wire.OpReply.MarshalBinary: %w", err) - } - } + binary.LittleEndian.PutUint32(b[0:4], uint32(reply.ResponseFlags)) + binary.LittleEndian.PutUint64(b[4:12], uint64(reply.CursorID)) + binary.LittleEndian.PutUint32(b[12:16], uint32(reply.StartingFrom)) - if err := bufw.Flush(); err != nil { - return nil, err + if reply.document == nil { + binary.LittleEndian.PutUint32(b[16:20], uint32(0)) + } else { + binary.LittleEndian.PutUint32(b[16:20], uint32(1)) + copy(b[20:], reply.document) } - return buf.Bytes(), nil + return b, nil } -// Documents returns reply documents. -func (reply *OpReply) Documents() []*types.Document { - return reply.documents +// Document returns reply document. +func (reply *OpReply) Document() (*types.Document, error) { + if reply.document == nil { + return nil, nil + } + + return reply.document.Convert() } // SetDocument sets reply document. func (reply *OpReply) SetDocument(doc *types.Document) { - reply.documents = []*types.Document{doc} + d := must.NotFail(bson2.ConvertDocument(doc)) + reply.document = must.NotFail(d.Encode()) } // String returns a string representation for logging. @@ -139,18 +128,23 @@ func (reply *OpReply) String() string { } m := map[string]any{ - "ResponseFlags": reply.ResponseFlags, - "CursorID": reply.CursorID, - "StartingFrom": reply.StartingFrom, - "NumberReturned": reply.NumberReturned, + "ResponseFlags": reply.ResponseFlags, + "CursorID": reply.CursorID, + "StartingFrom": reply.StartingFrom, } - docs := make([]json.RawMessage, len(reply.documents)) - for i, d := range reply.documents { - docs[i] = json.RawMessage(must.NotFail(fjson.Marshal(d))) - } + if reply.document == nil { + m["NumberReturned"] = 0 + } else { + m["NumberReturned"] = 1 - m["Documents"] = docs + doc, err := reply.document.Convert() + if err == nil { + m["Documents"] = json.RawMessage(must.NotFail(fjson.Marshal(doc))) + } else { + m["DocumentError"] = err.Error() + } + } return string(must.NotFail(json.MarshalIndent(m, "", " "))) } diff --git a/internal/wire/op_reply_test.go b/internal/wire/op_reply_test.go index 875ae4b198b7..43c8750926f7 100644 --- a/internal/wire/op_reply_test.go +++ b/internal/wire/op_reply_test.go @@ -23,73 +23,74 @@ import ( "github.com/FerretDB/FerretDB/internal/util/testutil" ) -var replyTestCases = []testCase{{ - name: "handshake2", - headerB: testutil.MustParseDumpFile("testdata", "handshake2_header.hex"), - bodyB: testutil.MustParseDumpFile("testdata", "handshake2_body.hex"), - msgHeader: &MsgHeader{ - MessageLength: 319, - RequestID: 290, - ResponseTo: 1, - OpCode: OpCodeReply, +var replyTestCases = []testCase{ + { + name: "handshake2", + headerB: testutil.MustParseDumpFile("testdata", "handshake2_header.hex"), + bodyB: testutil.MustParseDumpFile("testdata", "handshake2_body.hex"), + msgHeader: &MsgHeader{ + MessageLength: 319, + RequestID: 290, + ResponseTo: 1, + OpCode: OpCodeReply, + }, + msgBody: &OpReply{ + ResponseFlags: OpReplyFlags(OpReplyAwaitCapable), + CursorID: 0, + StartingFrom: 0, + document: convertDocument(must.NotFail(types.NewDocument( + "ismaster", true, + "topologyVersion", must.NotFail(types.NewDocument( + "processId", types.ObjectID{0x60, 0xfb, 0xed, 0x53, 0x71, 0xfe, 0x1b, 0xae, 0x70, 0x33, 0x95, 0x05}, + "counter", int64(0), + )), + "maxBsonObjectSize", int32(16777216), + "maxMessageSizeBytes", int32(48000000), + "maxWriteBatchSize", int32(100000), + "localTime", time.Date(2021, time.July, 24, 12, 54, 41, 571000000, time.UTC).Local(), + "logicalSessionTimeoutMinutes", int32(30), + "connectionId", int32(28), + "minWireVersion", int32(0), + "maxWireVersion", int32(13), + "readOnly", false, + "ok", float64(1), + ))), + }, }, - msgBody: &OpReply{ - ResponseFlags: OpReplyFlags(OpReplyAwaitCapable), - CursorID: 0, - StartingFrom: 0, - NumberReturned: 1, - documents: []*types.Document{must.NotFail(types.NewDocument( - "ismaster", true, - "topologyVersion", must.NotFail(types.NewDocument( - "processId", types.ObjectID{0x60, 0xfb, 0xed, 0x53, 0x71, 0xfe, 0x1b, 0xae, 0x70, 0x33, 0x95, 0x05}, - "counter", int64(0), - )), - "maxBsonObjectSize", int32(16777216), - "maxMessageSizeBytes", int32(48000000), - "maxWriteBatchSize", int32(100000), - "localTime", time.Date(2021, time.July, 24, 12, 54, 41, 571000000, time.UTC).Local(), - "logicalSessionTimeoutMinutes", int32(30), - "connectionId", int32(28), - "minWireVersion", int32(0), - "maxWireVersion", int32(13), - "readOnly", false, - "ok", float64(1), - ))}, + { + name: "handshake4", + headerB: testutil.MustParseDumpFile("testdata", "handshake4_header.hex"), + bodyB: testutil.MustParseDumpFile("testdata", "handshake4_body.hex"), + msgHeader: &MsgHeader{ + MessageLength: 319, + RequestID: 291, + ResponseTo: 2, + OpCode: OpCodeReply, + }, + msgBody: &OpReply{ + ResponseFlags: OpReplyFlags(OpReplyAwaitCapable), + CursorID: 0, + StartingFrom: 0, + document: convertDocument(must.NotFail(types.NewDocument( + "ismaster", true, + "topologyVersion", must.NotFail(types.NewDocument( + "processId", types.ObjectID{0x60, 0xfb, 0xed, 0x53, 0x71, 0xfe, 0x1b, 0xae, 0x70, 0x33, 0x95, 0x05}, + "counter", int64(0), + )), + "maxBsonObjectSize", int32(16777216), + "maxMessageSizeBytes", int32(48000000), + "maxWriteBatchSize", int32(100000), + "localTime", time.Date(2021, time.July, 24, 12, 54, 41, 592000000, time.UTC).Local(), + "logicalSessionTimeoutMinutes", int32(30), + "connectionId", int32(29), + "minWireVersion", int32(0), + "maxWireVersion", int32(13), + "readOnly", false, + "ok", float64(1), + ))), + }, }, -}, { - name: "handshake4", - headerB: testutil.MustParseDumpFile("testdata", "handshake4_header.hex"), - bodyB: testutil.MustParseDumpFile("testdata", "handshake4_body.hex"), - msgHeader: &MsgHeader{ - MessageLength: 319, - RequestID: 291, - ResponseTo: 2, - OpCode: OpCodeReply, - }, - msgBody: &OpReply{ - ResponseFlags: OpReplyFlags(OpReplyAwaitCapable), - CursorID: 0, - StartingFrom: 0, - NumberReturned: 1, - documents: []*types.Document{must.NotFail(types.NewDocument( - "ismaster", true, - "topologyVersion", must.NotFail(types.NewDocument( - "processId", types.ObjectID{0x60, 0xfb, 0xed, 0x53, 0x71, 0xfe, 0x1b, 0xae, 0x70, 0x33, 0x95, 0x05}, - "counter", int64(0), - )), - "maxBsonObjectSize", int32(16777216), - "maxMessageSizeBytes", int32(48000000), - "maxWriteBatchSize", int32(100000), - "localTime", time.Date(2021, time.July, 24, 12, 54, 41, 592000000, time.UTC).Local(), - "logicalSessionTimeoutMinutes", int32(30), - "connectionId", int32(29), - "minWireVersion", int32(0), - "maxWireVersion", int32(13), - "readOnly", false, - "ok", float64(1), - ))}, - }, -}} +} func TestReply(t *testing.T) { t.Parallel() diff --git a/internal/wire/wire_test.go b/internal/wire/wire_test.go index b3139b39d2ec..c33ba41f2db7 100644 --- a/internal/wire/wire_test.go +++ b/internal/wire/wire_test.go @@ -25,9 +25,18 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/FerretDB/FerretDB/internal/bson2" + "github.com/FerretDB/FerretDB/internal/types" + "github.com/FerretDB/FerretDB/internal/util/must" "github.com/FerretDB/FerretDB/internal/util/testutil/testtb" ) +// convertDocument converts [*types.Document] to [bson2.RawDocument]. +func convertDocument(doc *types.Document) bson2.RawDocument { + d := must.NotFail(bson2.ConvertDocument(doc)) + return must.NotFail(d.Encode()) +} + // lastErr returns the last error in error chain. func lastErr(err error) error { for { From 653adce5ca47f1fed1413111d07f7d99f8324153 Mon Sep 17 00:00:00 2001 From: Artyom Fadeyev <70910148+fadyat@users.noreply.github.com> Date: Thu, 22 Feb 2024 16:42:58 +0300 Subject: [PATCH 12/13] Make logger configurable in the embedded `ferretdb` package (#4028) Co-authored-by: noisersup --- ferretdb/ferretdb.go | 47 ++++++++++++++++++++------------ internal/bson2/slog.go | 2 ++ internal/util/logging/logging.go | 9 ++++-- 3 files changed, 38 insertions(+), 20 deletions(-) diff --git a/ferretdb/ferretdb.go b/ferretdb/ferretdb.go index e7198f348ccf..9e8421604f12 100644 --- a/ferretdb/ferretdb.go +++ b/ferretdb/ferretdb.go @@ -23,6 +23,7 @@ import ( "errors" "fmt" "net/url" + "sync" "go.uber.org/zap" @@ -38,6 +39,9 @@ import ( type Config struct { Listener ListenerConfig + // Logger to use; if nil, it uses the default global logger. + Logger *zap.Logger + // Handler to use; one of `postgresql` or `sqlite`. Handler string @@ -113,8 +117,15 @@ func New(config *Config) (*FerretDB, error) { metrics := connmetrics.NewListenerMetrics() + log := config.Logger + if log == nil { + log = getGlobalLogger() + } else { + log = logging.WithHooks(log) + } + h, closeBackend, err := registry.NewHandler(config.Handler, ®istry.NewHandlerOpts{ - Logger: logger, + Logger: log, ConnMetrics: metrics.ConnMetrics, StateProvider: sp, TCPHost: config.Listener.TCP, @@ -143,7 +154,7 @@ func New(config *Config) (*FerretDB, error) { Mode: clientconn.NormalMode, Metrics: metrics, Handler: h, - Logger: logger, + Logger: log, }) return &FerretDB{ @@ -214,21 +225,23 @@ func (f *FerretDB) MongoDBURI() string { return u.String() } -// logger is a global logger used by FerretDB. -// -// TODO https://github.com/FerretDB/FerretDB/issues/4014 -var logger *zap.Logger +var ( + loggerOnce sync.Once + logger *zap.Logger +) -// Initialize the global logger there to avoid creating too many issues for zap users that initialize it in their -// `main()` functions. It is still not a full solution; eventually, we should remove the usage of the global logger. -// -// TODO https://github.com/FerretDB/FerretDB/issues/4014 -func init() { - l := zap.ErrorLevel - if version.Get().DebugBuild { - l = zap.DebugLevel - } +// getGlobalLogger retrieves or creates a global logger using +// a loggerOnce to ensure it is created only once. +func getGlobalLogger() *zap.Logger { + loggerOnce.Do(func() { + level := zap.ErrorLevel + if version.Get().DebugBuild { + level = zap.DebugLevel + } + + logging.Setup(level, "console", "") + logger = zap.L() + }) - logging.Setup(l, "console", "") - logger = zap.L() + return logger } diff --git a/internal/bson2/slog.go b/internal/bson2/slog.go index 7050cabe7d58..08458bde3768 100644 --- a/internal/bson2/slog.go +++ b/internal/bson2/slog.go @@ -41,6 +41,7 @@ func slogValue(v any) slog.Value { if v == nil { return slog.StringValue("RawDocument(nil)") } + return slog.StringValue("RawDocument(" + strconv.Itoa(len(v)) + " bytes)") case *Array: @@ -56,6 +57,7 @@ func slogValue(v any) slog.Value { if v == nil { return slog.StringValue("RawArray(nil)") } + return slog.StringValue("RawArray(" + strconv.Itoa(len(v)) + " bytes)") default: diff --git a/internal/util/logging/logging.go b/internal/util/logging/logging.go index 9b68eb527a45..76761348d1a9 100644 --- a/internal/util/logging/logging.go +++ b/internal/util/logging/logging.go @@ -80,12 +80,15 @@ func Setup(level zapcore.Level, encoding, uuid string) { log.Fatal(err) } - logger = logger.WithOptions(zap.Hooks(func(entry zapcore.Entry) error { + SetupWithZapLogger(WithHooks(logger)) +} + +// WithHooks returns a logger with recent entries hooks. +func WithHooks(logger *zap.Logger) *zap.Logger { + return logger.WithOptions(zap.Hooks(func(entry zapcore.Entry) error { RecentEntries.append(&entry) return nil })) - - SetupWithZapLogger(logger) } // setupSlog initializes slog logging with a given level. From 7e22bf03f6b6dcf79bbf0c944ba0c6e3f5a3a847 Mon Sep 17 00:00:00 2001 From: Henrique Vicente Date: Fri, 23 Feb 2024 09:09:23 +0100 Subject: [PATCH 13/13] wip --- cmd/envtool/tests_test.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/cmd/envtool/tests_test.go b/cmd/envtool/tests_test.go index dc502bfbf66b..f0ad3f933592 100644 --- a/cmd/envtool/tests_test.go +++ b/cmd/envtool/tests_test.go @@ -106,7 +106,11 @@ func TestRunGoTest(t *testing.T) { logger, err := makeTestLogger(&actual) require.NoError(t, err) - err = runGoTest(context.TODO(), []string{"./testdata", "-count=1", "-run=TestWithSubtest/Third"}, 1, false, logger.Sugar()) + err = runGoTest(context.TODO(), []string{ + "./testdata", + "-count=1", + "-run=TestWithSubtest/Third", + }, 1, false, logger.Sugar()) require.NoError(t, err) expected := []string{ @@ -494,7 +498,7 @@ func TestListTestFuncsWithSkip(t *testing.T) { assert.Nil(t, err) lastRes, lastSkip, err := shardTestFuncs(3, 3, testFuncs) - assert.Equal(t, []string{"TestNormal1", "TestSubtest]"}, lastRes) + assert.Equal(t, []string{"TestNormal1", "TestWithSubtest"}, lastRes) assert.Equal(t, []string{"TestError1", "TestError2", "TestNormal2", "TestPanic1"}, lastSkip) require.NoError(t, err) }