Skip to content

Commit

Permalink
Allow configuring a custom marshaller (#10)
Browse files Browse the repository at this point in the history
* Allow configuring a custom marshaller

Fixes #9

* Add marshalling unit test.
  • Loading branch information
rubenv committed May 9, 2024
1 parent 0c5e554 commit d204388
Show file tree
Hide file tree
Showing 10 changed files with 173 additions and 20 deletions.
13 changes: 11 additions & 2 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ type amqpChannel struct {

// connectionType defines the connectionType.
connectionType connectionType

// marshaller defines the marshalling method used to encode messages.
marshaller Marshaller
}

// newConsumerChannel instantiates a new consumerChannel and amqpChannel for method inheritance.
Expand All @@ -87,13 +90,15 @@ type amqpChannel struct {
// - consumer is the MessageConsumer that will hold consumption information.
// - maxRetry is the retry header for each message.
// - logger is the parent logger.
// - marshaller is the Marshaller used for encoding messages.
func newConsumerChannel(
ctx context.Context,
connection *amqp.Connection,
keepAlive bool,
retryDelay time.Duration,
consumer *MessageConsumer,
logger logger,
marshaller Marshaller,
) *amqpChannel {
channel := &amqpChannel{
ctx: ctx,
Expand All @@ -119,6 +124,7 @@ func newConsumerChannel(
connectionType: connectionTypeConsumer,
consumptionHealth: make(consumptionHealth),
consumer: consumer,
marshaller: marshaller,
}

// We open an initial channel.
Expand All @@ -141,6 +147,7 @@ func newConsumerChannel(
// - publishingCacheSize is the maximum cache size of failed publishing.
// - publishingCacheTTL defines the time to live for each failed publishing that was put in cache.
// - logger is the parent logger.
// - marshaller is the Marshaller used for encoding messages.
func newPublishingChannel(
ctx context.Context,
connection *amqp.Connection,
Expand All @@ -150,6 +157,7 @@ func newPublishingChannel(
publishingCacheSize uint64,
publishingCacheTTL time.Duration,
logger logger,
marshaller Marshaller,
) *amqpChannel {
channel := &amqpChannel{
ctx: ctx,
Expand All @@ -171,6 +179,7 @@ func newPublishingChannel(
connectionType: connectionTypePublisher,
publishingCache: newTTLMap[string, mqttPublishing](publishingCacheSize, publishingCacheTTL),
maxRetry: maxRetry,
marshaller: marshaller,
}

// We open an initial channel.
Expand Down Expand Up @@ -521,7 +530,7 @@ func (c *amqpChannel) retryDelivery(delivery *amqp.Delivery, alreadyAcknowledged

// We create a new publishing which is a copy of the old one but with a decremented xDeathCountHeader.
newPublishing := amqp.Publishing{
ContentType: "application/json",
ContentType: c.marshaller.ContentType(),
Body: delivery.Body,
Type: delivery.RoutingKey,
Priority: delivery.Priority,
Expand Down Expand Up @@ -554,7 +563,7 @@ func (c *amqpChannel) retryDelivery(delivery *amqp.Delivery, alreadyAcknowledged
// publish will publish a message with the given configuration.
func (c *amqpChannel) publish(exchange string, routingKey string, payload []byte, options *PublishingOptions) error {
publishing := &amqp.Publishing{
ContentType: "application/json",
ContentType: c.marshaller.ContentType(),
Body: payload,
Type: routingKey,
Priority: PriorityMedium.Uint8(),
Expand Down
5 changes: 5 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ func newClientFromOptions(options *ClientOptions) MQTTClient {
protocol = securedProtocol
}

if options.Marshaller == nil {
options.Marshaller = defaultMarshaller
}

dialURL := fmt.Sprintf("%s://%s:%s@%s:%d/%s", protocol, client.Username, client.Password, client.Host, client.Port, client.Vhost)

client.connectionManager = newConnectionManager(
Expand All @@ -163,6 +167,7 @@ func newClientFromOptions(options *ClientOptions) MQTTClient {
options.PublishingCacheSize,
options.PublishingCacheTTL,
client.logger,
options.Marshaller,
)

return client
Expand Down
11 changes: 11 additions & 0 deletions client_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ type ClientOptions struct {

// Mode will specify whether logs are enabled or not.
Mode string

// Marshaller defines the content type used for messages and how they're marshalled (default: JSON).
Marshaller Marshaller
}

// DefaultClientOptions will return a ClientOptions with default values.
Expand All @@ -63,6 +66,7 @@ func DefaultClientOptions() *ClientOptions {
PublishingCacheTTL: defaultPublishingCacheTTL,
PublishingCacheSize: defaultPublishingCacheSize,
Mode: defaultMode,
Marshaller: defaultMarshaller,
}
}

Expand Down Expand Up @@ -195,3 +199,10 @@ func (c *ClientOptions) SetMode(mode string) *ClientOptions {

return c
}

// SetMarshaller will assign the Marshaller.
func (c *ClientOptions) SetMarshaller(marshaller Marshaller) *ClientOptions {
c.Marshaller = marshaller

return c
}
30 changes: 25 additions & 5 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ type amqpConnection struct {

// connectionType defines the connectionType.
connectionType connectionType

// marshaller defines the marshalling method used to encode messages.
marshaller Marshaller
}

// newConsumerConnection initializes a new consumer amqpConnection with given arguments.
Expand All @@ -57,8 +60,17 @@ type amqpConnection struct {
// - keepAlive will keep the connection alive if true.
// - retryDelay defines the delay between each re-connection, if the keepAlive flag is set to true.
// - logger is the parent logger.
func newConsumerConnection(ctx context.Context, uri, connectionName string, keepAlive bool, retryDelay time.Duration, logger logger) *amqpConnection {
return newConnection(ctx, uri, connectionName, keepAlive, retryDelay, logger, connectionTypeConsumer)
// - marshaller is the Marshaller used for encoding messages.
func newConsumerConnection(
ctx context.Context,
uri string,
connectionName string,
keepAlive bool,
retryDelay time.Duration,
logger logger,
marshaller Marshaller,
) *amqpConnection {
return newConnection(ctx, uri, connectionName, keepAlive, retryDelay, logger, connectionTypeConsumer, marshaller)
}

// newPublishingConnection initializes a new publisher amqpConnection with given arguments.
Expand All @@ -71,6 +83,7 @@ func newConsumerConnection(ctx context.Context, uri, connectionName string, keep
// - publishingCacheSize defines the maximum length of failed publishing cache.
// - publishingCacheTTL defines the time to live for failed publishing in cache.
// - logger is the parent logger.
// - marshaller is the Marshaller used for encoding messages.
func newPublishingConnection(
ctx context.Context,
uri string,
Expand All @@ -81,8 +94,9 @@ func newPublishingConnection(
publishingCacheSize uint64,
publishingCacheTTL time.Duration,
logger logger,
marshaller Marshaller,
) *amqpConnection {
conn := newConnection(ctx, uri, connectionName, keepAlive, retryDelay, logger, connectionTypePublisher)
conn := newConnection(ctx, uri, connectionName, keepAlive, retryDelay, logger, connectionTypePublisher, marshaller)

conn.maxRetry = maxRetry
conn.publishingCacheSize = publishingCacheSize
Expand All @@ -98,6 +112,7 @@ func newPublishingConnection(
// - keepAlive will keep the connection alive if true.
// - retryDelay defines the delay between each re-connection, if the keepAlive flag is set to true.
// - logger is the parent logger.
// - marshaller is the Marshaller used for encoding messages.
func newConnection(
ctx context.Context,
uri string,
Expand All @@ -106,6 +121,7 @@ func newConnection(
retryDelay time.Duration,
logger logger,
connectionType connectionType,
marshaller Marshaller,
) *amqpConnection {
conn := &amqpConnection{
ctx: ctx,
Expand All @@ -119,6 +135,7 @@ func newConnection(
"type": connectionType,
}),
connectionType: connectionType,
marshaller: marshaller,
}

conn.logger.Debug("Initializing new amqp connection", logField{Key: "uri", Value: conn.uriForLog()})
Expand Down Expand Up @@ -303,7 +320,7 @@ func (a *amqpConnection) registerConsumer(consumer MessageConsumer) error {
return err
}

channel := newConsumerChannel(a.ctx, a.connection, a.keepAlive, a.retryDelay, &consumer, a.logger)
channel := newConsumerChannel(a.ctx, a.connection, a.keepAlive, a.retryDelay, &consumer, a.logger, a.marshaller)

a.channels = append(a.channels, channel)

Expand All @@ -315,7 +332,10 @@ func (a *amqpConnection) registerConsumer(consumer MessageConsumer) error {
func (a *amqpConnection) publish(exchange, routingKey string, payload []byte, options *PublishingOptions) error {
publishingChannel := a.channels.publishingChannel()
if publishingChannel == nil {
publishingChannel = newPublishingChannel(a.ctx, a.connection, a.keepAlive, a.retryDelay, a.maxRetry, a.publishingCacheSize, a.publishingCacheTTL, a.logger)
publishingChannel = newPublishingChannel(
a.ctx, a.connection, a.keepAlive, a.retryDelay, a.maxRetry,
a.publishingCacheSize, a.publishingCacheTTL, a.logger, a.marshaller,
)

a.channels = append(a.channels, publishingChannel)
}
Expand Down
17 changes: 13 additions & 4 deletions connection_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package gorabbit

import (
"context"
"encoding/json"
"time"
)

Expand All @@ -12,6 +11,9 @@ type connectionManager struct {

// publisherConnection holds the independent publishing connection.
publisherConnection *amqpConnection

// marshaller holds the marshaller used to encode messages.
marshaller Marshaller
}

// newConnectionManager instantiates a new connectionManager with given arguments.
Expand All @@ -25,10 +27,17 @@ func newConnectionManager(
publishingCacheSize uint64,
publishingCacheTTL time.Duration,
logger logger,
marshaller Marshaller,
) *connectionManager {
c := &connectionManager{
consumerConnection: newConsumerConnection(ctx, uri, connectionName, keepAlive, retryDelay, logger),
publisherConnection: newPublishingConnection(ctx, uri, connectionName, keepAlive, retryDelay, maxRetry, publishingCacheSize, publishingCacheTTL, logger),
consumerConnection: newConsumerConnection(
ctx, uri, connectionName, keepAlive, retryDelay, logger, marshaller,
),
publisherConnection: newPublishingConnection(
ctx, uri, connectionName, keepAlive, retryDelay, maxRetry,
publishingCacheSize, publishingCacheTTL, logger, marshaller,
),
marshaller: marshaller,
}

return c
Expand Down Expand Up @@ -75,7 +84,7 @@ func (c *connectionManager) publish(exchange, routingKey string, payload interfa
return errPublisherConnectionNotInitialized
}

payloadBytes, err := json.Marshal(payload)
payloadBytes, err := c.marshaller.Marshal(payload)
if err != nil {
return err
}
Expand Down
2 changes: 2 additions & 0 deletions constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ const (
defaultMode = Release
)

var defaultMarshaller = NewJSONMarshaller()

// Default values for the amqp Config.
const (
defaultHeartbeat = 10 * time.Second
Expand Down
12 changes: 10 additions & 2 deletions manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ type mqttManager struct {

// channel holds the single channel from the connection.
channel *amqp.Channel

// marshaller holds the marshaller used to encode messages.
marshaller Marshaller
}

// NewManager will instantiate a new MQTTManager.
Expand Down Expand Up @@ -158,6 +161,11 @@ func newManagerFromOptions(options *ManagerOptions) (MQTTManager, error) {
protocol = securedProtocol
}

if options.Marshaller == nil {
options.Marshaller = defaultMarshaller
}
manager.marshaller = options.Marshaller

dialURL := fmt.Sprintf("%s://%s:%s@%s:%d/%s", protocol, manager.Username, manager.Password, manager.Host, manager.Port, manager.Vhost)

var err error
Expand Down Expand Up @@ -320,14 +328,14 @@ func (manager *mqttManager) PushMessageToExchange(exchange, routingKey string, p
}

// We convert the payload to a []byte.
payloadBytes, err := json.Marshal(payload)
payloadBytes, err := manager.marshaller.Marshal(payload)
if err != nil {
return err
}

// We build the amqp.Publishing object.
publishing := amqp.Publishing{
ContentType: "application/json",
ContentType: manager.marshaller.ContentType(),
Body: payloadBytes,
Type: routingKey,
Priority: PriorityMedium.Uint8(),
Expand Down
25 changes: 18 additions & 7 deletions manager_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,22 @@ type ManagerOptions struct {

// Mode will specify whether logs are enabled or not.
Mode string

// Marshaller defines the content type used for messages and how they're marshalled (default: JSON).
Marshaller Marshaller
}

// DefaultManagerOptions will return a ManagerOptions with default values.
func DefaultManagerOptions() *ManagerOptions {
return &ManagerOptions{
Host: defaultHost,
Port: defaultPort,
Username: defaultUsername,
Password: defaultPassword,
Vhost: defaultVhost,
UseTLS: defaultUseTLS,
Mode: defaultMode,
Host: defaultHost,
Port: defaultPort,
Username: defaultUsername,
Password: defaultPassword,
Vhost: defaultVhost,
UseTLS: defaultUseTLS,
Mode: defaultMode,
Marshaller: defaultMarshaller,
}
}

Expand Down Expand Up @@ -126,3 +130,10 @@ func (m *ManagerOptions) SetMode(mode string) *ManagerOptions {

return m
}

// SetMarshaller will assign the Marshaller.
func (m *ManagerOptions) SetMarshaller(marshaller Marshaller) *ManagerOptions {
m.Marshaller = marshaller

return m
}
Loading

0 comments on commit d204388

Please sign in to comment.