diff --git a/.golangci.yml b/.golangci.yml index 2587e98..98e45a3 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -39,7 +39,8 @@ linters-settings: - '^ @.*' # swaggo comments like // @title - '^ (\d+)(\.|\)).*' # enumeration comments like // 1. or // 1) gosec: - global: - audit: true + config: + global: + audit: true excludes: - G104 diff --git a/cmd/root.go b/cmd/root.go index 8a71914..5366c79 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -152,6 +152,10 @@ func init() { "If the file does not exist, the file will be created."). Default(dnsbench.DefaultRequestLogPath).StringVar(&benchmark.RequestLogPath) + pApp.Flag("separate-worker-connections", "Controls whether the concurrent workers will try to share connections to the server or not. When enabled "+ + "the workers will use separate connections. Disabled by default."). + Default("false").BoolVar(&benchmark.SeparateWorkerConnections) + pApp.Arg("queries", "Queries to issue. It can be a local file referenced using @, for example @data/2-domains. "+ "It can also be resource accessible using HTTP, like https://raw.githubusercontent.com/Tantalor93/dnspyre/master/data/1000-domains, in that "+ "case, the file will be downloaded and saved in-memory. "+ diff --git a/pkg/dnsbench/benchmark.go b/pkg/dnsbench/benchmark.go index feb4234..28935c3 100644 --- a/pkg/dnsbench/benchmark.go +++ b/pkg/dnsbench/benchmark.go @@ -190,6 +190,10 @@ type Benchmark struct { // If it exists, the request logs are appended to the file. RequestLogPath string + // SeparateWorkerConnections controls whether the concurrent workers will try to share connections to the server or not. When set true, + // the workers will NOT share connections and each worker will have separate connection. + SeparateWorkerConnections bool + // Writer used for writing benchmark execution logs and results. Default is os.Stdout. Writer io.Writer @@ -200,8 +204,8 @@ type Benchmark struct { type queryFunc func(context.Context, string, *dns.Msg) (*dns.Msg, error) -// prepare validates and normalizes Benchmark settings. -func (b *Benchmark) prepare() error { +// init validates and normalizes Benchmark settings. +func (b *Benchmark) init() error { if b.Writer == nil { b.Writer = os.Stdout } @@ -259,7 +263,7 @@ func (b *Benchmark) prepare() error { } } - if b.RequestLogEnabled { + if b.RequestLogEnabled && len(b.RequestLogPath) == 0 { b.RequestLogPath = DefaultRequestLogPath } @@ -268,6 +272,12 @@ func (b *Benchmark) prepare() error { // Run executes benchmark, if benchmark is unable to start the error is returned, otherwise array of results from parallel benchmark goroutines is returned. func (b *Benchmark) Run(ctx context.Context) ([]*ResultStats, error) { + color.NoColor = !b.Color + + if err := b.init(); err != nil { + return nil, err + } + if b.RequestLogEnabled { file, err := os.OpenFile(b.RequestLogPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) if err != nil { @@ -277,12 +287,6 @@ func (b *Benchmark) Run(ctx context.Context) ([]*ResultStats, error) { log.SetOutput(file) } - color.NoColor = !b.Color - - if err := b.prepare(); err != nil { - return nil, err - } - questions, err := b.prepareQuestions() if err != nil { return nil, err @@ -499,13 +503,32 @@ func (b *Benchmark) queryFactory() func() queryFunc { // and granular control of the connection switch { case b.useDoH: + if b.SeparateWorkerConnections { + return func() queryFunc { + dohQuery := b.dohQuery() + return dohQuery + } + } dohQuery := b.dohQuery() - queryFactory := func() queryFunc { + return func() queryFunc { return dohQuery } - return queryFactory case b.useQuic: h, _, _ := net.SplitHostPort(b.Server) + if b.SeparateWorkerConnections { + return func() queryFunc { + // nolint:gosec + quicClient := doq.NewClient(b.Server, doq.Options{ + TLSConfig: &tls.Config{ServerName: h, InsecureSkipVerify: b.Insecure}, + ReadTimeout: b.ReadTimeout, + WriteTimeout: b.WriteTimeout, + ConnectTimeout: b.ConnectTimeout, + }) + return func(ctx context.Context, _ string, msg *dns.Msg) (*dns.Msg, error) { + return quicClient.Send(ctx, msg) + } + } + } // nolint:gosec quicClient := doq.NewClient(b.Server, doq.Options{ TLSConfig: &tls.Config{ServerName: h, InsecureSkipVerify: b.Insecure}, diff --git a/pkg/dnsbench/benchmark_api_test.go b/pkg/dnsbench/benchmark_api_test.go index 16d72ab..3980713 100644 --- a/pkg/dnsbench/benchmark_api_test.go +++ b/pkg/dnsbench/benchmark_api_test.go @@ -15,10 +15,12 @@ import ( "os" "regexp" "strconv" + "sync" "testing" "time" "github.com/miekg/dns" + "github.com/quic-go/quic-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tantalor93/dnspyre/v3/pkg/dnsbench" @@ -726,7 +728,7 @@ func TestBenchmark_Run_PlainDNS_default_count(t *testing.T) { } func TestBenchmark_Run_DoQ(t *testing.T) { - server := newDoQServer(func(r *dns.Msg) *dns.Msg { + server := newDoQServer(func(_ quic.Connection, r *dns.Msg) *dns.Msg { ret := new(dns.Msg) ret.SetReply(r) ret.Answer = append(ret.Answer, A("example.org. IN A 127.0.0.1")) @@ -1033,7 +1035,7 @@ func TestBenchmark_Run_DoH_error(t *testing.T) { } func TestBenchmark_Run_DoQ_error(t *testing.T) { - server := newDoQServer(func(_ *dns.Msg) *dns.Msg { + server := newDoQServer(func(_ quic.Connection, _ *dns.Msg) *dns.Msg { return nil }) server.start() @@ -1238,7 +1240,7 @@ func TestBenchmark_Run_DoH_truncated(t *testing.T) { } func TestBenchmark_Run_DoQ_truncated(t *testing.T) { - server := newDoQServer(func(r *dns.Msg) *dns.Msg { + server := newDoQServer(func(_ quic.Connection, r *dns.Msg) *dns.Msg { ret := new(dns.Msg) ret.SetReply(r) ret.Answer = append(ret.Answer, A("example.org. IN A 127.0.0.1")) @@ -1359,6 +1361,192 @@ func TestBenchmark_Requestlog(t *testing.T) { assert.Equal(t, map[string]int{"AAAA": 2, "A": 2}, qtypes) } +func TestBenchmark_Run_DoH_separate_connections(t *testing.T) { + tests := []struct { + name string + separateConnections bool + wantNumberOfConnections int + }{ + { + name: "separate connections", + separateConnections: true, + wantNumberOfConnections: 5, + }, + { + name: "shared connections", + separateConnections: false, + wantNumberOfConnections: 1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cert, err := tls.LoadX509KeyPair("testdata/test.crt", "testdata/test.key") + require.NoError(t, err) + + certs, err := os.ReadFile("testdata/test.crt") + require.NoError(t, err) + + pool, err := x509.SystemCertPool() + require.NoError(t, err) + + pool.AppendCertsFromPEM(certs) + config := tls.Config{ + ServerName: "localhost", + RootCAs: pool, + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, + } + + mutex := sync.Mutex{} + remoteAddrs := make(map[string]int) + + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mutex.Lock() + remoteAddrs[r.RemoteAddr]++ + mutex.Unlock() + + bd, err := io.ReadAll(r.Body) + if err != nil { + panic(err) + } + + msg := dns.Msg{} + err = msg.Unpack(bd) + if err != nil { + panic(err) + } + + msg.Answer = append(msg.Answer, A("example.org. IN A 127.0.0.1")) + + pack, err := msg.Pack() + if err != nil { + panic(err) + } + + _, err = w.Write(pack) + if err != nil { + panic(err) + } + })) + ts.EnableHTTP2 = true + ts.TLS = &config + ts.StartTLS() + defer ts.Close() + + buf := bytes.Buffer{} + bench := dnsbench.Benchmark{ + Queries: []string{"example.org"}, + Types: []string{"A"}, + Server: ts.URL, + DohProtocol: "2", + TCP: true, + Concurrency: 5, + Count: 2, + Probability: 1, + WriteTimeout: 1 * time.Second, + ReadTimeout: 3 * time.Second, + ConnectTimeout: 1 * time.Second, + RequestTimeout: 5 * time.Second, + Rcodes: true, + Recurse: true, + DohMethod: dnsbench.PostHTTPMethod, + Writer: &buf, + SeparateWorkerConnections: tt.separateConnections, + Insecure: true, + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + rs, err := bench.Run(ctx) + + // close right away to mitigate race detector failures + ts.Close() + + require.NoError(t, err, "expected no error from benchmark run") + assert.Len(t, rs, 5) + for _, v := range rs { + assert.Empty(t, v.Errors) + } + assert.Len(t, remoteAddrs, tt.wantNumberOfConnections) + assert.Equal(t, fmt.Sprintf("Using 1 hostnames\nBenchmarking %s/dns-query via https/2 (POST) with 5 concurrent requests \n", ts.URL), buf.String()) + }) + } +} + +func TestBenchmark_Run_DoQ_separate_connections(t *testing.T) { + tests := []struct { + name string + separateConnections bool + wantNumberOfConnections int + }{ + { + name: "separate connections", + separateConnections: true, + wantNumberOfConnections: 5, + }, + { + name: "shared connections", + separateConnections: false, + wantNumberOfConnections: 1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mutex := sync.Mutex{} + remoteAddrs := make(map[string]int) + + server := newDoQServer(func(c quic.Connection, r *dns.Msg) *dns.Msg { + mutex.Lock() + remoteAddrs[c.RemoteAddr().String()]++ + mutex.Unlock() + + ret := new(dns.Msg) + ret.SetReply(r) + ret.Answer = append(ret.Answer, A("example.org. IN A 127.0.0.1")) + return ret + }) + server.start() + defer server.stop() + + buf := bytes.Buffer{} + bench := dnsbench.Benchmark{ + Queries: []string{"example.org"}, + Types: []string{"A"}, + Server: "quic://" + server.addr, + TCP: true, + Concurrency: 5, + Count: 2, + Probability: 1, + WriteTimeout: 1 * time.Second, + ReadTimeout: 3 * time.Second, + ConnectTimeout: 1 * time.Second, + RequestTimeout: 5 * time.Second, + Rcodes: true, + Recurse: true, + DohMethod: dnsbench.PostHTTPMethod, + Writer: &buf, + SeparateWorkerConnections: tt.separateConnections, + Insecure: true, + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + rs, err := bench.Run(ctx) + + // stop right away to mitigate race detector failures + server.stop() + + require.NoError(t, err, "expected no error from benchmark run") + assert.Len(t, rs, 5) + for _, v := range rs { + assert.Empty(t, v.Errors) + } + assert.Len(t, remoteAddrs, tt.wantNumberOfConnections) + assert.Equal(t, fmt.Sprintf("Using 1 hostnames\nBenchmarking %s via quic with 5 concurrent requests \n", server.addr), buf.String()) + }) + } +} + func parseRequestLogs(t *testing.T, reader io.Reader) []requestLog { pattern := `.*worker:\[(.*)\] reqid:\[(.*)\] qname:\[(.*)\] qtype:\[(.*)\] respid:\[(.*)\] rcode:\[(.*)\] respflags:\[(.*)\] err:\[(.*)\] duration:\[(.*)\]$` regex := regexp.MustCompile(pattern) diff --git a/pkg/dnsbench/benchmark_test.go b/pkg/dnsbench/benchmark_test.go index 2e531de..a7b7eca 100644 --- a/pkg/dnsbench/benchmark_test.go +++ b/pkg/dnsbench/benchmark_test.go @@ -8,7 +8,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestBenchmark_prepare(t *testing.T) { +func TestBenchmark_init(t *testing.T) { tests := []struct { name string benchmark Benchmark @@ -110,7 +110,7 @@ func TestBenchmark_prepare(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := tt.benchmark.prepare() + err := tt.benchmark.init() require.Equal(t, tt.wantErr, err != nil) if !tt.wantErr { diff --git a/pkg/dnsbench/doq_test.go b/pkg/dnsbench/doq_test.go index 828f95c..985483f 100644 --- a/pkg/dnsbench/doq_test.go +++ b/pkg/dnsbench/doq_test.go @@ -14,7 +14,7 @@ import ( "github.com/quic-go/quic-go" ) -type doqHandler func(req *dns.Msg) *dns.Msg +type doqHandler func(conn quic.Connection, req *dns.Msg) *dns.Msg // doqServer is a DoQ test DNS server. type doqServer struct { @@ -58,7 +58,7 @@ func (d *doqServer) start() { return } - resp := d.handler(req) + resp := d.handler(conn, req) if resp == nil { // this should cause timeout return