From b83a5b13421e9824755230b9707ac662c4a7b6ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ond=C5=99ej=20Benkovsk=C3=BD?= Date: Mon, 20 Nov 2023 22:52:28 +0100 Subject: [PATCH] add more cases to cover IO errors --- cmd/benchmark.go | 12 +- cmd/benchmark_test.go | 287 +++++++++++++++++++++++++++++------------- cmd/doq_test.go | 72 ++++++++--- cmd/result.go | 8 +- 4 files changed, 264 insertions(+), 115 deletions(-) diff --git a/cmd/benchmark.go b/cmd/benchmark.go index ee9220e..38eeb60 100644 --- a/cmd/benchmark.go +++ b/cmd/benchmark.go @@ -451,14 +451,10 @@ func (b *Benchmark) Run(ctx context.Context) ([]*ResultStats, error) { start := time.Now() reqTimeoutCtx, cancel := context.WithTimeout(ctx, b.RequestTimeout) - if resp, err = query(reqTimeoutCtx, b.Server, &m); err != nil { - cancel() - st.Counters.IOError++ - st.Errors = append(st.Errors, err) - } else { - cancel() - st.record(&m, resp, start, time.Since(start)) - } + resp, err = query(reqTimeoutCtx, b.Server, &m) + cancel() + st.record(&m, resp, err, start, time.Since(start)) + if incrementBar { bar.Add(1) } diff --git a/cmd/benchmark_test.go b/cmd/benchmark_test.go index 2abba6f..ff6d98d 100644 --- a/cmd/benchmark_test.go +++ b/cmd/benchmark_test.go @@ -69,7 +69,7 @@ func Test_do_classic_dns(t *testing.T) { defer cancel() rs, err := bench.Run(ctx) - assert.NoError(t, err, "expected no error from benchmark run") + require.NoError(t, err, "expected no error from benchmark run") assertResult(t, rs) }) } @@ -97,7 +97,7 @@ func Test_do_classic_dns_dnssec(t *testing.T) { defer cancel() rs, err := bench.Run(ctx) - assert.NoError(t, err, "expected no error from benchmark run") + require.NoError(t, err, "expected no error from benchmark run") assertResult(t, rs) for _, r := range rs { assert.Equal(t, r.AuthenticatedDomains, map[string]struct{}{"example.org.": {}}) @@ -130,7 +130,7 @@ func Test_do_classic_dns_edns0(t *testing.T) { defer cancel() rs, err := bench.Run(ctx) - assert.NoError(t, err, "expected no error from benchmark run") + require.NoError(t, err, "expected no error from benchmark run") assertResult(t, rs) } @@ -175,7 +175,7 @@ func Test_do_classic_dns_edns0_ednsopt(t *testing.T) { defer cancel() rs, err := bench.Run(ctx) - assert.NoError(t, err, "expected no error from benchmark run") + require.NoError(t, err, "expected no error from benchmark run") assertResult(t, rs) } @@ -216,7 +216,7 @@ func Test_do_doh_post(t *testing.T) { defer cancel() rs, err := bench.Run(ctx) - assert.NoError(t, err, "expected no error from benchmark run") + require.NoError(t, err, "expected no error from benchmark run") assertResult(t, rs) } @@ -259,7 +259,7 @@ func Test_do_doh_get(t *testing.T) { defer cancel() rs, err := bench.Run(ctx) - assert.NoError(t, err, "expected no error from benchmark run") + require.NoError(t, err, "expected no error from benchmark run") assertResult(t, rs) } @@ -282,12 +282,10 @@ func Test_do_probability(t *testing.T) { defer cancel() rs, err := bench.Run(ctx) - assert.NoError(t, err, "expected no error from benchmark run") - assert.Len(t, rs, 2, "Run(ctx) rstats") - rs0 := rs[0] - rs1 := rs[1] - assert.Equal(t, int64(0), rs0.Counters.Total, "Run(ctx) total counter") - assert.Equal(t, int64(0), rs1.Counters.Total, "Run(ctx) total counter") + require.NoError(t, err, "expected no error from benchmark run") + require.Len(t, rs, 2, "expected results from two workers") + assert.Equal(t, int64(0), rs[0].Counters.Total, "Run(ctx) total counter") + assert.Equal(t, int64(0), rs[1].Counters.Total, "Run(ctx) total counter") } func Test_download_external_datasource_using_http(t *testing.T) { @@ -331,7 +329,7 @@ func Test_download_external_datasource_using_http(t *testing.T) { defer cancel() rs, err := bench.Run(ctx) - assert.NoError(t, err, "expected no error from benchmark run") + require.NoError(t, err, "expected no error from benchmark run") assertResult(t, rs) } @@ -361,7 +359,7 @@ func Test_download_external_datasource_using_http_not_available(t *testing.T) { defer cancel() _, err := bench.Run(ctx) - assert.Error(t, err, "expected error from benchmark run") + require.Error(t, err, "expected error from benchmark run") } func Test_download_external_datasource_using_http_wrong_response(t *testing.T) { @@ -390,7 +388,7 @@ func Test_download_external_datasource_using_http_wrong_response(t *testing.T) { defer cancel() _, err := bench.Run(ctx) - assert.Error(t, err, "expected error from benchmark run") + require.Error(t, err, "expected error from benchmark run") } func Test_do_classic_dns_with_duration(t *testing.T) { @@ -421,7 +419,7 @@ func Test_do_classic_dns_with_duration(t *testing.T) { defer cancel() rs, err := bench.Run(ctx) - assert.NoError(t, err, "expected no error from benchmark run") + require.NoError(t, err, "expected no error from benchmark run") assert.GreaterOrEqual(t, rs[0].Counters.Total, int64(1), "there should be atleast one execution") } @@ -446,7 +444,7 @@ func Test_duration_and_count_specified_at_once(t *testing.T) { defer cancel() _, err := bench.Run(ctx) - assert.Error(t, err, "expected error from benchmark run") + require.Error(t, err, "expected error from benchmark run") } func Test_do_classic_dns_default_count(t *testing.T) { @@ -480,14 +478,22 @@ func Test_do_classic_dns_default_count(t *testing.T) { defer cancel() rs, err := bench.Run(ctx) - assert.NoError(t, err, "expected no error from benchmark run") - assert.Len(t, rs, 1, "Run(ctx) rstats") + require.NoError(t, err, "expected no error from benchmark run") + require.Len(t, rs, 1, "expected results from one worker") assert.Equal(t, int64(1), rs[0].Counters.Total) assert.Equal(t, int64(1), rs[0].Counters.Success) } func Test_do_doq(t *testing.T) { - server := doqServer{} + server := newDoQServer(func(r *dns.Msg) *dns.Msg { + ret := new(dns.Msg) + ret.SetReply(r) + ret.Answer = append(ret.Answer, A("example.org. IN A 127.0.0.1")) + + // wait some time to actually have some observable duration + time.Sleep(time.Millisecond * 500) + return ret + }) server.start() defer server.stop() @@ -498,7 +504,7 @@ func Test_do_doq(t *testing.T) { defer cancel() rs, err := bench.Run(ctx) - assert.NoError(t, err, "expected no error from benchmark run") + require.NoError(t, err, "expected no error from benchmark run") assertResult(t, rs) } @@ -540,72 +546,10 @@ func Test_do_dot(t *testing.T) { defer cancel() rs, err := bench.Run(ctx) - assert.NoError(t, err, "expected no error from benchmark run") + require.NoError(t, err, "expected no error from benchmark run") assertResult(t, rs) } -func assertResult(t *testing.T, rs []*ResultStats) { - if assert.Len(t, rs, 2, "Run(ctx) rstats") { - rs0 := rs[0] - rs1 := rs[1] - assertResultStats(t, rs0) - assertResultStats(t, rs1) - assertTimings(t, rs0) - assertTimings(t, rs1) - } -} - -func assertResultStats(t *testing.T, rs *ResultStats) { - assert.NotNil(t, rs.Hist, "Run(ctx) rstats histogram") - - if assert.NotNil(t, rs.Codes, "Run(ctx) rstats codes") { - assert.Equal(t, int64(2), rs.Codes[0], "Run(ctx) rstats codes NOERROR, state:"+fmt.Sprint(rs.Codes)) - } - - if assert.NotNil(t, rs.Qtypes, "Run(ctx) rstats qtypes") { - assert.Equal(t, int64(1), rs.Qtypes[dns.TypeToString[dns.TypeA]], "Run(ctx) rstats qtypes A, state:"+fmt.Sprint(rs.Codes)) - assert.Equal(t, int64(1), rs.Qtypes[dns.TypeToString[dns.TypeAAAA]], "Run(ctx) rstats qtypes AAAA, state:"+fmt.Sprint(rs.Codes)) - } - - assert.Equal(t, int64(2), rs.Counters.Total, "Run(ctx) total counter") - assert.Zero(t, rs.Counters.IOError, "error counter") - assert.Equal(t, int64(2), rs.Counters.Success, "Run(ctx) success counter") - assert.Zero(t, rs.Counters.IDmismatch, "Run(ctx) mismatch counter") - assert.Zero(t, rs.Counters.Truncated, "Run(ctx) truncated counter") -} - -func assertTimings(t *testing.T, rs *ResultStats) { - if assert.Len(t, rs.Timings, 2, "Run(ctx) rstats timings") { - t0 := rs.Timings[0] - t1 := rs.Timings[1] - assert.NotZero(t, t0.Duration, "Run(ctx) rstats timings duration") - assert.NotZero(t, t0.Start, "Run(ctx) rstats timings start") - assert.NotZero(t, t1.Duration, "Run(ctx) rstats timings duration") - assert.NotZero(t, t1.Start, "Run(ctx) rstats timings start") - } -} - -func createBenchmark(server string, tcp bool, prob float64) Benchmark { - return Benchmark{ - Queries: []string{"example.org"}, - Types: []string{"A", "AAAA"}, - Server: server, - TCP: tcp, - Concurrency: 2, - Count: 1, - Probability: prob, - WriteTimeout: 1 * time.Second, - ReadTimeout: 3 * time.Second, - ConnectTimeout: 1 * time.Second, - RequestTimeout: 5 * time.Second, - Rcodes: true, - Recurse: true, - } -} - -// A returns an A record from rr. It panics on errors. -func A(rr string) *dns.A { r, _ := dns.NewRR(rr); return r.(*dns.A) } - func TestBenchmark_prepare(t *testing.T) { tests := []struct { name string @@ -760,8 +704,8 @@ func Test_global_ratelimit(t *testing.T) { defer cancel() rs, err := bench.Run(ctx) - assert.NoError(t, err, "expected no error from benchmark run") - assert.Len(t, rs, 2) + require.NoError(t, err, "expected no error from benchmark run") + require.Len(t, rs, 2, "expected results from two workers") // assert that total queries is 5 with +-1 precision, because benchmark cancellation based on duration is not that precise // and one worker can start the resolution before cancelling assert.InDelta(t, int64(5), rs[0].Counters.Total+rs[1].Counters.Total, 1.0) @@ -801,10 +745,177 @@ func Test_worker_ratelimit(t *testing.T) { defer cancel() rs, err := bench.Run(ctx) - assert.NoError(t, err, "expected no error from benchmark run") - assert.Len(t, rs, 2) + require.NoError(t, err, "expected no error from benchmark run") + require.Len(t, rs, 2, "expected results from two workers") + // assert that total queries is 10 with +-2 precision, // because benchmark cancellation based on duration is not that precise // and each worker can start the resolution before cancelling assert.InDelta(t, int64(10), rs[0].Counters.Total+rs[1].Counters.Total, 2.0) } + +func Test_do_classic_dns_error(t *testing.T) { + s := NewServer("udp", nil, func(w dns.ResponseWriter, r *dns.Msg) { + }) + defer s.Close() + + bench := createBenchmark(s.Addr, false, 1) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + rs, err := bench.Run(ctx) + + require.NoError(t, err, "expected no error from benchmark run") + require.Len(t, rs, 2, "expected results from two workers") + + assert.Equal(t, rs[0].Counters.Total, int64(2), "there should be executions") + assert.Equal(t, rs[0].Counters.IOError, int64(2), "there should be errors") + assert.Equal(t, rs[1].Counters.Total, int64(2), "there should be executions") + assert.Equal(t, rs[1].Counters.IOError, int64(2), "there should be errors") +} + +func Test_do_dot_error(t *testing.T) { + cert, err := tls.LoadX509KeyPair("test.crt", "test.key") + require.NoError(t, err) + + certs, err := os.ReadFile("test.crt") + require.NoError(t, err) + + pool, err := x509.SystemCertPool() + require.NoError(t, err) + + pool.AppendCertsFromPEM(certs) + config := tls.Config{ + ServerName: "localhost", + RootCAs: pool, + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, + } + + server := NewServer("tcp-tls", &config, func(w dns.ResponseWriter, r *dns.Msg) { + }) + defer server.Close() + + bench := createBenchmark(server.Addr, false, 1) + bench.Insecure = true + bench.DOT = true + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + rs, err := bench.Run(ctx) + + require.NoError(t, err, "expected no error from benchmark run") + require.Len(t, rs, 2, "expected results from two workers") + + assert.Equal(t, rs[0].Counters.Total, int64(2), "there should be executions") + assert.Equal(t, rs[0].Counters.IOError, int64(2), "there should be errors") + assert.Equal(t, rs[1].Counters.Total, int64(2), "there should be executions") + assert.Equal(t, rs[1].Counters.IOError, int64(2), "there should be errors") +} + +func Test_do_doh_error(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(500) + })) + defer ts.Close() + + bench := createBenchmark(ts.URL, true, 1) + bench.DohMethod = post + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + rs, err := bench.Run(ctx) + + require.NoError(t, err, "expected no error from benchmark run") + require.Len(t, rs, 2, "expected results from two workers") + + assert.Equal(t, rs[0].Counters.Total, int64(2), "there should be executions") + assert.Equal(t, rs[0].Counters.IOError, int64(2), "there should be errors") + assert.Equal(t, rs[1].Counters.Total, int64(2), "there should be executions") + assert.Equal(t, rs[1].Counters.IOError, int64(2), "there should be errors") +} + +func Test_do_doq_error(t *testing.T) { + server := newDoQServer(func(r *dns.Msg) *dns.Msg { + return nil + }) + server.start() + defer server.stop() + + bench := createBenchmark("quic://"+server.addr, true, 1) + bench.Insecure = true + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + rs, err := bench.Run(ctx) + + require.NoError(t, err, "expected no error from benchmark run") + require.Len(t, rs, 2, "expected results from two workers") + + assert.Equal(t, rs[0].Counters.Total, int64(2), "there should be executions") + assert.Equal(t, rs[0].Counters.IOError, int64(2), "there should be errors") + assert.Equal(t, rs[1].Counters.Total, int64(2), "there should be executions") + assert.Equal(t, rs[1].Counters.IOError, int64(2), "there should be errors") +} + +func assertResult(t *testing.T, rs []*ResultStats) { + if assert.Len(t, rs, 2, "Run(ctx) rstats") { + rs0 := rs[0] + rs1 := rs[1] + assertResultStats(t, rs0) + assertResultStats(t, rs1) + assertTimings(t, rs0) + assertTimings(t, rs1) + } +} + +func assertResultStats(t *testing.T, rs *ResultStats) { + assert.NotNil(t, rs.Hist, "Run(ctx) rstats histogram") + + if assert.NotNil(t, rs.Codes, "Run(ctx) rstats codes") { + assert.Equal(t, int64(2), rs.Codes[0], "Run(ctx) rstats codes NOERROR, state:"+fmt.Sprint(rs.Codes)) + } + + if assert.NotNil(t, rs.Qtypes, "Run(ctx) rstats qtypes") { + assert.Equal(t, int64(1), rs.Qtypes[dns.TypeToString[dns.TypeA]], "Run(ctx) rstats qtypes A, state:"+fmt.Sprint(rs.Codes)) + assert.Equal(t, int64(1), rs.Qtypes[dns.TypeToString[dns.TypeAAAA]], "Run(ctx) rstats qtypes AAAA, state:"+fmt.Sprint(rs.Codes)) + } + + assert.Equal(t, int64(2), rs.Counters.Total, "Run(ctx) total counter") + assert.Zero(t, rs.Counters.IOError, "error counter") + assert.Equal(t, int64(2), rs.Counters.Success, "Run(ctx) success counter") + assert.Zero(t, rs.Counters.IDmismatch, "Run(ctx) mismatch counter") + assert.Zero(t, rs.Counters.Truncated, "Run(ctx) truncated counter") +} + +func assertTimings(t *testing.T, rs *ResultStats) { + if assert.Len(t, rs.Timings, 2, "Run(ctx) rstats timings") { + t0 := rs.Timings[0] + t1 := rs.Timings[1] + assert.NotZero(t, t0.Duration, "Run(ctx) rstats timings duration") + assert.NotZero(t, t0.Start, "Run(ctx) rstats timings start") + assert.NotZero(t, t1.Duration, "Run(ctx) rstats timings duration") + assert.NotZero(t, t1.Start, "Run(ctx) rstats timings start") + } +} + +func createBenchmark(server string, tcp bool, prob float64) Benchmark { + return Benchmark{ + Queries: []string{"example.org"}, + Types: []string{"A", "AAAA"}, + Server: server, + TCP: tcp, + Concurrency: 2, + Count: 1, + Probability: prob, + WriteTimeout: 1 * time.Second, + ReadTimeout: 3 * time.Second, + ConnectTimeout: 1 * time.Second, + RequestTimeout: 5 * time.Second, + Rcodes: true, + Recurse: true, + } +} + +// A returns an A record from rr. It panics on errors. +func A(rr string) *dns.A { r, _ := dns.NewRR(rr); return r.(*dns.A) } diff --git a/cmd/doq_test.go b/cmd/doq_test.go index b8c742c..7410a85 100644 --- a/cmd/doq_test.go +++ b/cmd/doq_test.go @@ -5,20 +5,28 @@ import ( "crypto/tls" "crypto/x509" "encoding/binary" - "net" + "fmt" + "io" "os" "sync/atomic" - "time" "github.com/miekg/dns" "github.com/quic-go/quic-go" ) +type doqHandler func(req *dns.Msg) *dns.Msg + // doqServer is a DoQ test DNS server. type doqServer struct { addr string listener *quic.Listener closed atomic.Bool + handler doqHandler +} + +func newDoQServer(f doqHandler) *doqServer { + server := doqServer{handler: f} + return &server } func (d *doqServer) start() { @@ -44,25 +52,20 @@ func (d *doqServer) start() { if err != nil { return } - // sleep to have some not zero duration - time.Sleep(100 * time.Millisecond) - - resp := dns.Msg{ - MsgHdr: dns.MsgHdr{Rcode: dns.RcodeSuccess}, - Question: []dns.Question{{Name: "example.org.", Qtype: dns.TypeA}}, - Answer: []dns.RR{&dns.A{ - Hdr: dns.RR_Header{ - Name: "example.org.", - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 10, - }, - A: net.ParseIP("127.0.0.1"), - }}, + + req, err := readDOQMessage(stream) + if err != nil { + return + } + + resp := d.handler(req) + if resp == nil { + // this should cause timeout + return } pack, err := resp.Pack() if err != nil { - panic(err) + return } packWithPrefix := make([]byte, 2+len(pack)) binary.BigEndian.PutUint16(packWithPrefix, uint16(len(pack))) @@ -106,3 +109,36 @@ func generateTLSConfig() *tls.Config { MinVersion: tls.VersionTLS12, } } + +func readDOQMessage(r io.Reader) (*dns.Msg, error) { + // All DNS messages (queries and responses) sent over DoQ connections MUST + // be encoded as a 2-octet length field followed by the message content as + // specified in [RFC1035]. + // See https://www.rfc-editor.org/rfc/rfc9250.html#section-4.2-4 + sizeBuf := make([]byte, 2) + _, err := io.ReadFull(r, sizeBuf) + if err != nil { + return nil, err + } + + size := binary.BigEndian.Uint16(sizeBuf) + + if size == 0 { + return nil, fmt.Errorf("message size is 0: probably unsupported DoQ version") + } + + buf := make([]byte, size) + _, err = io.ReadFull(r, buf) + + // A client or server receives a STREAM FIN before receiving all the bytes + // for a message indicated in the 2-octet length field. + // See https://www.rfc-editor.org/rfc/rfc9250#section-4.3.3-2.2 + if size != uint16(len(buf)) { + return nil, fmt.Errorf("message size does not match 2-byte prefix") + } + + msg := &dns.Msg{} + err = msg.Unpack(buf) + + return msg, err +} diff --git a/cmd/result.go b/cmd/result.go index 4f8b385..243a3c1 100644 --- a/cmd/result.go +++ b/cmd/result.go @@ -33,7 +33,13 @@ type ResultStats struct { AuthenticatedDomains map[string]struct{} } -func (rs *ResultStats) record(req *dns.Msg, resp *dns.Msg, time time.Time, timing time.Duration) { +func (rs *ResultStats) record(req *dns.Msg, resp *dns.Msg, err error, time time.Time, timing time.Duration) { + if err != nil { + rs.Counters.IOError++ + rs.Errors = append(rs.Errors, err) + return + } + if resp.Truncated { rs.Counters.Truncated++ }