/
TransactionalProducerStage.scala
194 lines (167 loc) · 7.04 KB
/
TransactionalProducerStage.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
/*
* Copyright (C) 2014 - 2016 Softwaremill <http://softwaremill.com>
* Copyright (C) 2016 - 2019 Lightbend Inc. <http://www.lightbend.com>
*/
package akka.kafka.internal
import akka.annotation.InternalApi
import akka.kafka.ConsumerMessage
import akka.kafka.ConsumerMessage.{GroupTopicPartition, PartitionOffset}
import akka.kafka.ProducerMessage.{Envelope, Results}
import akka.kafka.internal.ProducerStage.{MessageCallback, ProducerCompletionState}
import akka.stream.{Attributes, FlowShape}
import akka.stream.stage._
import org.apache.kafka.clients.consumer.OffsetAndMetadata
import org.apache.kafka.clients.producer.Producer
import org.apache.kafka.common.TopicPartition
import scala.concurrent.Future
import scala.concurrent.duration.FiniteDuration
import scala.concurrent.duration._
import scala.collection.JavaConverters._
/**
* INTERNAL API
*/
@InternalApi
private[kafka] final class TransactionalProducerStage[K, V, P](
val closeTimeout: FiniteDuration,
val closeProducerOnStop: Boolean,
val producerProvider: () => Producer[K, V],
commitInterval: FiniteDuration
) extends GraphStage[FlowShape[Envelope[K, V, P], Future[Results[K, V, P]]]]
with ProducerStage[K, V, P, Envelope[K, V, P], Results[K, V, P]] {
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
new TransactionalProducerStageLogic(this, producerProvider(), inheritedAttributes, commitInterval)
}
/** Internal API */
private object TransactionalProducerStage {
object TransactionBatch {
def empty: TransactionBatch = new EmptyTransactionBatch()
}
sealed trait TransactionBatch {
def updated(partitionOffset: PartitionOffset): TransactionBatch
}
final class EmptyTransactionBatch extends TransactionBatch {
override def updated(partitionOffset: PartitionOffset): TransactionBatch =
new NonemptyTransactionBatch(partitionOffset)
}
final class NonemptyTransactionBatch(head: PartitionOffset,
tail: Map[GroupTopicPartition, Long] = Map[GroupTopicPartition, Long]())
extends TransactionBatch {
// There is no guarantee that offsets adding callbacks will be called in any particular order.
// Decreasing an offset stored for the KTP would mean possible data duplication.
// Since `awaitingConfirmation` counter guarantees that all writes finished, we can safely assume
// that all all data up to maximal offsets has been wrote to Kafka.
private val previousHighest = tail.getOrElse(head.key, -1L)
private val offsets = tail + (head.key -> head.offset.max(previousHighest))
def group: String = head.key.groupId
def offsetMap(): Map[TopicPartition, OffsetAndMetadata] = offsets.map {
case (gtp, offset) => new TopicPartition(gtp.topic, gtp.partition) -> new OffsetAndMetadata(offset + 1)
}
override def updated(partitionOffset: PartitionOffset): TransactionBatch = {
require(
group == partitionOffset.key.groupId,
s"Transaction batch must contain messages from exactly 1 consumer group. $group != ${partitionOffset.key.groupId}"
)
new NonemptyTransactionBatch(partitionOffset, offsets)
}
}
}
/**
* Internal API.
*
* Transaction (Exactly-Once) Producer State Logic
*/
private final class TransactionalProducerStageLogic[K, V, P](stage: TransactionalProducerStage[K, V, P],
producer: Producer[K, V],
inheritedAttributes: Attributes,
commitInterval: FiniteDuration)
extends DefaultProducerStageLogic[K, V, P, Envelope[K, V, P], Results[K, V, P]](stage,
producer,
inheritedAttributes)
with StageLogging
with MessageCallback[K, V, P]
with ProducerCompletionState {
import TransactionalProducerStage._
private val commitSchedulerKey = "commit"
private val messageDrainInterval = 10.milliseconds
private var batchOffsets = TransactionBatch.empty
override def preStart(): Unit = {
initTransactions()
beginTransaction()
resumeDemand(tryToPull = false)
scheduleOnce(commitSchedulerKey, commitInterval)
}
private def resumeDemand(tryToPull: Boolean = true): Unit = {
setHandler(stage.out, new OutHandler {
override def onPull(): Unit = tryPull(stage.in)
})
// kick off demand for more messages if we're resuming demand
if (tryToPull && isAvailable(stage.out) && !hasBeenPulled(stage.in)) {
tryPull(stage.in)
}
}
private def suspendDemand(): Unit =
setHandler(
stage.out,
new OutHandler {
// suspend demand while a commit is in process so we can drain any outstanding message acknowledgements
override def onPull(): Unit = ()
}
)
override protected def onTimer(timerKey: Any): Unit =
if (timerKey == commitSchedulerKey) {
maybeCommitTransaction()
}
private def maybeCommitTransaction(beginNewTransaction: Boolean = true): Unit = {
val awaitingConf = awaitingConfirmation.get
batchOffsets match {
case batch: NonemptyTransactionBatch if awaitingConf == 0 =>
commitTransaction(batch, beginNewTransaction)
case _ if awaitingConf > 0 =>
suspendDemand()
scheduleOnce(commitSchedulerKey, messageDrainInterval)
case _ =>
scheduleOnce(commitSchedulerKey, commitInterval)
}
}
override val onMessageAckCb: AsyncCallback[Envelope[K, V, P]] =
getAsyncCallback[Envelope[K, V, P]](_.passThrough match {
case o: ConsumerMessage.PartitionOffset => batchOffsets = batchOffsets.updated(o)
case _ =>
})
override def onCompletionSuccess(): Unit = {
log.debug("Committing final transaction before shutdown")
cancelTimer(commitSchedulerKey)
maybeCommitTransaction(beginNewTransaction = false)
super.onCompletionSuccess()
}
override def onCompletionFailure(ex: Throwable): Unit = {
log.debug("Aborting transaction due to stage failure")
abortTransaction()
super.onCompletionFailure(ex)
}
private def commitTransaction(batch: NonemptyTransactionBatch, beginNewTransaction: Boolean): Unit = {
val group = batch.group
log.debug("Committing transaction for consumer group '{}' with offsets: {}", group, batch.offsetMap())
val offsetMap = batch.offsetMap().asJava
producer.sendOffsetsToTransaction(offsetMap, group)
producer.commitTransaction()
batchOffsets = TransactionBatch.empty
if (beginNewTransaction) {
beginTransaction()
resumeDemand()
scheduleOnce(commitSchedulerKey, commitInterval)
}
}
private def initTransactions(): Unit = {
log.debug("Initializing transactions")
producer.initTransactions()
}
private def beginTransaction(): Unit = {
log.debug("Beginning new transaction")
producer.beginTransaction()
}
private def abortTransaction(): Unit = {
log.debug("Aborting transaction")
producer.abortTransaction()
}
}