From d204388a87fb2a8bfc2c103080a66a2223393b44 Mon Sep 17 00:00:00 2001 From: Ruben Vermeersch Date: Thu, 9 May 2024 17:35:53 +0200 Subject: [PATCH] Allow configuring a custom marshaller (#10) * Allow configuring a custom marshaller Fixes #9 * Add marshalling unit test. --- channel.go | 13 ++++++++++-- client.go | 5 +++++ client_options.go | 11 ++++++++++ connection.go | 30 ++++++++++++++++++++++----- connection_manager.go | 17 ++++++++++++---- constants.go | 2 ++ manager.go | 12 +++++++++-- manager_options.go | 25 ++++++++++++++++------- marshalling.go | 47 +++++++++++++++++++++++++++++++++++++++++++ marshalling_test.go | 31 ++++++++++++++++++++++++++++ 10 files changed, 173 insertions(+), 20 deletions(-) create mode 100644 marshalling.go create mode 100644 marshalling_test.go diff --git a/channel.go b/channel.go index 8cb4d8b..b5732f9 100644 --- a/channel.go +++ b/channel.go @@ -77,6 +77,9 @@ type amqpChannel struct { // connectionType defines the connectionType. connectionType connectionType + + // marshaller defines the marshalling method used to encode messages. + marshaller Marshaller } // newConsumerChannel instantiates a new consumerChannel and amqpChannel for method inheritance. @@ -87,6 +90,7 @@ type amqpChannel struct { // - consumer is the MessageConsumer that will hold consumption information. // - maxRetry is the retry header for each message. // - logger is the parent logger. +// - marshaller is the Marshaller used for encoding messages. func newConsumerChannel( ctx context.Context, connection *amqp.Connection, @@ -94,6 +98,7 @@ func newConsumerChannel( retryDelay time.Duration, consumer *MessageConsumer, logger logger, + marshaller Marshaller, ) *amqpChannel { channel := &amqpChannel{ ctx: ctx, @@ -119,6 +124,7 @@ func newConsumerChannel( connectionType: connectionTypeConsumer, consumptionHealth: make(consumptionHealth), consumer: consumer, + marshaller: marshaller, } // We open an initial channel. @@ -141,6 +147,7 @@ func newConsumerChannel( // - publishingCacheSize is the maximum cache size of failed publishing. // - publishingCacheTTL defines the time to live for each failed publishing that was put in cache. // - logger is the parent logger. +// - marshaller is the Marshaller used for encoding messages. func newPublishingChannel( ctx context.Context, connection *amqp.Connection, @@ -150,6 +157,7 @@ func newPublishingChannel( publishingCacheSize uint64, publishingCacheTTL time.Duration, logger logger, + marshaller Marshaller, ) *amqpChannel { channel := &amqpChannel{ ctx: ctx, @@ -171,6 +179,7 @@ func newPublishingChannel( connectionType: connectionTypePublisher, publishingCache: newTTLMap[string, mqttPublishing](publishingCacheSize, publishingCacheTTL), maxRetry: maxRetry, + marshaller: marshaller, } // We open an initial channel. @@ -521,7 +530,7 @@ func (c *amqpChannel) retryDelivery(delivery *amqp.Delivery, alreadyAcknowledged // We create a new publishing which is a copy of the old one but with a decremented xDeathCountHeader. newPublishing := amqp.Publishing{ - ContentType: "application/json", + ContentType: c.marshaller.ContentType(), Body: delivery.Body, Type: delivery.RoutingKey, Priority: delivery.Priority, @@ -554,7 +563,7 @@ func (c *amqpChannel) retryDelivery(delivery *amqp.Delivery, alreadyAcknowledged // publish will publish a message with the given configuration. func (c *amqpChannel) publish(exchange string, routingKey string, payload []byte, options *PublishingOptions) error { publishing := &amqp.Publishing{ - ContentType: "application/json", + ContentType: c.marshaller.ContentType(), Body: payload, Type: routingKey, Priority: PriorityMedium.Uint8(), diff --git a/client.go b/client.go index 8a0baef..58408b9 100644 --- a/client.go +++ b/client.go @@ -151,6 +151,10 @@ func newClientFromOptions(options *ClientOptions) MQTTClient { protocol = securedProtocol } + if options.Marshaller == nil { + options.Marshaller = defaultMarshaller + } + dialURL := fmt.Sprintf("%s://%s:%s@%s:%d/%s", protocol, client.Username, client.Password, client.Host, client.Port, client.Vhost) client.connectionManager = newConnectionManager( @@ -163,6 +167,7 @@ func newClientFromOptions(options *ClientOptions) MQTTClient { options.PublishingCacheSize, options.PublishingCacheTTL, client.logger, + options.Marshaller, ) return client diff --git a/client_options.go b/client_options.go index 7b67544..ba8ce10 100644 --- a/client_options.go +++ b/client_options.go @@ -46,6 +46,9 @@ type ClientOptions struct { // Mode will specify whether logs are enabled or not. Mode string + + // Marshaller defines the content type used for messages and how they're marshalled (default: JSON). + Marshaller Marshaller } // DefaultClientOptions will return a ClientOptions with default values. @@ -63,6 +66,7 @@ func DefaultClientOptions() *ClientOptions { PublishingCacheTTL: defaultPublishingCacheTTL, PublishingCacheSize: defaultPublishingCacheSize, Mode: defaultMode, + Marshaller: defaultMarshaller, } } @@ -195,3 +199,10 @@ func (c *ClientOptions) SetMode(mode string) *ClientOptions { return c } + +// SetMarshaller will assign the Marshaller. +func (c *ClientOptions) SetMarshaller(marshaller Marshaller) *ClientOptions { + c.Marshaller = marshaller + + return c +} diff --git a/connection.go b/connection.go index 626d45c..f7c1edd 100644 --- a/connection.go +++ b/connection.go @@ -48,6 +48,9 @@ type amqpConnection struct { // connectionType defines the connectionType. connectionType connectionType + + // marshaller defines the marshalling method used to encode messages. + marshaller Marshaller } // newConsumerConnection initializes a new consumer amqpConnection with given arguments. @@ -57,8 +60,17 @@ type amqpConnection struct { // - keepAlive will keep the connection alive if true. // - retryDelay defines the delay between each re-connection, if the keepAlive flag is set to true. // - logger is the parent logger. -func newConsumerConnection(ctx context.Context, uri, connectionName string, keepAlive bool, retryDelay time.Duration, logger logger) *amqpConnection { - return newConnection(ctx, uri, connectionName, keepAlive, retryDelay, logger, connectionTypeConsumer) +// - marshaller is the Marshaller used for encoding messages. +func newConsumerConnection( + ctx context.Context, + uri string, + connectionName string, + keepAlive bool, + retryDelay time.Duration, + logger logger, + marshaller Marshaller, +) *amqpConnection { + return newConnection(ctx, uri, connectionName, keepAlive, retryDelay, logger, connectionTypeConsumer, marshaller) } // newPublishingConnection initializes a new publisher amqpConnection with given arguments. @@ -71,6 +83,7 @@ func newConsumerConnection(ctx context.Context, uri, connectionName string, keep // - publishingCacheSize defines the maximum length of failed publishing cache. // - publishingCacheTTL defines the time to live for failed publishing in cache. // - logger is the parent logger. +// - marshaller is the Marshaller used for encoding messages. func newPublishingConnection( ctx context.Context, uri string, @@ -81,8 +94,9 @@ func newPublishingConnection( publishingCacheSize uint64, publishingCacheTTL time.Duration, logger logger, + marshaller Marshaller, ) *amqpConnection { - conn := newConnection(ctx, uri, connectionName, keepAlive, retryDelay, logger, connectionTypePublisher) + conn := newConnection(ctx, uri, connectionName, keepAlive, retryDelay, logger, connectionTypePublisher, marshaller) conn.maxRetry = maxRetry conn.publishingCacheSize = publishingCacheSize @@ -98,6 +112,7 @@ func newPublishingConnection( // - keepAlive will keep the connection alive if true. // - retryDelay defines the delay between each re-connection, if the keepAlive flag is set to true. // - logger is the parent logger. +// - marshaller is the Marshaller used for encoding messages. func newConnection( ctx context.Context, uri string, @@ -106,6 +121,7 @@ func newConnection( retryDelay time.Duration, logger logger, connectionType connectionType, + marshaller Marshaller, ) *amqpConnection { conn := &amqpConnection{ ctx: ctx, @@ -119,6 +135,7 @@ func newConnection( "type": connectionType, }), connectionType: connectionType, + marshaller: marshaller, } conn.logger.Debug("Initializing new amqp connection", logField{Key: "uri", Value: conn.uriForLog()}) @@ -303,7 +320,7 @@ func (a *amqpConnection) registerConsumer(consumer MessageConsumer) error { return err } - channel := newConsumerChannel(a.ctx, a.connection, a.keepAlive, a.retryDelay, &consumer, a.logger) + channel := newConsumerChannel(a.ctx, a.connection, a.keepAlive, a.retryDelay, &consumer, a.logger, a.marshaller) a.channels = append(a.channels, channel) @@ -315,7 +332,10 @@ func (a *amqpConnection) registerConsumer(consumer MessageConsumer) error { func (a *amqpConnection) publish(exchange, routingKey string, payload []byte, options *PublishingOptions) error { publishingChannel := a.channels.publishingChannel() if publishingChannel == nil { - publishingChannel = newPublishingChannel(a.ctx, a.connection, a.keepAlive, a.retryDelay, a.maxRetry, a.publishingCacheSize, a.publishingCacheTTL, a.logger) + publishingChannel = newPublishingChannel( + a.ctx, a.connection, a.keepAlive, a.retryDelay, a.maxRetry, + a.publishingCacheSize, a.publishingCacheTTL, a.logger, a.marshaller, + ) a.channels = append(a.channels, publishingChannel) } diff --git a/connection_manager.go b/connection_manager.go index bd71413..cfc7b71 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -2,7 +2,6 @@ package gorabbit import ( "context" - "encoding/json" "time" ) @@ -12,6 +11,9 @@ type connectionManager struct { // publisherConnection holds the independent publishing connection. publisherConnection *amqpConnection + + // marshaller holds the marshaller used to encode messages. + marshaller Marshaller } // newConnectionManager instantiates a new connectionManager with given arguments. @@ -25,10 +27,17 @@ func newConnectionManager( publishingCacheSize uint64, publishingCacheTTL time.Duration, logger logger, + marshaller Marshaller, ) *connectionManager { c := &connectionManager{ - consumerConnection: newConsumerConnection(ctx, uri, connectionName, keepAlive, retryDelay, logger), - publisherConnection: newPublishingConnection(ctx, uri, connectionName, keepAlive, retryDelay, maxRetry, publishingCacheSize, publishingCacheTTL, logger), + consumerConnection: newConsumerConnection( + ctx, uri, connectionName, keepAlive, retryDelay, logger, marshaller, + ), + publisherConnection: newPublishingConnection( + ctx, uri, connectionName, keepAlive, retryDelay, maxRetry, + publishingCacheSize, publishingCacheTTL, logger, marshaller, + ), + marshaller: marshaller, } return c @@ -75,7 +84,7 @@ func (c *connectionManager) publish(exchange, routingKey string, payload interfa return errPublisherConnectionNotInitialized } - payloadBytes, err := json.Marshal(payload) + payloadBytes, err := c.marshaller.Marshal(payload) if err != nil { return err } diff --git a/constants.go b/constants.go index 95861cd..8b19cd4 100644 --- a/constants.go +++ b/constants.go @@ -30,6 +30,8 @@ const ( defaultMode = Release ) +var defaultMarshaller = NewJSONMarshaller() + // Default values for the amqp Config. const ( defaultHeartbeat = 10 * time.Second diff --git a/manager.go b/manager.go index 375f7b4..54e9a8e 100644 --- a/manager.go +++ b/manager.go @@ -101,6 +101,9 @@ type mqttManager struct { // channel holds the single channel from the connection. channel *amqp.Channel + + // marshaller holds the marshaller used to encode messages. + marshaller Marshaller } // NewManager will instantiate a new MQTTManager. @@ -158,6 +161,11 @@ func newManagerFromOptions(options *ManagerOptions) (MQTTManager, error) { protocol = securedProtocol } + if options.Marshaller == nil { + options.Marshaller = defaultMarshaller + } + manager.marshaller = options.Marshaller + dialURL := fmt.Sprintf("%s://%s:%s@%s:%d/%s", protocol, manager.Username, manager.Password, manager.Host, manager.Port, manager.Vhost) var err error @@ -320,14 +328,14 @@ func (manager *mqttManager) PushMessageToExchange(exchange, routingKey string, p } // We convert the payload to a []byte. - payloadBytes, err := json.Marshal(payload) + payloadBytes, err := manager.marshaller.Marshal(payload) if err != nil { return err } // We build the amqp.Publishing object. publishing := amqp.Publishing{ - ContentType: "application/json", + ContentType: manager.marshaller.ContentType(), Body: payloadBytes, Type: routingKey, Priority: PriorityMedium.Uint8(), diff --git a/manager_options.go b/manager_options.go index 79f4bd0..7fc10cb 100644 --- a/manager_options.go +++ b/manager_options.go @@ -24,18 +24,22 @@ type ManagerOptions struct { // Mode will specify whether logs are enabled or not. Mode string + + // Marshaller defines the content type used for messages and how they're marshalled (default: JSON). + Marshaller Marshaller } // DefaultManagerOptions will return a ManagerOptions with default values. func DefaultManagerOptions() *ManagerOptions { return &ManagerOptions{ - Host: defaultHost, - Port: defaultPort, - Username: defaultUsername, - Password: defaultPassword, - Vhost: defaultVhost, - UseTLS: defaultUseTLS, - Mode: defaultMode, + Host: defaultHost, + Port: defaultPort, + Username: defaultUsername, + Password: defaultPassword, + Vhost: defaultVhost, + UseTLS: defaultUseTLS, + Mode: defaultMode, + Marshaller: defaultMarshaller, } } @@ -126,3 +130,10 @@ func (m *ManagerOptions) SetMode(mode string) *ManagerOptions { return m } + +// SetMarshaller will assign the Marshaller. +func (m *ManagerOptions) SetMarshaller(marshaller Marshaller) *ManagerOptions { + m.Marshaller = marshaller + + return m +} diff --git a/marshalling.go b/marshalling.go new file mode 100644 index 0000000..15bae3b --- /dev/null +++ b/marshalling.go @@ -0,0 +1,47 @@ +package gorabbit + +import ( + "encoding/json" + "fmt" +) + +type Marshaller interface { + ContentType() string + Marshal(data any) ([]byte, error) +} + +type marshaller struct { + contentType string + marshal func(data any) ([]byte, error) +} + +func (m *marshaller) ContentType() string { + return m.contentType +} + +func (m *marshaller) Marshal(data any) ([]byte, error) { + return m.marshal(data) +} + +func NewJSONMarshaller() Marshaller { + return &marshaller{ + contentType: "application/json", + marshal: json.Marshal, + } +} + +func NewTextMarshaller() Marshaller { + return &marshaller{ + contentType: "text/plain", + marshal: func(data any) ([]byte, error) { + switch s := data.(type) { + case string: + return []byte(s), nil + case []byte: + return s, nil + default: + return nil, fmt.Errorf("cannot marshal %T as text", data) + } + }, + } +} diff --git a/marshalling_test.go b/marshalling_test.go new file mode 100644 index 0000000..497a5dd --- /dev/null +++ b/marshalling_test.go @@ -0,0 +1,31 @@ +package gorabbit_test + +import ( + "testing" + + "github.com/KardinalAI/gorabbit" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestJSONMarshaller(t *testing.T) { + m := gorabbit.NewJSONMarshaller() + assert.NotNil(t, m) + + assert.Equal(t, "application/json", m.ContentType()) + + data, err := m.Marshal("test") + require.NoError(t, err) + assert.Equal(t, []byte(`"test"`), data) +} + +func TestTextMarshaller(t *testing.T) { + m := gorabbit.NewTextMarshaller() + assert.NotNil(t, m) + + assert.Equal(t, "text/plain", m.ContentType()) + + data, err := m.Marshal("test") + require.NoError(t, err) + assert.Equal(t, []byte(`test`), data) +}