diff --git a/builder.go b/builder.go index f38db51..3c25702 100644 --- a/builder.go +++ b/builder.go @@ -1,7 +1,7 @@ package cloudmap import ( - "sync" + "context" "time" "google.golang.org/grpc/grpclog" @@ -37,11 +37,11 @@ type builder struct { // so you don't need to call this function to register the default builder. // // Default Options: +// // Session: session.NewSession() // HealthStatusFilter: HealthStatusFilterHealthy // MaxResults: 100 // RefreshInterval: 30s -// func Register(opts ...Opt) { b := &builder{ healthStatusFilter: HealthStatusFilterHealthy, @@ -72,23 +72,26 @@ func (b *builder) Build(t grpcresolver.Target, cc grpcresolver.ClientConn, _ grp } } + ctx, cancel := context.WithCancel(context.Background()) r := &resolver{ - mu: &sync.RWMutex{}, - logger: grpclog.Component(b.Scheme()), cc: cc, - ticker: time.NewTicker(b.refreshInterval), - sd: servicediscovery.New(sess), namespace: cmT.namespace, service: cmT.service, healthStatusFilter: b.healthStatusFilter, maxResults: b.maxResults, + + ctx: ctx, + cancel: cancel, + ticker: time.NewTicker(b.refreshInterval), + resolveCmd: make(chan struct{}, 1), } - go r.watch() + r.wg.Add(1) + go r.watcher() return r, nil } diff --git a/resolver.go b/resolver.go index 1c1c5d1..1fbcba7 100644 --- a/resolver.go +++ b/resolver.go @@ -1,6 +1,7 @@ package cloudmap import ( + "context" "fmt" "sync" "time" @@ -14,34 +15,35 @@ import ( "github.com/aws/aws-sdk-go/service/servicediscovery" ) -type resolver struct { - mu *sync.RWMutex - isClosed bool +type serviceDiscovery interface { + DiscoverInstances(input *servicediscovery.DiscoverInstancesInput) (*servicediscovery.DiscoverInstancesOutput, error) +} +type resolver struct { logger grpclog.LoggerV2 + cc grpcresolver.ClientConn - cc grpcresolver.ClientConn - - ticker *time.Ticker - - sd *servicediscovery.ServiceDiscovery + sd serviceDiscovery namespace string service string healthStatusFilter string maxResults int64 + + ctx context.Context + cancel context.CancelFunc + ticker *time.Ticker + resolveCmd chan struct{} + wg sync.WaitGroup } func (c *resolver) ResolveNow(grpcresolver.ResolveNowOptions) { - locked := c.mu.TryLock() - if !locked { // already resolving - return - } - defer c.mu.Unlock() - - if c.isClosed { - return + select { + case c.resolveCmd <- struct{}{}: + default: } +} +func (c *resolver) lookupCloudmap() (*grpcresolver.State, error) { output, err := c.sd.DiscoverInstances(&servicediscovery.DiscoverInstancesInput{ NamespaceName: aws.String(c.namespace), ServiceName: aws.String(c.service), @@ -65,8 +67,7 @@ func (c *resolver) ResolveNow(grpcresolver.ResolveNowOptions) { } else { c.logger.Errorln(err.Error()) } - c.cc.ReportError(err) - return + return nil, err } addrs := make([]grpcresolver.Address, len(output.Instances)) @@ -74,25 +75,37 @@ func (c *resolver) ResolveNow(grpcresolver.ResolveNowOptions) { addrs[i] = httpInstanceSummaryToAddr(instance) } - c.cc.UpdateState(grpcresolver.State{Addresses: addrs}) + return &grpcresolver.State{Addresses: addrs}, nil } func (c *resolver) Close() { - c.mu.Lock() - defer c.mu.Unlock() - - if c.isClosed { - return - } - - c.isClosed = true + c.cancel() c.ticker.Stop() + c.wg.Wait() } -func (c *resolver) watch() { +func (c *resolver) watcher() { + defer c.wg.Done() + for { - c.ResolveNow(grpcresolver.ResolveNowOptions{}) - <-c.ticker.C + state, err := c.lookupCloudmap() + if err != nil { + c.cc.ReportError(err) + } else { + err = c.cc.UpdateState(*state) + } + + if err != nil { + c.logger.Errorln(err) + // wait for next iteration + } + + select { + case <-c.ctx.Done(): + return + case <-c.ticker.C: + case <-c.resolveCmd: + } } } diff --git a/resolver_test.go b/resolver_test.go new file mode 100644 index 0000000..b845d4c --- /dev/null +++ b/resolver_test.go @@ -0,0 +1,70 @@ +package cloudmap + +import ( + "context" + "fmt" + "github.com/aws/aws-sdk-go/service/servicediscovery" + "google.golang.org/grpc/grpclog" + grpcresolver "google.golang.org/grpc/resolver" + "google.golang.org/grpc/serviceconfig" + "testing" + "time" +) + +type mockCC struct{} + +func (m mockCC) UpdateState(state grpcresolver.State) error { return nil } + +func (m mockCC) ReportError(err error) {} + +func (m mockCC) NewAddress(addresses []grpcresolver.Address) {} + +func (m mockCC) NewServiceConfig(serviceConfig string) {} + +func (m mockCC) ParseServiceConfig(serviceConfigJSON string) *serviceconfig.ParseResult { + return nil +} + +type mockDiscovery struct{} + +func (m mockDiscovery) DiscoverInstances(input *servicediscovery.DiscoverInstancesInput) (*servicediscovery.DiscoverInstancesOutput, error) { + time.Sleep(1 * time.Second) + fmt.Println("DiscoverInstances called") + return &servicediscovery.DiscoverInstancesOutput{ + Instances: make([]*servicediscovery.HttpInstanceSummary, 0), + }, nil +} + +func Test_resolver(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + r := &resolver{ + logger: grpclog.Component("test"), + + cc: mockCC{}, + sd: mockDiscovery{}, + + ctx: ctx, + cancel: cancel, + ticker: time.NewTicker(10 * time.Second), + resolveCmd: make(chan struct{}, 1), + } + + r.wg.Add(1) + go r.watcher() + + timeout := time.After(100 * time.Millisecond) + done := make(chan bool) + go func() { + for i := 0; i < 10; i++ { + r.ResolveNow(grpcresolver.ResolveNowOptions{}) + } + done <- true + }() + select { + case <-timeout: + t.Error("timeout") + case <-done: + t.Log("done") + } + r.Close() +}