/
Handshake.scala
333 lines (287 loc) · 12.3 KB
/
Handshake.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
/*
* Copyright (C) 2016-2023 Lightbend Inc. <https://www.lightbend.com>
*/
package akka.remote.artery
import scala.concurrent.Future
import scala.concurrent.duration._
import scala.util.control.NoStackTrace
import akka.Done
import akka.actor.ActorSystem
import akka.actor.Address
import akka.dispatch.ExecutionContexts
import akka.remote.UniqueAddress
import akka.stream.Attributes
import akka.stream.FlowShape
import akka.stream.Inlet
import akka.stream.Outlet
import akka.stream.stage._
import akka.util.OptionVal
import akka.util.unused
/**
* INTERNAL API
*/
private[remote] object OutboundHandshake {
/**
* Stream is failed with this exception if the handshake is not completed
* within the handshake timeout.
*/
class HandshakeTimeoutException(msg: String) extends RuntimeException(msg) with NoStackTrace
final case class HandshakeReq(from: UniqueAddress, to: Address) extends ControlMessage
final case class HandshakeRsp(from: UniqueAddress) extends Reply
private sealed trait HandshakeState
private case object Start extends HandshakeState
private case object ReqInProgress extends HandshakeState
private case object Completed extends HandshakeState
private case object HandshakeTimeout
private case object HandshakeRetryTick
private case object InjectHandshakeTick
private case object LivenessProbeTick
}
/**
* INTERNAL API
*/
private[remote] class OutboundHandshake(
@unused system: ActorSystem,
outboundContext: OutboundContext,
outboundEnvelopePool: ObjectPool[ReusableOutboundEnvelope],
timeout: FiniteDuration,
retryInterval: FiniteDuration,
injectHandshakeInterval: FiniteDuration,
livenessProbeInterval: Duration)
extends GraphStage[FlowShape[OutboundEnvelope, OutboundEnvelope]] {
val in: Inlet[OutboundEnvelope] = Inlet("OutboundHandshake.in")
val out: Outlet[OutboundEnvelope] = Outlet("OutboundHandshake.out")
override val shape: FlowShape[OutboundEnvelope, OutboundEnvelope] = FlowShape(in, out)
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
new TimerGraphStageLogic(shape) with InHandler with OutHandler with StageLogging {
import OutboundHandshake._
private var handshakeState: HandshakeState = Start
private var pendingMessage: OptionVal[OutboundEnvelope] = OptionVal.None
private var injectHandshakeTickScheduled = false
private val uniqueRemoteAddressAsyncCallback = getAsyncCallback[UniqueAddress] { _ =>
if (handshakeState != Completed) {
handshakeCompleted()
if (isAvailable(out))
pull(in)
}
}
// this must be a `val` because function equality is used when removing in postStop
private val uniqueRemoteAddressListener: UniqueAddress => Unit =
peer => uniqueRemoteAddressAsyncCallback.invoke(peer)
override protected def logSource: Class[_] = classOf[OutboundHandshake]
override def preStart(): Unit = {
scheduleOnce(HandshakeTimeout, timeout)
livenessProbeInterval match {
case d: FiniteDuration => scheduleWithFixedDelay(LivenessProbeTick, d, d)
case _ => // only used in control stream
}
}
override def postStop(): Unit = {
outboundContext.associationState.removeUniqueRemoteAddressListener(uniqueRemoteAddressListener)
super.postStop()
}
// InHandler
override def onPush(): Unit = {
if (handshakeState != Completed)
throw new IllegalStateException(s"onPush before handshake completed, was [$handshakeState].")
// inject a HandshakeReq once in a while to trigger a new handshake when destination
// system has been restarted
if (injectHandshakeTickScheduled) {
// out is always available here, except for if a liveness HandshakeReq was just pushed
if (isAvailable(out))
push(out, grab(in))
else {
if (pendingMessage.isDefined)
throw new IllegalStateException(s"pendingMessage expected to be empty")
pendingMessage = OptionVal.Some(grab(in))
}
} else {
pushHandshakeReq()
pendingMessage = OptionVal.Some(grab(in))
}
}
// OutHandler
override def onPull(): Unit = {
handshakeState match {
case Completed =>
pendingMessage match {
case OptionVal.Some(p) =>
push(out, p)
pendingMessage = OptionVal.None
case _ =>
if (!hasBeenPulled(in))
pull(in)
}
case Start =>
outboundContext.associationState.uniqueRemoteAddress() match {
case Some(_) =>
handshakeCompleted()
case None =>
// will pull when handshake reply is received (uniqueRemoteAddress populated)
handshakeState = ReqInProgress
scheduleWithFixedDelay(HandshakeRetryTick, retryInterval, retryInterval)
// The InboundHandshake stage will complete the AssociationState.uniqueRemoteAddress
// when it receives the HandshakeRsp reply
outboundContext.associationState.addUniqueRemoteAddressListener(uniqueRemoteAddressListener)
}
// always push a HandshakeReq as the first message
pushHandshakeReq()
case ReqInProgress => // will pull when handshake reply is received
}
}
private def pushHandshakeReq(): Unit = {
injectHandshakeTickScheduled = true
scheduleOnce(InjectHandshakeTick, injectHandshakeInterval)
outboundContext.associationState.lastUsedTimestamp.set(System.nanoTime())
if (isAvailable(out))
push(out, createHandshakeReqEnvelope())
}
private def pushLivenessProbeReq(): Unit = {
// The associationState.lastUsedTimestamp will be updated when the HandshakeRsp is received
// and that is the confirmation that the other system is alive, and will not be quarantined
// by the quarantine-idle-outbound-after even though no real messages have been sent.
if (handshakeState == Completed && isAvailable(out) && pendingMessage.isEmpty) {
val lastUsedDuration = (System.nanoTime() - outboundContext.associationState.lastUsedTimestamp.get()).nanos
if (lastUsedDuration >= livenessProbeInterval) {
log.info(
"Association to [{}] has been idle for [{}] seconds, sending HandshakeReq to validate liveness",
outboundContext.remoteAddress,
lastUsedDuration.toSeconds)
push(out, createHandshakeReqEnvelope())
}
}
}
private def createHandshakeReqEnvelope(): OutboundEnvelope = {
outboundEnvelopePool
.acquire()
.init(
recipient = OptionVal.None,
message = HandshakeReq(outboundContext.localAddress, outboundContext.remoteAddress),
sender = OptionVal.None)
}
private def handshakeCompleted(): Unit = {
handshakeState = Completed
cancelTimer(HandshakeRetryTick)
cancelTimer(HandshakeTimeout)
}
override protected def onTimer(timerKey: Any): Unit =
timerKey match {
case InjectHandshakeTick =>
// next onPush message will trigger sending of HandshakeReq
injectHandshakeTickScheduled = false
case LivenessProbeTick =>
pushLivenessProbeReq()
case HandshakeRetryTick =>
if (isAvailable(out))
pushHandshakeReq()
case HandshakeTimeout =>
failStage(
new HandshakeTimeoutException(
s"Handshake with [${outboundContext.remoteAddress}] did not complete within ${timeout.toMillis} ms"))
case unknown =>
throw new IllegalArgumentException(s"Unknown timer key: $unknown")
}
setHandlers(in, out, this)
}
}
/**
* INTERNAL API
*/
private[remote] class InboundHandshake(inboundContext: InboundContext, inControlStream: Boolean)
extends GraphStage[FlowShape[InboundEnvelope, InboundEnvelope]] {
val in: Inlet[InboundEnvelope] = Inlet("InboundHandshake.in")
val out: Outlet[InboundEnvelope] = Outlet("InboundHandshake.out")
override val shape: FlowShape[InboundEnvelope, InboundEnvelope] = FlowShape(in, out)
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
new TimerGraphStageLogic(shape) with OutHandler with StageLogging {
import OutboundHandshake._
private val runInStage = getAsyncCallback[() => Unit] { thunk =>
thunk()
}
// InHandler
if (inControlStream)
setHandler(
in,
new InHandler {
override def onPush(): Unit = {
val env = grab(in)
env.message match {
case HandshakeReq(from, to) => onHandshakeReq(from, to)
case HandshakeRsp(from) =>
// Touch the lastUsedTimestamp here also because when sending the extra low frequency HandshakeRsp
// the timestamp is not supposed to be updated when sending but when receiving reply, which confirms
// that the other system is alive.
inboundContext.association(from.address).associationState.lastUsedTimestamp.set(System.nanoTime())
after(inboundContext.completeHandshake(from)) { _ =>
pull(in)
}
case _ =>
onMessage(env)
}
}
})
else
setHandler(in, new InHandler {
override def onPush(): Unit = {
val env = grab(in)
env.message match {
case HandshakeReq(from, to) => onHandshakeReq(from, to)
case _ =>
onMessage(env)
}
}
})
private def onHandshakeReq(from: UniqueAddress, to: Address): Unit = {
if (to == inboundContext.localAddress.address) {
after(inboundContext.completeHandshake(from)) { success =>
if (success)
inboundContext.sendControl(from.address, HandshakeRsp(inboundContext.localAddress))
pull(in)
}
} else {
log.warning(
"Dropping Handshake Request from [{}] addressed to unknown local address [{}]. " +
"Local address is [{}]. Check that the sending system uses the same " +
"address to contact recipient system as defined in the " +
"'akka.remote.artery.canonical.hostname' of the recipient system. " +
"The name of the ActorSystem must also match.",
from,
to,
inboundContext.localAddress.address)
pull(in)
}
}
private def after(first: Future[Done])(thenInside: Boolean => Unit): Unit = {
first.value match {
case Some(result) =>
// This in the normal case (all but the first). The future will be completed
// because handshake was already completed. Note that we send those HandshakeReq
// periodically.
thenInside(result.isSuccess)
case None =>
first.onComplete(result => runInStage.invoke(() => thenInside(result.isSuccess)))(
ExecutionContexts.parasitic)
}
}
private def onMessage(env: InboundEnvelope): Unit = {
if (isKnownOrigin(env))
push(out, env)
else {
val dropReason = s"Unknown system with UID [${env.originUid}]. " +
s"This system with UID [${inboundContext.localAddress.uid}] was probably restarted. " +
"Messages will be accepted when new handshake has been completed."
inboundContext.publishDropped(env, dropReason)
pull(in)
}
}
private def isKnownOrigin(env: InboundEnvelope): Boolean = {
// the association is passed in the envelope from the Decoder stage to avoid
// additional lookup. The second OR case is because if we didn't use fusing it
// would be possible that it was not found by Decoder (handshake not completed yet)
env.association.isDefined || inboundContext.association(env.originUid).isDefined
}
// OutHandler
override def onPull(): Unit = pull(in)
setHandler(out, this)
}
}