-
Notifications
You must be signed in to change notification settings - Fork 648
/
AmqpSourceStage.scala
181 lines (150 loc) · 6.71 KB
/
AmqpSourceStage.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
/*
* Copyright (C) 2016-2018 Lightbend Inc. <http://www.lightbend.com>
*/
package akka.stream.alpakka.amqp
import akka.Done
import akka.stream._
import akka.stream.alpakka.amqp.scaladsl.CommittableIncomingMessage
import akka.stream.stage._
import akka.util.ByteString
import com.rabbitmq.client.AMQP.BasicProperties
import com.rabbitmq.client._
import scala.collection.mutable
import scala.concurrent.Promise
import scala.util.Try
final case class IncomingMessage(bytes: ByteString, envelope: Envelope, properties: BasicProperties)
trait CommitCallback
final case class AckArguments(deliveryTag: Long, multiple: Boolean, promise: Promise[Done]) extends CommitCallback
final case class NackArguments(deliveryTag: Long, multiple: Boolean, requeue: Boolean, promise: Promise[Done])
extends CommitCallback
object AmqpSourceStage {
private val defaultAttributes = Attributes.name("AmqpSource")
}
/**
* Connects to an AMQP server upon materialization and consumes messages from it emitting them
* into the stream. Each materialized source will create one connection to the broker.
* As soon as an `IncomingMessage` is sent downstream, an ack for it is sent to the broker.
*
* @param bufferSize The max number of elements to prefetch and buffer at any given time.
*/
final class AmqpSourceStage(settings: AmqpSourceSettings, bufferSize: Int)
extends GraphStage[SourceShape[CommittableIncomingMessage]] { stage =>
val out = Outlet[CommittableIncomingMessage]("AmqpSource.out")
override val shape: SourceShape[CommittableIncomingMessage] = SourceShape.of(out)
override protected def initialAttributes: Attributes = AmqpSourceStage.defaultAttributes
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) with AmqpConnectorLogic {
override val settings = stage.settings
private val queue = mutable.Queue[CommittableIncomingMessage]()
private var unackedMessages = 0
override def whenConnected(): Unit = {
import scala.collection.JavaConverters._
// we have only one consumer per connection so global is ok
channel.basicQos(bufferSize, true)
val consumerCallback = getAsyncCallback(handleDelivery)
val shutdownCallback = getAsyncCallback[Option[ShutdownSignalException]] {
case Some(ex) => failStage(ex)
case None => if (unackedMessages == 0) completeStage()
}
val commitCallback = getAsyncCallback[CommitCallback] {
case AckArguments(deliveryTag, multiple, promise) => {
try {
channel.basicAck(deliveryTag, multiple)
unackedMessages -= 1
if (unackedMessages == 0 && isClosed(out)) completeStage()
promise.complete(Try(Done))
} catch {
case e: Throwable => promise.failure(e)
}
}
case NackArguments(deliveryTag, multiple, requeue, promise) => {
try {
channel.basicNack(deliveryTag, multiple, requeue)
unackedMessages -= 1
if (unackedMessages == 0 && isClosed(out)) completeStage()
promise.complete(Try(Done))
} catch {
case e: Throwable => promise.failure(e)
}
}
}
val amqpSourceConsumer = new DefaultConsumer(channel) {
override def handleDelivery(consumerTag: String,
envelope: Envelope,
properties: BasicProperties,
body: Array[Byte]): Unit =
consumerCallback.invoke(
new CommittableIncomingMessage {
override val message = IncomingMessage(ByteString(body), envelope, properties)
override def ack(multiple: Boolean) = {
val promise = Promise[Done]()
commitCallback.invoke(AckArguments(message.envelope.getDeliveryTag, multiple, promise))
promise.future
}
override def nack(multiple: Boolean, requeue: Boolean) = {
val promise = Promise[Done]()
commitCallback.invoke(NackArguments(message.envelope.getDeliveryTag, multiple, requeue, promise))
promise.future
}
}
)
override def handleCancel(consumerTag: String): Unit =
// non consumer initiated cancel, for example happens when the queue has been deleted.
shutdownCallback.invoke(None)
override def handleShutdownSignal(consumerTag: String, sig: ShutdownSignalException): Unit =
// "Called when either the channel or the underlying connection has been shut down."
shutdownCallback.invoke(Option(sig))
}
def setupNamedQueue(settings: NamedQueueSourceSettings): Unit =
channel.basicConsume(
settings.queue,
false, // never auto-ack
settings.consumerTag, // consumer tag
settings.noLocal,
settings.exclusive,
settings.arguments.asJava,
amqpSourceConsumer
)
def setupTemporaryQueue(settings: TemporaryQueueSourceSettings): Unit = {
// this is a weird case that required dynamic declaration, the queue name is not known
// up front, it is only useful for sources, so that's why it's not placed in the AmqpConnectorLogic
val queueName = channel.queueDeclare().getQueue
channel.queueBind(queueName, settings.exchange, settings.routingKey.getOrElse(""))
channel.basicConsume(
queueName,
amqpSourceConsumer
)
}
settings match {
case settings: NamedQueueSourceSettings => setupNamedQueue(settings)
case settings: TemporaryQueueSourceSettings => setupTemporaryQueue(settings)
}
}
def handleDelivery(message: CommittableIncomingMessage): Unit =
if (isAvailable(out)) {
pushMessage(message)
} else if (queue.size + 1 > bufferSize) {
failStage(new RuntimeException(s"Reached maximum buffer size $bufferSize"))
} else {
queue.enqueue(message)
}
setHandler(
out,
new OutHandler {
override def onPull(): Unit =
if (queue.nonEmpty) {
pushMessage(queue.dequeue())
}
override def onDownstreamFinish(): Unit = {
setKeepGoing(true)
if (unackedMessages == 0) super.onDownstreamFinish()
}
}
)
def pushMessage(message: CommittableIncomingMessage): Unit = {
push(out, message)
unackedMessages += 1
}
override def onFailure(ex: Throwable): Unit = {}
}
}