/
streaming_batch.go
109 lines (84 loc) · 2.83 KB
/
streaming_batch.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package shared
import (
"context"
"log"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/messaging/azservicebus"
)
// internalBatchSender is an interface for an *azservicebus.Sender
type internalBatchSender interface {
SendMessageBatch(ctx context.Context, batch internalBatch) error
NewMessageBatch(ctx context.Context, options *azservicebus.MessageBatchOptions) (internalBatch, error)
}
type internalBatch interface {
AddMessage(m *azservicebus.Message, options *azservicebus.AddMessageOptions) error
NumMessages() int32
}
type senderWrapper struct {
inner *TrackingSender
}
func (sw *senderWrapper) SendMessageBatch(ctx context.Context, batch internalBatch) error {
return sw.inner.SendMessageBatch(ctx, batch.(*azservicebus.MessageBatch), nil)
}
func (sw *senderWrapper) NewMessageBatch(ctx context.Context, options *azservicebus.MessageBatchOptions) (internalBatch, error) {
return sw.inner.NewMessageBatch(ctx, options)
}
func NewStreamingMessageBatch(ctx context.Context, sender internalBatchSender, expectedTotal int) (*StreamingMessageBatch, error) {
batch, err := sender.NewMessageBatch(ctx, nil)
if err != nil {
return nil, err
}
return &StreamingMessageBatch{
sender: sender,
currentBatch: batch,
expectedTotal: expectedTotal,
}, nil
}
type StreamingMessageBatch struct {
sender internalBatchSender
currentBatch internalBatch
expectedTotal int
total int
}
// Add appends to the current batch. If it's full it'll send it, allocate a new one.
func (sb *StreamingMessageBatch) Add(ctx context.Context, msg *azservicebus.Message, options *azservicebus.AddMessageOptions) error {
err := sb.currentBatch.AddMessage(msg, options)
if err == nil {
return nil
}
if err != azservicebus.ErrMessageTooLarge {
// must be a fatal error, just give up.
return err
}
log.Printf("Sending message batch with %d messages. %d/%d messages sent so far.", sb.currentBatch.NumMessages(), sb.total, sb.expectedTotal)
if err := sb.sender.SendMessageBatch(ctx, sb.currentBatch); err != nil {
return err
}
sb.total += int(sb.currentBatch.NumMessages())
// throttle a teeny bit.
time.Sleep(time.Second)
batch, err := sb.sender.NewMessageBatch(ctx, nil)
if err != nil {
return err
}
if err := batch.AddMessage(msg, nil); err != nil {
// if we can't add this message here (ie, by itself) into the batch then
// we'll just error out.
return err
}
sb.currentBatch = batch
return nil
}
// Close sends any messages currently held in our batch.
func (sb *StreamingMessageBatch) Close(ctx context.Context) error {
if sb.currentBatch.NumMessages() == 0 {
return nil
}
log.Printf("Sending final message batch")
if err := sb.sender.SendMessageBatch(ctx, sb.currentBatch); err != nil {
return err
}
return nil
}