Skip to content

Commit

Permalink
add option for forcing each worker to have separate connection
Browse files Browse the repository at this point in the history
  • Loading branch information
Tantalor93 committed May 9, 2024
1 parent b6d8e62 commit a1e561e
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 20 deletions.
5 changes: 3 additions & 2 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ linters-settings:
- '^ @.*' # swaggo comments like // @title
- '^ (\d+)(\.|\)).*' # enumeration comments like // 1. or // 1)
gosec:
global:
audit: true
config:
global:
audit: true
excludes:
- G104
4 changes: 4 additions & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ func init() {
"If the file does not exist, the file will be created.").
Default(dnsbench.DefaultRequestLogPath).StringVar(&benchmark.RequestLogPath)

pApp.Flag("separate-worker-connections", "Controls whether the concurrent workers will try to share connections to the server or not. When enabled "+
"the workers will use separate connections. Disabled by default.").
Default("false").BoolVar(&benchmark.SeparateWorkerConnections)

pApp.Arg("queries", "Queries to issue. It can be a local file referenced using @<file-path>, for example @data/2-domains. "+
"It can also be resource accessible using HTTP, like https://raw.githubusercontent.com/Tantalor93/dnspyre/master/data/1000-domains, in that "+
"case, the file will be downloaded and saved in-memory. "+
Expand Down
45 changes: 34 additions & 11 deletions pkg/dnsbench/benchmark.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ type Benchmark struct {
// If it exists, the request logs are appended to the file.
RequestLogPath string

// SeparateWorkerConnections controls whether the concurrent workers will try to share connections to the server or not. When set true,
// the workers will NOT share connections and each worker will have separate connection.
SeparateWorkerConnections bool

// Writer used for writing benchmark execution logs and results. Default is os.Stdout.
Writer io.Writer

Expand All @@ -200,8 +204,8 @@ type Benchmark struct {

type queryFunc func(context.Context, string, *dns.Msg) (*dns.Msg, error)

// prepare validates and normalizes Benchmark settings.
func (b *Benchmark) prepare() error {
// init validates and normalizes Benchmark settings.
func (b *Benchmark) init() error {
if b.Writer == nil {
b.Writer = os.Stdout
}
Expand Down Expand Up @@ -259,7 +263,7 @@ func (b *Benchmark) prepare() error {
}
}

if b.RequestLogEnabled {
if b.RequestLogEnabled && len(b.RequestLogPath) == 0 {
b.RequestLogPath = DefaultRequestLogPath
}

Expand All @@ -268,6 +272,12 @@ func (b *Benchmark) prepare() error {

// Run executes benchmark, if benchmark is unable to start the error is returned, otherwise array of results from parallel benchmark goroutines is returned.
func (b *Benchmark) Run(ctx context.Context) ([]*ResultStats, error) {
color.NoColor = !b.Color

if err := b.init(); err != nil {
return nil, err
}

if b.RequestLogEnabled {
file, err := os.OpenFile(b.RequestLogPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
if err != nil {
Expand All @@ -277,12 +287,6 @@ func (b *Benchmark) Run(ctx context.Context) ([]*ResultStats, error) {
log.SetOutput(file)
}

color.NoColor = !b.Color

if err := b.prepare(); err != nil {
return nil, err
}

questions, err := b.prepareQuestions()
if err != nil {
return nil, err
Expand Down Expand Up @@ -499,13 +503,32 @@ func (b *Benchmark) queryFactory() func() queryFunc {
// and granular control of the connection
switch {
case b.useDoH:
if b.SeparateWorkerConnections {
return func() queryFunc {
dohQuery := b.dohQuery()
return dohQuery
}
}
dohQuery := b.dohQuery()
queryFactory := func() queryFunc {
return func() queryFunc {
return dohQuery
}
return queryFactory
case b.useQuic:
h, _, _ := net.SplitHostPort(b.Server)
if b.SeparateWorkerConnections {
return func() queryFunc {
// nolint:gosec
quicClient := doq.NewClient(b.Server, doq.Options{
TLSConfig: &tls.Config{ServerName: h, InsecureSkipVerify: b.Insecure},
ReadTimeout: b.ReadTimeout,
WriteTimeout: b.WriteTimeout,
ConnectTimeout: b.ConnectTimeout,
})
return func(ctx context.Context, _ string, msg *dns.Msg) (*dns.Msg, error) {
return quicClient.Send(ctx, msg)
}
}
}
// nolint:gosec
quicClient := doq.NewClient(b.Server, doq.Options{
TLSConfig: &tls.Config{ServerName: h, InsecureSkipVerify: b.Insecure},
Expand Down
194 changes: 191 additions & 3 deletions pkg/dnsbench/benchmark_api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ import (
"os"
"regexp"
"strconv"
"sync"
"testing"
"time"

"github.com/miekg/dns"
"github.com/quic-go/quic-go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tantalor93/dnspyre/v3/pkg/dnsbench"
Expand Down Expand Up @@ -726,7 +728,7 @@ func TestBenchmark_Run_PlainDNS_default_count(t *testing.T) {
}

func TestBenchmark_Run_DoQ(t *testing.T) {
server := newDoQServer(func(r *dns.Msg) *dns.Msg {
server := newDoQServer(func(_ quic.Connection, r *dns.Msg) *dns.Msg {
ret := new(dns.Msg)
ret.SetReply(r)
ret.Answer = append(ret.Answer, A("example.org. IN A 127.0.0.1"))
Expand Down Expand Up @@ -1033,7 +1035,7 @@ func TestBenchmark_Run_DoH_error(t *testing.T) {
}

func TestBenchmark_Run_DoQ_error(t *testing.T) {
server := newDoQServer(func(_ *dns.Msg) *dns.Msg {
server := newDoQServer(func(_ quic.Connection, _ *dns.Msg) *dns.Msg {
return nil
})
server.start()
Expand Down Expand Up @@ -1238,7 +1240,7 @@ func TestBenchmark_Run_DoH_truncated(t *testing.T) {
}

func TestBenchmark_Run_DoQ_truncated(t *testing.T) {
server := newDoQServer(func(r *dns.Msg) *dns.Msg {
server := newDoQServer(func(_ quic.Connection, r *dns.Msg) *dns.Msg {
ret := new(dns.Msg)
ret.SetReply(r)
ret.Answer = append(ret.Answer, A("example.org. IN A 127.0.0.1"))
Expand Down Expand Up @@ -1359,6 +1361,192 @@ func TestBenchmark_Requestlog(t *testing.T) {
assert.Equal(t, map[string]int{"AAAA": 2, "A": 2}, qtypes)
}

func TestBenchmark_Run_DoH_separate_connections(t *testing.T) {
tests := []struct {
name string
separateConnections bool
wantNumberOfConnections int
}{
{
name: "separate connections",
separateConnections: true,
wantNumberOfConnections: 5,
},
{
name: "shared connections",
separateConnections: false,
wantNumberOfConnections: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cert, err := tls.LoadX509KeyPair("testdata/test.crt", "testdata/test.key")
require.NoError(t, err)

certs, err := os.ReadFile("testdata/test.crt")
require.NoError(t, err)

pool, err := x509.SystemCertPool()
require.NoError(t, err)

pool.AppendCertsFromPEM(certs)
config := tls.Config{
ServerName: "localhost",
RootCAs: pool,
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
}

mutex := sync.Mutex{}
remoteAddrs := make(map[string]int)

ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mutex.Lock()
remoteAddrs[r.RemoteAddr]++
mutex.Unlock()

bd, err := io.ReadAll(r.Body)
if err != nil {
panic(err)
}

msg := dns.Msg{}
err = msg.Unpack(bd)
if err != nil {
panic(err)
}

msg.Answer = append(msg.Answer, A("example.org. IN A 127.0.0.1"))

pack, err := msg.Pack()
if err != nil {
panic(err)
}

_, err = w.Write(pack)
if err != nil {
panic(err)
}
}))
ts.EnableHTTP2 = true
ts.TLS = &config
ts.StartTLS()
defer ts.Close()

buf := bytes.Buffer{}
bench := dnsbench.Benchmark{
Queries: []string{"example.org"},
Types: []string{"A"},
Server: ts.URL,
DohProtocol: "2",
TCP: true,
Concurrency: 5,
Count: 2,
Probability: 1,
WriteTimeout: 1 * time.Second,
ReadTimeout: 3 * time.Second,
ConnectTimeout: 1 * time.Second,
RequestTimeout: 5 * time.Second,
Rcodes: true,
Recurse: true,
DohMethod: dnsbench.PostHTTPMethod,
Writer: &buf,
SeparateWorkerConnections: tt.separateConnections,
Insecure: true,
}

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
rs, err := bench.Run(ctx)

// close right away to mitigate race detector failures
ts.Close()

require.NoError(t, err, "expected no error from benchmark run")
assert.Len(t, rs, 5)
for _, v := range rs {
assert.Empty(t, v.Errors)
}
assert.Len(t, remoteAddrs, tt.wantNumberOfConnections)
assert.Equal(t, fmt.Sprintf("Using 1 hostnames\nBenchmarking %s/dns-query via https/2 (POST) with 5 concurrent requests \n", ts.URL), buf.String())
})
}
}

func TestBenchmark_Run_DoQ_separate_connections(t *testing.T) {
tests := []struct {
name string
separateConnections bool
wantNumberOfConnections int
}{
{
name: "separate connections",
separateConnections: true,
wantNumberOfConnections: 5,
},
{
name: "shared connections",
separateConnections: false,
wantNumberOfConnections: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mutex := sync.Mutex{}
remoteAddrs := make(map[string]int)

server := newDoQServer(func(c quic.Connection, r *dns.Msg) *dns.Msg {
mutex.Lock()
remoteAddrs[c.RemoteAddr().String()]++
mutex.Unlock()

ret := new(dns.Msg)
ret.SetReply(r)
ret.Answer = append(ret.Answer, A("example.org. IN A 127.0.0.1"))
return ret
})
server.start()
defer server.stop()

buf := bytes.Buffer{}
bench := dnsbench.Benchmark{
Queries: []string{"example.org"},
Types: []string{"A"},
Server: "quic://" + server.addr,
TCP: true,
Concurrency: 5,
Count: 2,
Probability: 1,
WriteTimeout: 1 * time.Second,
ReadTimeout: 3 * time.Second,
ConnectTimeout: 1 * time.Second,
RequestTimeout: 5 * time.Second,
Rcodes: true,
Recurse: true,
DohMethod: dnsbench.PostHTTPMethod,
Writer: &buf,
SeparateWorkerConnections: tt.separateConnections,
Insecure: true,
}

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
rs, err := bench.Run(ctx)

// stop right away to mitigate race detector failures
server.stop()

require.NoError(t, err, "expected no error from benchmark run")
assert.Len(t, rs, 5)
for _, v := range rs {
assert.Empty(t, v.Errors)
}
assert.Len(t, remoteAddrs, tt.wantNumberOfConnections)
assert.Equal(t, fmt.Sprintf("Using 1 hostnames\nBenchmarking %s via quic with 5 concurrent requests \n", server.addr), buf.String())
})
}
}

func parseRequestLogs(t *testing.T, reader io.Reader) []requestLog {
pattern := `.*worker:\[(.*)\] reqid:\[(.*)\] qname:\[(.*)\] qtype:\[(.*)\] respid:\[(.*)\] rcode:\[(.*)\] respflags:\[(.*)\] err:\[(.*)\] duration:\[(.*)\]$`
regex := regexp.MustCompile(pattern)
Expand Down
4 changes: 2 additions & 2 deletions pkg/dnsbench/benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/stretchr/testify/require"
)

func TestBenchmark_prepare(t *testing.T) {
func TestBenchmark_init(t *testing.T) {
tests := []struct {
name string
benchmark Benchmark
Expand Down Expand Up @@ -110,7 +110,7 @@ func TestBenchmark_prepare(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.benchmark.prepare()
err := tt.benchmark.init()

require.Equal(t, tt.wantErr, err != nil)
if !tt.wantErr {
Expand Down
4 changes: 2 additions & 2 deletions pkg/dnsbench/doq_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"github.com/quic-go/quic-go"
)

type doqHandler func(req *dns.Msg) *dns.Msg
type doqHandler func(conn quic.Connection, req *dns.Msg) *dns.Msg

// doqServer is a DoQ test DNS server.
type doqServer struct {
Expand Down Expand Up @@ -58,7 +58,7 @@ func (d *doqServer) start() {
return
}

resp := d.handler(req)
resp := d.handler(conn, req)
if resp == nil {
// this should cause timeout
return
Expand Down

0 comments on commit a1e561e

Please sign in to comment.