Skip to content

Commit

Permalink
dynamictls: add dial error tests
Browse files Browse the repository at this point in the history
  • Loading branch information
abursavich committed Mar 27, 2020
1 parent 2b622c9 commit 75bb235
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 6 deletions.
95 changes: 95 additions & 0 deletions dynamic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
package dynamictls

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
Expand Down Expand Up @@ -447,6 +449,73 @@ func TestMTLS(t *testing.T) {
}
}

func TestDialErrors(t *testing.T) {
// create temp dir
dir, err := ioutil.TempDir("", "")
check(t, "Failed to create directory", err)
defer os.RemoveAll(dir)

// create certificates
ca, caCertPEM, _, err := tlstest.GenerateCert(nil)
check(t, "Failed to create CA", err)
caFile := createFile(t, dir, "roots.pem", caCertPEM)
_, certPEM, keyPEM, err := tlstest.GenerateCert(&tlstest.CertOptions{
Template: &x509.Certificate{
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageDataEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageClientAuth,
x509.ExtKeyUsageServerAuth,
},
},
Parent: ca,
})
check(t, "Failed to create certificate", err)
certFile := createFile(t, dir, "cert.pem", certPEM)
keyFile := createFile(t, dir, "key.pem", keyPEM)

// create config
cfg, err := NewConfig(
WithBase(&tls.Config{
MinVersion: tls.VersionTLS12,
}),
WithHTTP2(),
WithCertificate(certFile, keyFile),
WithRootCAs(caFile),
WithErrorLogger(t),
)
check(t, "Failed to create dynamic TLS config", err)
defer cfg.Close()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

dialErr := fmt.Errorf("dial test error")
cfg.dialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, dialErr
}
if _, err := cfg.Dial(ctx, "tcp", "localhost"); err != dialErr {
t.Fatalf("Dial error; want: %v; got: %v", dialErr, err)
}

cfg.dialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
return errConn{}, nil
}
if _, err = cfg.Dial(ctx, "tcp", "localhost"); err == nil {
t.Fatal("Expected a handshake error")
}

doneCtx, cancel := context.WithCancel(context.Background())
cancel()

cfg.dialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
return &ctxWaitConn{ctx: ctx}, nil
}
_, err = cfg.Dial(doneCtx, "tcp", "localhost")
if want := context.Canceled; err != want {
t.Fatalf("Dial error; want: %v; got: %v", want, err)
}
}

func createDir(t *testing.T, dir string, files map[string][]byte) {
t.Helper()
check(t, "Failed to make directory", os.Mkdir(dir, os.ModePerm))
Expand All @@ -468,3 +537,29 @@ func check(t *testing.T, msg string, err error) {
t.Fatalf("%s: %v", msg, err)
}
}

type ctxWaitConn struct {
ctx context.Context
errConn
}

func (c *ctxWaitConn) Read(b []byte) (n int, err error) {
<-c.ctx.Done()
return c.errConn.Read(b)
}

func (c *ctxWaitConn) Write(b []byte) (n int, err error) {
<-c.ctx.Done()
return c.errConn.Write(b)
}

type errConn struct{}

func (errConn) Read(b []byte) (n int, err error) { return 0, io.ErrClosedPipe }
func (errConn) Write(b []byte) (n int, err error) { return 0, io.ErrClosedPipe }
func (errConn) Close() error { return nil }
func (errConn) LocalAddr() net.Addr { return &net.UnixAddr{Net: "unix", Name: "/tmp/fake"} }
func (errConn) RemoteAddr() net.Addr { return &net.UnixAddr{Net: "unix", Name: "/tmp/fake"} }
func (errConn) SetDeadline(t time.Time) error { return nil }
func (errConn) SetReadDeadline(t time.Time) error { return nil }
func (errConn) SetWriteDeadline(t time.Time) error { return nil }
19 changes: 13 additions & 6 deletions dynamictls.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ type keyPair struct {
certFile, keyFile string
}

type dialFunc func(ctx context.Context, network, address string) (net.Conn, error)

var defaultDialFunc = (&net.Dialer{}).DialContext

// A Config is used to configure a TLS client or server.
type Config struct {
latest atomic.Value
Expand All @@ -191,6 +195,8 @@ type Config struct {
watcher *fsnotify.Watcher
close chan struct{} // signals watch goroutine to end
done chan struct{} // signals watch goroutine has ended

dialFunc dialFunc // used by tests
}

// NewConfig returns a new Config with the given options.
Expand All @@ -206,11 +212,12 @@ func NewConfig(options ...Option) (cfg *Config, err error) {
}
}()
cfg = &Config{
base: &tls.Config{},
errLog: noopLogger{},
watcher: w,
close: make(chan struct{}, 1),
done: make(chan struct{}),
base: &tls.Config{},
errLog: noopLogger{},
watcher: w,
close: make(chan struct{}, 1),
done: make(chan struct{}),
dialFunc: defaultDialFunc,
}
for _, o := range sortedOptions(options) {
if err := o.apply(cfg); err != nil {
Expand Down Expand Up @@ -259,7 +266,7 @@ func (cfg *Config) Listen(ctx context.Context, network, address string) (net.Lis
// Dial connects to the given network address and initiates a TLS handshake,
// returning the resulting TLS connection.
func (cfg *Config) Dial(ctx context.Context, network, address string) (net.Conn, error) {
rawConn, err := (&net.Dialer{}).DialContext(ctx, network, address)
rawConn, err := cfg.dialFunc(ctx, network, address)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 75bb235

Please sign in to comment.