Skip to content

Commit

Permalink
BREAKING: switch from Update callback to Observer
Browse files Browse the repository at this point in the history
  • Loading branch information
abursavich committed Feb 10, 2022
1 parent a269473 commit 30ba0c3
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 114 deletions.
27 changes: 14 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# DynamicTLS

[![License](https://img.shields.io/badge/license-mit-blue.svg?style=for-the-badge)](https://raw.githubusercontent.com/abursavich/dynamictls/master/LICENSE)
[![GoDev Reference](https://img.shields.io/static/v1?logo=go&logoColor=white&color=00ADD8&label=dev&message=reference&style=for-the-badge)](https://pkg.go.dev/bursavich.dev/dynamictls)
[![Go Report Card](https://goreportcard.com/badge/bursavich.dev/dynamictls?style=for-the-badge)](https://goreportcard.com/report/bursavich.dev/dynamictls)
Expand All @@ -15,16 +16,16 @@ It provides simple integrations with HTTP/1.1, HTTP/2, gRPC, and Prometheus.

```go
// create metrics
metrics, err := tlsprom.NewMetrics(
observer, err := tlsprom.NewObserver(
tlsprom.WithHTTP(),
tlsprom.WithServer(),
)
check(err)
prometheus.MustRegister(metrics)
prometheus.MustRegister(observer)

// create TLS config
cfg, err := dynamictls.NewConfig(
dynamictls.WithNotifyFunc(metrics.Update),
dynamictls.WithObserver(observer),
dynamictls.WithCertificate(primaryCertFile, primaryKeyFile),
dynamictls.WithCertificate(secondaryCertFile, secondaryKeyFile),
dynamictls.WithRootCAs(caFile),
Expand All @@ -43,16 +44,16 @@ check(http.Serve(lis, http.DefaultServeMux))

```go
// create metrics
metrics, err := tlsprom.NewMetrics(
observer, err := tlsprom.NewObserver(
tlsprom.WithHTTP(),
tlsprom.WithClient(),
)
check(err)
prometheus.MustRegister(metrics)
prometheus.MustRegister(observer)

// create TLS config
cfg, err := dynamictls.NewConfig(
dynamictls.WithNotifyFunc(metrics.Update),
dynamictls.WithObserver(observer),
dynamictls.WithBase(&tls.Config{
MinVersion: tls.VersionTLS12,
}),
Expand All @@ -77,16 +78,16 @@ defer client.CloseIdleConnections()

```go
// create metrics
metrics, err := tlsprom.NewMetrics(
observer, err := tlsprom.NewObserver(
tlsprom.WithGRPC(),
tlsprom.WithServer(),
)
check(err)
prometheus.MustRegister(metrics)
prometheus.MustRegister(observer)

// create TLS config
cfg, err := dynamictls.NewConfig(
dynamictls.WithNotifyFunc(metrics.Update),
dynamictls.WithObserver(observer),
dynamictls.WithBase(&tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
MinVersion: tls.VersionTLS13,
Expand Down Expand Up @@ -115,16 +116,16 @@ check(srv.Serve(lis))

```go
// create metrics
metrics, err := tlsprom.NewMetrics(
observer, err := tlsprom.NewObserver(
tlsprom.WithGRPC(),
tlsprom.WithClient(),
)
check(err)
prometheus.MustRegister(metrics)
prometheus.MustRegister(observer)

// create TLS config
cfg, err := dynamictls.NewConfig(
dynamictls.WithNotifyFunc(metrics.Update),
dynamictls.WithObserver(observer),
dynamictls.WithBase(&tls.Config{
MinVersion: tls.VersionTLS13,
}),
Expand All @@ -146,4 +147,4 @@ conn, err := grpc.Dial(
check(err)
defer conn.Close()
client := pb.NewTestServiceClient(conn)
```
```
30 changes: 17 additions & 13 deletions dynamictls.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,16 @@ import (

const hashSize = 16 // 128-bit

// NotifyFunc is a function that is called when new config data
// is loaded or an error occurs loading new config data.
type NotifyFunc func(cfg *tls.Config, err error)
// An Observer observes when new config data is loaded or an error occurs loading new config data.
type Observer interface {
ObserveConfig(cfg *tls.Config)
ObserveReadError(err error)
}

type noopObserver struct{}

func (noopObserver) ObserveConfig(cfg *tls.Config) {}
func (noopObserver) ObserveReadError(err error) {}

// An Option applies optional configuration.
type Option interface {
Expand Down Expand Up @@ -110,10 +117,10 @@ func WithCertificate(certFile, keyFile string) Option {
})
}

// WithNotifyFunc returns an Option that registers the notify function.
func WithNotifyFunc(notify NotifyFunc) Option {
// WithObserver returns an Option that registers the Observer.
func WithObserver(observer Observer) Option {
return optionFunc(func(c *Config) error {
c.notifyFns = append(c.notifyFns, notify)
c.observer = observer
return nil
})
}
Expand Down Expand Up @@ -183,7 +190,7 @@ type Config struct {
rootCAs []string
clientCAs []string
certs []keyPair
notifyFns []NotifyFunc
observer Observer
log logr.Logger

watcher *fsnotify.Watcher
Expand All @@ -207,6 +214,7 @@ func NewConfig(options ...Option) (cfg *Config, err error) {
}()
cfg = &Config{
base: &tls.Config{},
observer: noopObserver{},
log: logr.Discard(),
watcher: w,
close: make(chan struct{}),
Expand Down Expand Up @@ -322,9 +330,7 @@ func (cfg *Config) read() error {
}

cfg.latest.Store(config)
for _, fn := range cfg.notifyFns {
fn(config, nil)
}
cfg.observer.ObserveConfig(config)
return nil
}

Expand All @@ -337,9 +343,7 @@ func (cfg *Config) watch() {
// TODO: ignore unrelated events
if err := cfg.read(); err != nil {
cfg.log.Error(err, "Read failure") // errors already decorated
for _, fn := range cfg.notifyFns {
fn(nil, err)
}
cfg.observer.ObserveReadError(err)
}
case err := <-cfg.watcher.Errors:
cfg.log.Error(err, "Watch failure")
Expand Down
81 changes: 55 additions & 26 deletions dynamictls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,36 @@ func certPoolEqual(x, y *x509.CertPool) bool {
return reflect.DeepEqual(xs, ys)
}

type testObserver struct {
configCh chan *tls.Config
errCh chan error
}

func newTestObserver() *testObserver {
return &testObserver{
configCh: make(chan *tls.Config, 1),
errCh: make(chan error, 1),
}
}

func (o *testObserver) ObserveConfig(cfg *tls.Config) {
timeout := time.NewTimer(10 * time.Second)
defer timeout.Stop()
select {
case <-timeout.C:
case o.configCh <- cfg:
}
}

func (o *testObserver) ObserveReadError(err error) {
timeout := time.NewTimer(10 * time.Second)
defer timeout.Stop()
select {
case <-timeout.C:
case o.errCh <- err:
}
}

func TestNotifyError(t *testing.T) {
// create temp dir
dir, err := ioutil.TempDir("", "")
Expand All @@ -234,10 +264,10 @@ func TestNotifyError(t *testing.T) {
keyFile := createFile(t, dir, "key.pem", keyPEMBlock)

// create config
errCh := make(chan error, 1)
obs := newTestObserver()
cfg, err := NewConfig(
WithCertificate(certFile, keyFile),
WithNotifyFunc(func(_ *tls.Config, err error) { errCh <- err }),
WithObserver(obs),
)
check(t, "Failed to initialize config", err)
defer cfg.Close()
Expand All @@ -246,7 +276,11 @@ func TestNotifyError(t *testing.T) {
defer timeout.Stop()

select {
case err := <-errCh:
case cfg := <-obs.configCh:
if cfg == nil {
t.Fatalf("Unexpected nil config")
}
case err := <-obs.errCh:
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
Expand All @@ -257,7 +291,11 @@ func TestNotifyError(t *testing.T) {
check(t, "Failed to remove cert file", os.Remove(certFile))

select {
case err := <-errCh:
case cfg := <-obs.configCh:
if cfg != nil {
t.Fatalf("Unexpected config")
}
case err := <-obs.errCh:
if err == nil {
t.Fatal("Expected an error after deleting certs")
}
Expand Down Expand Up @@ -307,38 +345,29 @@ func TestKubernetes(t *testing.T) {
check(t, "Failed to create symlink", os.Symlink(data0, data))

// create config
ch := make(chan result, 1)
notifyFn := func(config *tls.Config, err error) {
select {
case <-ch:
default:
}
ch <- result{config: config, err: err}
}
obs := newTestObserver()
wantCert := func(want *tls.Certificate) {
t.Helper()
timeout := time.NewTimer(5 * time.Second)
defer timeout.Stop()
var err error
for {
select {
case res := <-ch:
if res.err != nil {
// An error can occur if a filesystem event triggers a reload and a
// symlink flip happens between reading the public and private keys.
// The keys won't match due to this race, but a subsequent reload
// will also be triggered and they will match the next time.
t.Logf("Unexpected error, may be transient: %v", res.err)
err = res.err
continue
}
if res.config == nil {
case err = <-obs.errCh:
// An error can occur if a filesystem event triggers a reload and a
// symlink flip happens between reading the public and private keys.
// The keys won't match due to this race, but a subsequent reload
// will also be triggered and they will match the next time.
t.Logf("Unexpected error, may be transient: %v", err)
continue
case cfg := <-obs.configCh:
if cfg == nil {
t.Fatal("Config missing")
}
if len(res.config.Certificates) == 0 {
if len(cfg.Certificates) == 0 {
t.Fatal("Config missing certs")
}
got := res.config.Certificates[0]
got := cfg.Certificates[0]
if !reflect.DeepEqual(got.Certificate, want.Certificate) {
t.Fatal("Unexpected cert")
}
Expand All @@ -358,7 +387,7 @@ func TestKubernetes(t *testing.T) {
cfg, err := NewConfig(
WithCertificate(certFile, keyFile),
WithRootCAs(caFile),
WithNotifyFunc(notifyFn),
WithObserver(obs),
)
check(t, "Failed to initialize config", err)
defer cfg.Close()
Expand Down
13 changes: 7 additions & 6 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// Use of this source code is governed by The MIT License
// which can be found in the LICENSE file.

//go:build go1.14
// +build go1.14

package dynamictls_test
Expand All @@ -19,15 +20,15 @@ import (
)

func ExampleConfig_Listen() {
metrics, err := tlsprom.NewMetrics(
observer, err := tlsprom.NewObserver(
tlsprom.WithHTTP(),
tlsprom.WithServer(),
)
check(err)
prometheus.MustRegister(metrics)
prometheus.MustRegister(observer)

cfg, err := dynamictls.NewConfig(
dynamictls.WithNotifyFunc(metrics.Update),
dynamictls.WithObserver(observer),
dynamictls.WithCertificate(primaryCertFile, primaryKeyFile),
dynamictls.WithCertificate(secondaryCertFile, secondaryKeyFile),
dynamictls.WithRootCAs(caFile),
Expand All @@ -42,15 +43,15 @@ func ExampleConfig_Listen() {
}

func ExampleConfig_Dial() {
metrics, err := tlsprom.NewMetrics(
observer, err := tlsprom.NewObserver(
tlsprom.WithHTTP(),
tlsprom.WithClient(),
)
check(err)
prometheus.MustRegister(metrics)
prometheus.MustRegister(observer)

cfg, err := dynamictls.NewConfig(
dynamictls.WithNotifyFunc(metrics.Update),
dynamictls.WithObserver(observer),
dynamictls.WithBase(&tls.Config{
MinVersion: tls.VersionTLS12,
}),
Expand Down
Loading

0 comments on commit 30ba0c3

Please sign in to comment.