Skip to content

Commit

Permalink
add DoT tests and support insecure flag with DoT
Browse files Browse the repository at this point in the history
  • Loading branch information
Tantalor93 committed Aug 12, 2023
1 parent 4ef7b7a commit 7d1ae79
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 34 deletions.
8 changes: 5 additions & 3 deletions cmd/benchmark.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,18 +404,20 @@ func (b *Benchmark) getDNSClient() *dns.Client {
network := "udp"
if b.TCP {
network = "tcp"
} else if b.DOT {
}
if b.DOT {
network = "tcp-tls"
}

dnsClient := dns.Client{
return &dns.Client{
Net: network,
DialTimeout: b.ConnectTimeout,
WriteTimeout: b.WriteTimeout,
ReadTimeout: b.ReadTimeout,
Timeout: b.RequestTimeout,
// nolint:gosec
TLSConfig: &tls.Config{InsecureSkipVerify: b.Insecure},
}
return &dnsClient
}

func (b *Benchmark) prepareQuestions() ([]string, error) {
Expand Down
55 changes: 50 additions & 5 deletions cmd/benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ package cmd

import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"

Expand Down Expand Up @@ -46,7 +49,7 @@ func Test_do_classic_dns(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := NewServer(tt.args.protocol, func(w dns.ResponseWriter, r *dns.Msg) {
s := NewServer(tt.args.protocol, nil, func(w dns.ResponseWriter, r *dns.Msg) {
ret := new(dns.Msg)
ret.SetReply(r)
ret.Answer = append(ret.Answer, A("example.org. IN A 127.0.0.1"))
Expand Down Expand Up @@ -155,7 +158,7 @@ func Test_do_doh_get(t *testing.T) {
}

func Test_do_probability(t *testing.T) {
s := NewServer(udp, func(w dns.ResponseWriter, r *dns.Msg) {
s := NewServer(udp, nil, func(w dns.ResponseWriter, r *dns.Msg) {
ret := new(dns.Msg)
ret.SetReply(r)
ret.Answer = append(ret.Answer, A("example.org. IN A 127.0.0.1"))
Expand All @@ -182,7 +185,7 @@ func Test_do_probability(t *testing.T) {
}

func Test_download_external_datasource_using_http(t *testing.T) {
s := NewServer("udp", func(w dns.ResponseWriter, r *dns.Msg) {
s := NewServer("udp", nil, func(w dns.ResponseWriter, r *dns.Msg) {
ret := new(dns.Msg)
ret.SetReply(r)
ret.Answer = append(ret.Answer, A("example.org. IN A 127.0.0.1"))
Expand Down Expand Up @@ -285,7 +288,7 @@ func Test_download_external_datasource_using_http_wrong_response(t *testing.T) {
}

func Test_do_classic_dns_with_duration(t *testing.T) {
s := NewServer("udp", func(w dns.ResponseWriter, r *dns.Msg) {
s := NewServer("udp", nil, func(w dns.ResponseWriter, r *dns.Msg) {
ret := new(dns.Msg)
ret.SetReply(r)
ret.Answer = append(ret.Answer, A("example.org. IN A 127.0.0.1"))
Expand Down Expand Up @@ -341,7 +344,7 @@ func Test_duration_and_count_specified_at_once(t *testing.T) {
}

func Test_do_classic_dns_default_count(t *testing.T) {
s := NewServer("udp", func(w dns.ResponseWriter, r *dns.Msg) {
s := NewServer("udp", nil, func(w dns.ResponseWriter, r *dns.Msg) {
ret := new(dns.Msg)
ret.SetReply(r)
ret.Answer = append(ret.Answer, A("example.org. IN A 127.0.0.1"))
Expand Down Expand Up @@ -393,6 +396,48 @@ func Test_do_doq(t *testing.T) {
assertResult(t, rs)
}

func Test_do_dot(t *testing.T) {
cert, err := tls.LoadX509KeyPair("test.crt", "test.key")
require.NoError(t, err)

certs, err := os.ReadFile("test.crt")
require.NoError(t, err)

pool, err := x509.SystemCertPool()
require.NoError(t, err)

pool.AppendCertsFromPEM(certs)
config := tls.Config{
ServerName: "localhost",
RootCAs: pool,
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
}

server := NewServer("tcp-tls", &config, func(w dns.ResponseWriter, r *dns.Msg) {
ret := new(dns.Msg)
ret.SetReply(r)
ret.Answer = append(ret.Answer, A("example.org. IN A 127.0.0.1"))

// wait some time to actually have some observable duration
time.Sleep(time.Millisecond * 500)

w.WriteMsg(ret)
})
defer server.Close()

bench := createBenchmark(server.Addr, false, 1)
bench.Insecure = true
bench.DOT = true

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
rs, err := bench.Run(ctx)

assert.NoError(t, err, "expected no error from benchmark run")
assertResult(t, rs)
}

func assertResult(t *testing.T, rs []*ResultStats) {
if assert.Len(t, rs, 2, "Run(ctx) rstats") {
rs0 := rs[0]
Expand Down
39 changes: 13 additions & 26 deletions cmd/server_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package cmd

import (
"net"
"crypto/tls"

"github.com/miekg/dns"
)
Expand All @@ -18,35 +18,22 @@ func (s *Server) Close() {
}

// NewServer creates and starts new DNS server instance.
func NewServer(network string, f dns.HandlerFunc) *Server {
func NewServer(network string, tlsConfig *tls.Config, f dns.HandlerFunc) *Server {
ch := make(chan bool)
s := &dns.Server{}
s.Handler = f
s := &dns.Server{Net: network, Addr: "127.0.0.1:0", TLSConfig: tlsConfig, NotifyStartedFunc: func() { close(ch) }, Handler: f}

for i := 0; i < 10; i++ {
s.Listener, _ = net.Listen("tcp", "127.0.0.1:0")
if network == "udp" {
if s.Listener == nil {
continue
}
s.PacketConn, _ = net.ListenPacket("udp", s.Listener.Addr().String())
if s.PacketConn != nil {
break
}
}
if s.Listener != nil {
break
}
}
if s.Listener == nil {
panic("failed to create new client")
}

s.NotifyStartedFunc = func() { close(ch) }
go func() {
s.ActivateAndServe()
if err := s.ListenAndServe(); err != nil {
panic(err)
}
}()

<-ch
return &Server{inner: s, Addr: s.Listener.Addr().String()}
server := Server{inner: s}
if network == "udp" {
server.Addr = s.PacketConn.LocalAddr().String()
} else {
server.Addr = s.Listener.Addr().String()
}
return &server
}

0 comments on commit 7d1ae79

Please sign in to comment.