diff --git a/authentication/azure/auth.go b/authentication/azure/auth.go index 23ba58a435..296b1a7727 100644 --- a/authentication/azure/auth.go +++ b/authentication/azure/auth.go @@ -44,6 +44,8 @@ func NewEnvironmentSettings(resourceName string, values map[string]string) (Envi case "cosmosdb": // Azure Cosmos DB (data plane) es.Resource = "https://" + azureEnv.CosmosDBDNSSuffix + case "servicebus": + es.Resource = azureEnv.ResourceIdentifiers.ServiceBus default: return es, errors.New("invalid resource name: " + resourceName) } diff --git a/authentication/azure/auth_amqp.go b/authentication/azure/auth_amqp.go new file mode 100644 index 0000000000..85d4db6460 --- /dev/null +++ b/authentication/azure/auth_amqp.go @@ -0,0 +1,25 @@ +// ------------------------------------------------------------ +// Copyright (c) Microsoft Corporation and Dapr Contributors. +// Licensed under the MIT License. +// ------------------------------------------------------------ + +package azure + +import "github.com/Azure/azure-amqp-common-go/v3/aad" + +const ( + AzureServiceBusResourceName string = "servicebus" +) + +// GetTokenProvider creates a TokenProvider for AAD retrieved from, in order: +// 1. Client credentials +// 2. Client certificate +// 3. MSI. +func (s EnvironmentSettings) GetAADTokenProvider() (*aad.TokenProvider, error) { + spt, err := s.GetServicePrincipalToken() + if err != nil { + return nil, err + } + + return aad.NewJWTProvider(aad.JWTProviderWithAADToken(spt), aad.JWTProviderWithAzureEnvironment(s.AzureEnvironment)) +} diff --git a/bindings/azure/servicebusqueues/servicebusqueues.go b/bindings/azure/servicebusqueues/servicebusqueues.go index afd3eb04dc..04e0556beb 100644 --- a/bindings/azure/servicebusqueues/servicebusqueues.go +++ b/bindings/azure/servicebusqueues/servicebusqueues.go @@ -8,6 +8,7 @@ package servicebusqueues import ( "context" "encoding/json" + "errors" "strings" "sync/atomic" "time" @@ -15,6 +16,7 @@ import ( servicebus "github.com/Azure/azure-service-bus-go" "github.com/cenkalti/backoff/v4" + azauth "github.com/dapr/components-contrib/authentication/azure" "github.com/dapr/components-contrib/bindings" contrib_metadata "github.com/dapr/components-contrib/metadata" "github.com/dapr/kit/logger" @@ -43,6 +45,7 @@ type AzureServiceBusQueues struct { type serviceBusQueuesMetadata struct { ConnectionString string `json:"connectionString"` + NamespaceName string `json:"namespaceName,omitempty"` QueueName string `json:"queueName"` ttl time.Duration } @@ -61,10 +64,36 @@ func (a *AzureServiceBusQueues) Init(metadata bindings.Metadata) error { userAgent := "dapr-" + logger.DaprVersion a.metadata = meta - ns, err := servicebus.NewNamespace(servicebus.NamespaceWithConnectionString(a.metadata.ConnectionString), - servicebus.NamespaceWithUserAgent(userAgent)) - if err != nil { - return err + var ns *servicebus.Namespace + if a.metadata.ConnectionString != "" { + ns, err = servicebus.NewNamespace(servicebus.NamespaceWithConnectionString(a.metadata.ConnectionString), + servicebus.NamespaceWithUserAgent(userAgent)) + if err != nil { + return err + } + } else { + // Initialization code + settings, sErr := azauth.NewEnvironmentSettings(azauth.AzureServiceBusResourceName, metadata.Properties) + if sErr != nil { + return sErr + } + + tokenProvider, tErr := settings.GetAADTokenProvider() + if tErr != nil { + return tErr + } + + ns, err = servicebus.NewNamespace(servicebus.NamespaceWithTokenProvider(tokenProvider), + servicebus.NamespaceWithUserAgent(userAgent)) + if err != nil { + return err + } + + // We set these separately as the ServiceBus SDK does not provide a way to pass the environment via the options + // pattern unless you allow it to recreate the entire environment which seems wasteful. + ns.Name = a.metadata.NamespaceName + ns.Environment = *settings.AzureEnvironment + ns.Suffix = settings.AzureEnvironment.ServiceBusEndpointSuffix } a.ns = ns @@ -124,6 +153,10 @@ func (a *AzureServiceBusQueues) parseMetadata(metadata bindings.Metadata) (*serv return nil, err } + if m.ConnectionString != "" && m.NamespaceName != "" { + return nil, errors.New("connectionString and namespaceName are mutually exclusive") + } + ttl, ok, err := contrib_metadata.TryGetTTL(metadata.Properties) if err != nil { return nil, err diff --git a/bindings/azure/servicebusqueues/servicebusqueues_test.go b/bindings/azure/servicebusqueues/servicebusqueues_test.go index e1ca065d36..abdf3a34aa 100644 --- a/bindings/azure/servicebusqueues/servicebusqueues_test.go +++ b/bindings/azure/servicebusqueues/servicebusqueues_test.go @@ -93,3 +93,55 @@ func TestParseMetadataWithInvalidTTL(t *testing.T) { }) } } + +func TestParseMetadataConnectionStringAndNamespaceNameExclusivity(t *testing.T) { + testCases := []struct { + name string + properties map[string]string + expectedConnectionString string + expectedNamespaceName string + expectedQueueName string + expectedErr bool + }{ + { + name: "ConnectionString and queue name", + properties: map[string]string{"connectionString": "connString", "queueName": "queue1"}, + expectedConnectionString: "connString", + expectedNamespaceName: "", + expectedQueueName: "queue1", + expectedErr: false, + }, + { + name: "Empty TTL", + properties: map[string]string{"namespaceName": "testNamespace", "queueName": "queue1", metadata.TTLMetadataKey: ""}, + expectedConnectionString: "", + expectedNamespaceName: "testNamespace", + expectedQueueName: "queue1", + expectedErr: false, + }, + { + name: "With TTL", + properties: map[string]string{"connectionString": "connString", "namespaceName": "testNamespace", "queueName": "queue1", metadata.TTLMetadataKey: "1"}, + expectedConnectionString: "", + expectedNamespaceName: "", + expectedQueueName: "queue1", + expectedErr: true, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + m := bindings.Metadata{} + m.Properties = tt.properties + a := NewAzureServiceBusQueues(logger.NewLogger("test")) + meta, err := a.parseMetadata(m) + if tt.expectedErr { + assert.NotNil(t, err) + } else { + assert.Equal(t, tt.expectedConnectionString, meta.ConnectionString) + assert.Equal(t, tt.expectedQueueName, meta.QueueName) + assert.Equal(t, tt.expectedNamespaceName, meta.NamespaceName) + } + }) + } +} diff --git a/pubsub/azure/servicebus/metadata.go b/pubsub/azure/servicebus/metadata.go index 304df323f2..ab5093da2a 100644 --- a/pubsub/azure/servicebus/metadata.go +++ b/pubsub/azure/servicebus/metadata.go @@ -26,4 +26,5 @@ type metadata struct { PrefetchCount *int `json:"prefetchCount"` PublishMaxRetries int `json:"publishMaxRetries"` PublishInitialRetryIntervalInMs int `json:"publishInitialRetryInternalInMs"` + NamespaceName string `json:"namespaceName,omitempty"` } diff --git a/pubsub/azure/servicebus/servicebus.go b/pubsub/azure/servicebus/servicebus.go index fb30f9c299..33b389368d 100644 --- a/pubsub/azure/servicebus/servicebus.go +++ b/pubsub/azure/servicebus/servicebus.go @@ -18,6 +18,7 @@ import ( azservicebus "github.com/Azure/azure-service-bus-go" + azauth "github.com/dapr/components-contrib/authentication/azure" "github.com/dapr/components-contrib/pubsub" "github.com/dapr/kit/logger" "github.com/dapr/kit/retry" @@ -43,6 +44,7 @@ const ( connectionRecoveryInSec = "connectionRecoveryInSec" publishMaxRetries = "publishMaxRetries" publishInitialRetryInternalInMs = "publishInitialRetryInternalInMs" + namespaceName = "namespaceName" errorMessagePrefix = "azure service bus error:" // Defaults. @@ -93,8 +95,15 @@ func parseAzureServiceBusMetadata(meta pubsub.Metadata) (metadata, error) { /* Required configuration settings - no defaults. */ if val, ok := meta.Properties[connectionString]; ok && val != "" { m.ConnectionString = val + + // The connection string and the namespace cannot both be present. + if namespace, present := meta.Properties[namespaceName]; present && namespace != "" { + return m, fmt.Errorf("%s connectionString and namespaceName cannot both be specified", errorMessagePrefix) + } + } else if val, ok := meta.Properties[namespaceName]; ok && val != "" { + m.NamespaceName = val } else { - return m, fmt.Errorf("%s missing connection string", errorMessagePrefix) + return m, fmt.Errorf("%s missing connection string and namespace name", errorMessagePrefix) } if val, ok := meta.Properties[consumerID]; ok && val != "" { @@ -258,12 +267,37 @@ func (a *azureServiceBus) Init(metadata pubsub.Metadata) error { userAgent := "dapr-" + logger.DaprVersion a.metadata = m - a.namespace, err = azservicebus.NewNamespace( - azservicebus.NamespaceWithConnectionString(a.metadata.ConnectionString), - azservicebus.NamespaceWithUserAgent(userAgent)) + if a.metadata.ConnectionString != "" { + a.namespace, err = azservicebus.NewNamespace( + azservicebus.NamespaceWithConnectionString(a.metadata.ConnectionString), + azservicebus.NamespaceWithUserAgent(userAgent)) - if err != nil { - return err + if err != nil { + return err + } + } else { + // Initialization code + settings, err := azauth.NewEnvironmentSettings(azauth.AzureServiceBusResourceName, metadata.Properties) + if err != nil { + return err + } + + tokenProvider, err := settings.GetAADTokenProvider() + if err != nil { + return err + } + + a.namespace, err = azservicebus.NewNamespace(azservicebus.NamespaceWithTokenProvider(tokenProvider), + azservicebus.NamespaceWithUserAgent(userAgent)) + if err != nil { + return err + } + + // We set these separately as the ServiceBus SDK does not provide a way to pass the environment via the options + // pattern unless you allow it to recreate the entire environment which seems wasteful. + a.namespace.Name = a.metadata.NamespaceName + a.namespace.Environment = *settings.AzureEnvironment + a.namespace.Suffix = settings.AzureEnvironment.ServiceBusEndpointSuffix } a.topicManager = a.namespace.NewTopicManager() diff --git a/pubsub/azure/servicebus/servicebus_test.go b/pubsub/azure/servicebus/servicebus_test.go index cac7f6bb84..6fb082672e 100644 --- a/pubsub/azure/servicebus/servicebus_test.go +++ b/pubsub/azure/servicebus/servicebus_test.go @@ -20,6 +20,7 @@ const ( func getFakeProperties() map[string]string { return map[string]string{ connectionString: "fakeConnectionString", + namespaceName: "", consumerID: "fakeConId", disableEntityManagement: "true", timeoutInSec: "90", @@ -82,13 +83,14 @@ func TestParseServiceBusMetadata(t *testing.T) { assert.Equal(t, 10, *m.PrefetchCount) }) - t.Run("missing required connectionString", func(t *testing.T) { + t.Run("missing required connectionString and namespaceName", func(t *testing.T) { fakeProperties := getFakeProperties() fakeMetaData := pubsub.Metadata{ Properties: fakeProperties, } fakeMetaData.Properties[connectionString] = "" + fakeMetaData.Properties[namespaceName] = "" // act. m, err := parseAzureServiceBusMetadata(fakeMetaData) @@ -99,6 +101,56 @@ func TestParseServiceBusMetadata(t *testing.T) { assert.Empty(t, m.ConnectionString) }) + t.Run("connectionString makes namespace optional", func(t *testing.T) { + fakeProperties := getFakeProperties() + + fakeMetaData := pubsub.Metadata{ + Properties: fakeProperties, + } + fakeMetaData.Properties[namespaceName] = "" + + // act. + m, err := parseAzureServiceBusMetadata(fakeMetaData) + + // assert. + assert.NoError(t, err) + assert.Equal(t, "fakeConnectionString", m.ConnectionString) + }) + + t.Run("namespace makes conectionString optional", func(t *testing.T) { + fakeProperties := getFakeProperties() + + fakeMetaData := pubsub.Metadata{ + Properties: fakeProperties, + } + fakeMetaData.Properties[namespaceName] = "fakeNamespace" + fakeMetaData.Properties[connectionString] = "" + + // act. + m, err := parseAzureServiceBusMetadata(fakeMetaData) + + // assert. + assert.NoError(t, err) + assert.Equal(t, "fakeNamespace", m.NamespaceName) + }) + + t.Run("connectionString and namespace are mutually exclusive", func(t *testing.T) { + fakeProperties := getFakeProperties() + + fakeMetaData := pubsub.Metadata{ + Properties: fakeProperties, + } + + fakeMetaData.Properties[namespaceName] = "fakeNamespace" + + // act. + _, err := parseAzureServiceBusMetadata(fakeMetaData) + + // assert. + assert.Error(t, err) + assertValidErrorMessage(t, err) + }) + t.Run("missing required consumerID", func(t *testing.T) { fakeProperties := getFakeProperties()