Skip to content

Commit

Permalink
Move Service Bus Pubsub/Binding to common auth (dapr#1201)
Browse files Browse the repository at this point in the history
* Move Service Bus Pubsub/Binding to common auth

Both the pubsub and input/output binding for Azure Service Bus were
connecting via a connection string. This is still supported but will
now fallback to using AAD from the common auth library. This is also
the recommended auth pattern going forward.

* Move AMPQ specific auth and fix linter issues

* Make conn string and namespace mutually exclusive

* Move resourceName to a constant

* Update auth_amqp.go

* Update auth.go

Co-authored-by: Long Dai <long.dai@intel.com>
Co-authored-by: Simon Leet <31784195+CodeMonkeyLeet@users.noreply.github.com>
Co-authored-by: Artur Souza <artursouza.ms@outlook.com>
Co-authored-by: Dapr Bot <56698301+dapr-bot@users.noreply.github.com>
Signed-off-by: Amit Mor <amitm@at-bay.com>
  • Loading branch information
5 people authored and amimimor committed Dec 9, 2021
1 parent a134916 commit a8956d0
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 11 deletions.
2 changes: 2 additions & 0 deletions authentication/azure/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ func NewEnvironmentSettings(resourceName string, values map[string]string) (Envi
case "cosmosdb":
// Azure Cosmos DB (data plane)
es.Resource = "https://" + azureEnv.CosmosDBDNSSuffix
case "servicebus":
es.Resource = azureEnv.ResourceIdentifiers.ServiceBus
default:
return es, errors.New("invalid resource name: " + resourceName)
}
Expand Down
25 changes: 25 additions & 0 deletions authentication/azure/auth_amqp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// ------------------------------------------------------------
// Copyright (c) Microsoft Corporation and Dapr Contributors.
// Licensed under the MIT License.
// ------------------------------------------------------------

package azure

import "github.com/Azure/azure-amqp-common-go/v3/aad"

const (
AzureServiceBusResourceName string = "servicebus"
)

// GetTokenProvider creates a TokenProvider for AAD retrieved from, in order:
// 1. Client credentials
// 2. Client certificate
// 3. MSI.
func (s EnvironmentSettings) GetAADTokenProvider() (*aad.TokenProvider, error) {
spt, err := s.GetServicePrincipalToken()
if err != nil {
return nil, err
}

return aad.NewJWTProvider(aad.JWTProviderWithAADToken(spt), aad.JWTProviderWithAzureEnvironment(s.AzureEnvironment))
}
41 changes: 37 additions & 4 deletions bindings/azure/servicebusqueues/servicebusqueues.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ package servicebusqueues
import (
"context"
"encoding/json"
"errors"
"strings"
"sync/atomic"
"time"

servicebus "github.com/Azure/azure-service-bus-go"
"github.com/cenkalti/backoff/v4"

azauth "github.com/dapr/components-contrib/authentication/azure"
"github.com/dapr/components-contrib/bindings"
contrib_metadata "github.com/dapr/components-contrib/metadata"
"github.com/dapr/kit/logger"
Expand Down Expand Up @@ -43,6 +45,7 @@ type AzureServiceBusQueues struct {

type serviceBusQueuesMetadata struct {
ConnectionString string `json:"connectionString"`
NamespaceName string `json:"namespaceName,omitempty"`
QueueName string `json:"queueName"`
ttl time.Duration
}
Expand All @@ -61,10 +64,36 @@ func (a *AzureServiceBusQueues) Init(metadata bindings.Metadata) error {
userAgent := "dapr-" + logger.DaprVersion
a.metadata = meta

ns, err := servicebus.NewNamespace(servicebus.NamespaceWithConnectionString(a.metadata.ConnectionString),
servicebus.NamespaceWithUserAgent(userAgent))
if err != nil {
return err
var ns *servicebus.Namespace
if a.metadata.ConnectionString != "" {
ns, err = servicebus.NewNamespace(servicebus.NamespaceWithConnectionString(a.metadata.ConnectionString),
servicebus.NamespaceWithUserAgent(userAgent))
if err != nil {
return err
}
} else {
// Initialization code
settings, sErr := azauth.NewEnvironmentSettings(azauth.AzureServiceBusResourceName, metadata.Properties)
if sErr != nil {
return sErr
}

tokenProvider, tErr := settings.GetAADTokenProvider()
if tErr != nil {
return tErr
}

ns, err = servicebus.NewNamespace(servicebus.NamespaceWithTokenProvider(tokenProvider),
servicebus.NamespaceWithUserAgent(userAgent))
if err != nil {
return err
}

// We set these separately as the ServiceBus SDK does not provide a way to pass the environment via the options
// pattern unless you allow it to recreate the entire environment which seems wasteful.
ns.Name = a.metadata.NamespaceName
ns.Environment = *settings.AzureEnvironment
ns.Suffix = settings.AzureEnvironment.ServiceBusEndpointSuffix
}
a.ns = ns

Expand Down Expand Up @@ -124,6 +153,10 @@ func (a *AzureServiceBusQueues) parseMetadata(metadata bindings.Metadata) (*serv
return nil, err
}

if m.ConnectionString != "" && m.NamespaceName != "" {
return nil, errors.New("connectionString and namespaceName are mutually exclusive")
}

ttl, ok, err := contrib_metadata.TryGetTTL(metadata.Properties)
if err != nil {
return nil, err
Expand Down
52 changes: 52 additions & 0 deletions bindings/azure/servicebusqueues/servicebusqueues_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,55 @@ func TestParseMetadataWithInvalidTTL(t *testing.T) {
})
}
}

func TestParseMetadataConnectionStringAndNamespaceNameExclusivity(t *testing.T) {
testCases := []struct {
name string
properties map[string]string
expectedConnectionString string
expectedNamespaceName string
expectedQueueName string
expectedErr bool
}{
{
name: "ConnectionString and queue name",
properties: map[string]string{"connectionString": "connString", "queueName": "queue1"},
expectedConnectionString: "connString",
expectedNamespaceName: "",
expectedQueueName: "queue1",
expectedErr: false,
},
{
name: "Empty TTL",
properties: map[string]string{"namespaceName": "testNamespace", "queueName": "queue1", metadata.TTLMetadataKey: ""},
expectedConnectionString: "",
expectedNamespaceName: "testNamespace",
expectedQueueName: "queue1",
expectedErr: false,
},
{
name: "With TTL",
properties: map[string]string{"connectionString": "connString", "namespaceName": "testNamespace", "queueName": "queue1", metadata.TTLMetadataKey: "1"},
expectedConnectionString: "",
expectedNamespaceName: "",
expectedQueueName: "queue1",
expectedErr: true,
},
}

for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
m := bindings.Metadata{}
m.Properties = tt.properties
a := NewAzureServiceBusQueues(logger.NewLogger("test"))
meta, err := a.parseMetadata(m)
if tt.expectedErr {
assert.NotNil(t, err)
} else {
assert.Equal(t, tt.expectedConnectionString, meta.ConnectionString)
assert.Equal(t, tt.expectedQueueName, meta.QueueName)
assert.Equal(t, tt.expectedNamespaceName, meta.NamespaceName)
}
})
}
}
1 change: 1 addition & 0 deletions pubsub/azure/servicebus/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ type metadata struct {
PrefetchCount *int `json:"prefetchCount"`
PublishMaxRetries int `json:"publishMaxRetries"`
PublishInitialRetryIntervalInMs int `json:"publishInitialRetryInternalInMs"`
NamespaceName string `json:"namespaceName,omitempty"`
}
46 changes: 40 additions & 6 deletions pubsub/azure/servicebus/servicebus.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (

azservicebus "github.com/Azure/azure-service-bus-go"

azauth "github.com/dapr/components-contrib/authentication/azure"
"github.com/dapr/components-contrib/pubsub"
"github.com/dapr/kit/logger"
"github.com/dapr/kit/retry"
Expand All @@ -43,6 +44,7 @@ const (
connectionRecoveryInSec = "connectionRecoveryInSec"
publishMaxRetries = "publishMaxRetries"
publishInitialRetryInternalInMs = "publishInitialRetryInternalInMs"
namespaceName = "namespaceName"
errorMessagePrefix = "azure service bus error:"

// Defaults.
Expand Down Expand Up @@ -93,8 +95,15 @@ func parseAzureServiceBusMetadata(meta pubsub.Metadata) (metadata, error) {
/* Required configuration settings - no defaults. */
if val, ok := meta.Properties[connectionString]; ok && val != "" {
m.ConnectionString = val

// The connection string and the namespace cannot both be present.
if namespace, present := meta.Properties[namespaceName]; present && namespace != "" {
return m, fmt.Errorf("%s connectionString and namespaceName cannot both be specified", errorMessagePrefix)
}
} else if val, ok := meta.Properties[namespaceName]; ok && val != "" {
m.NamespaceName = val
} else {
return m, fmt.Errorf("%s missing connection string", errorMessagePrefix)
return m, fmt.Errorf("%s missing connection string and namespace name", errorMessagePrefix)
}

if val, ok := meta.Properties[consumerID]; ok && val != "" {
Expand Down Expand Up @@ -258,12 +267,37 @@ func (a *azureServiceBus) Init(metadata pubsub.Metadata) error {

userAgent := "dapr-" + logger.DaprVersion
a.metadata = m
a.namespace, err = azservicebus.NewNamespace(
azservicebus.NamespaceWithConnectionString(a.metadata.ConnectionString),
azservicebus.NamespaceWithUserAgent(userAgent))
if a.metadata.ConnectionString != "" {
a.namespace, err = azservicebus.NewNamespace(
azservicebus.NamespaceWithConnectionString(a.metadata.ConnectionString),
azservicebus.NamespaceWithUserAgent(userAgent))

if err != nil {
return err
if err != nil {
return err
}
} else {
// Initialization code
settings, err := azauth.NewEnvironmentSettings(azauth.AzureServiceBusResourceName, metadata.Properties)
if err != nil {
return err
}

tokenProvider, err := settings.GetAADTokenProvider()
if err != nil {
return err
}

a.namespace, err = azservicebus.NewNamespace(azservicebus.NamespaceWithTokenProvider(tokenProvider),
azservicebus.NamespaceWithUserAgent(userAgent))
if err != nil {
return err
}

// We set these separately as the ServiceBus SDK does not provide a way to pass the environment via the options
// pattern unless you allow it to recreate the entire environment which seems wasteful.
a.namespace.Name = a.metadata.NamespaceName
a.namespace.Environment = *settings.AzureEnvironment
a.namespace.Suffix = settings.AzureEnvironment.ServiceBusEndpointSuffix
}

a.topicManager = a.namespace.NewTopicManager()
Expand Down
54 changes: 53 additions & 1 deletion pubsub/azure/servicebus/servicebus_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const (
func getFakeProperties() map[string]string {
return map[string]string{
connectionString: "fakeConnectionString",
namespaceName: "",
consumerID: "fakeConId",
disableEntityManagement: "true",
timeoutInSec: "90",
Expand Down Expand Up @@ -82,13 +83,14 @@ func TestParseServiceBusMetadata(t *testing.T) {
assert.Equal(t, 10, *m.PrefetchCount)
})

t.Run("missing required connectionString", func(t *testing.T) {
t.Run("missing required connectionString and namespaceName", func(t *testing.T) {
fakeProperties := getFakeProperties()

fakeMetaData := pubsub.Metadata{
Properties: fakeProperties,
}
fakeMetaData.Properties[connectionString] = ""
fakeMetaData.Properties[namespaceName] = ""

// act.
m, err := parseAzureServiceBusMetadata(fakeMetaData)
Expand All @@ -99,6 +101,56 @@ func TestParseServiceBusMetadata(t *testing.T) {
assert.Empty(t, m.ConnectionString)
})

t.Run("connectionString makes namespace optional", func(t *testing.T) {
fakeProperties := getFakeProperties()

fakeMetaData := pubsub.Metadata{
Properties: fakeProperties,
}
fakeMetaData.Properties[namespaceName] = ""

// act.
m, err := parseAzureServiceBusMetadata(fakeMetaData)

// assert.
assert.NoError(t, err)
assert.Equal(t, "fakeConnectionString", m.ConnectionString)
})

t.Run("namespace makes conectionString optional", func(t *testing.T) {
fakeProperties := getFakeProperties()

fakeMetaData := pubsub.Metadata{
Properties: fakeProperties,
}
fakeMetaData.Properties[namespaceName] = "fakeNamespace"
fakeMetaData.Properties[connectionString] = ""

// act.
m, err := parseAzureServiceBusMetadata(fakeMetaData)

// assert.
assert.NoError(t, err)
assert.Equal(t, "fakeNamespace", m.NamespaceName)
})

t.Run("connectionString and namespace are mutually exclusive", func(t *testing.T) {
fakeProperties := getFakeProperties()

fakeMetaData := pubsub.Metadata{
Properties: fakeProperties,
}

fakeMetaData.Properties[namespaceName] = "fakeNamespace"

// act.
_, err := parseAzureServiceBusMetadata(fakeMetaData)

// assert.
assert.Error(t, err)
assertValidErrorMessage(t, err)
})

t.Run("missing required consumerID", func(t *testing.T) {
fakeProperties := getFakeProperties()

Expand Down

0 comments on commit a8956d0

Please sign in to comment.