/
RaftSimulator.scala
452 lines (394 loc) · 16.7 KB
/
RaftSimulator.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
package riff.raft.integration
package simulator
import riff.raft.log.LogAppendResult
import riff.raft.messages.{ReceiveHeartbeatTimeout, RequestOrResponse, SendHeartbeatTimeout, TimerMessage}
import riff.raft.node.{RaftCluster, RaftNode, _}
import riff.raft.timer.RaftClock
import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer
import scala.concurrent.duration._
/**
* Exposes a means to have raft nodes generate messages and replies, as well as time-outs in a repeatable, testable,
* fast way w/o having to test using 'eventually'.
*
* Instead of being a real, async, multi-threaded context, however, events are put on a shared [[Timeline]] which then
* manually (via the [[riff.raft.integration.IntegrationTest]] advances time -- which just means popping the next event off the timeline,
* feeding it to its targeted recipient, and pushing any resulting events from that application back onto the Timeline.
*
* As such, there should *always* be at least one event on the timeline after advancing time in a non-empty cluster, as
* something should always have set a timeout (e.g. to receive or send a heartbeat).
*
* This implementation should resemble (a little bit) what other glue code looks like as a [[NodeState]] gets "lifted"
* into some other context.
*
* That is to say, we may drive the implementation via akka actors, monix or fs2 streams, REST frameworks, etc.
* Each of which would have a handle on single node and use its output to send/enqueue events, complete futures, whatever.
*
* Doing it this way we're just directly applying node request/responses to each other in a single (test) thread,
* making life much simpler (and faster) to debug
*
*/
class RaftSimulator private (
val nextSendTimeout: Iterator[FiniteDuration],
val nextReceiveTimeout: Iterator[FiniteDuration],
clusterNodes: List[String],
val defaultLatency: FiniteDuration,
newNode: (String, RaftCluster, RaftClock) => RaftNode[String])
extends HasTimeline[TimelineType] {
/** Simulates the effect of making a node un responsive by not sending requests/responses to the node.
* The messages sent, however, will remain "popped" from the timeline, which is what we want... just as if e.g.
* we sent a REST request which wasn't ever received.
*/
def killNode(nodeName: String) = {
stoppedNodes = stoppedNodes + nodeName
}
/**
* Just makes the node respond to events again
*
* @param nodeName the node to restart
*/
def restartNode(nodeName: String) = {
stoppedNodes = stoppedNodes - nodeName
}
// keeps track of events. This is a var that can change via 'updateTimeline', but doesn't need to be locked/volatile,
// as our IntegrationTest is single-threaded (which makes stepping through, debugging and testing a *LOT* easier
// as we don't have to wait for arbitrary times for things NOT to happen, or otherwise set short (but non-zero) delays
// which are tuned for different environments. In the end, the tests prove that the system is functionally correct and
// can drive itself via the events it generates, which gives us a HUGE amount of confidence that things are correct in
// a repeatable way.
//
// In practice too, the NodeState which drives a given member in the cluster isn't threadsafe anyway, and so should be
// put behind something else which drives that concern.
private var sharedSimulatedTimeline = Timeline[TimelineType]()
private var undeliveredTimeline = Timeline[TimelineType]()
// a separate collect of the nodes which we want to be unresponsive. We don't just remove them from the 'clusterByName'
// map, as we still want the cluster view to be correct (e.g. the leader node should still know about the stopped/unresponsive members)
private var stoppedNodes: Set[String] = Set.empty
private var clusterByName: Map[String, RaftNode[String]] = {
clusterNodes
.ensuring(_.distinct.size == clusterNodes.size)
.map { name => //
name -> makeNode(name)
}
.toMap
}
// a handy place for breakpoints to watch for all updates
private[simulator] def updateTimeline(newTimeline: Timeline[TimelineType]) = {
require(newTimeline.currentTime >= sharedSimulatedTimeline.currentTime)
sharedSimulatedTimeline = newTimeline
}
private def markUndelivered(newTimeline: Timeline[TimelineType]) = {
undeliveredTimeline = newTimeline
}
def nextNodeName(): String =
Iterator.from(clusterByName.size).map(nameForIdx).dropWhile(clusterByName.keySet.contains).next()
def addNodeCommand() = RaftSimulator.addNode(nextNodeName())
def removeNodeCommand(idx: Int) = RaftSimulator.removeNode(nameForIdx(idx))
def updateCluster(data: String): Unit = {
data match {
case RaftSimulator.AddCommand(name) =>
addNode(name)
case RaftSimulator.RemoveCommand(name) =>
removeNode(name)
case _ =>
}
}
override def currentTimeline(): Timeline[TimelineType] = {
sharedSimulatedTimeline
}
/**
* convenience method to append to whatever the leader node is -- will error if there is no leader
*
*/
def appendToLeader(data: Array[String], latency: FiniteDuration = defaultLatency): NodeAppendResult[String] = {
val ldr = currentLeader()
val ldrId = ldr.state().id
ldr.appendIfLeader(data) match {
case appendResults @ NodeAppendResult(_, AddressedRequest(requests)) =>
val newTimeline = requests.zipWithIndex.foldLeft(currentTimeline) {
case (timeline, ((to, msg), i)) =>
//
// ensure we insert the 'SendRequest' after any other SendRequest which came FROM this node
// ...so we're not reordering our requests
//
val (newTimeline, _) = timeline.pushAfter(latency + i.millis, SendRequest(ldrId, to, msg)) {
case SendRequest(from, _, _) => from == ldrId
}
newTimeline
}
updateTimeline(newTimeline)
appendResults
}
}
/** @param name the node whose RaftCluster this is
* @return an implementation of RaftCluster for each node as a view onto the simulator
*/
private def clusterForNode(name: String): RaftCluster ={
new RaftCluster {
override def peers: Iterable[String] ={
clusterByName.keySet - name
}
override def contains(key: String): Boolean = {
clusterByName.contains(key)
}
}
}
private def makeNode(name: String): RaftNode[String] = {
val node = newNode(name, clusterForNode(name), new SimulatedClock(this, name))
val newLog = node.log.onCommit { entry => //
updateCluster(entry.data)
}
node.withLog(newLog)
}
private def removeNode(name: String) = {
clusterByName = clusterByName - name
}
private def addNode(name: String) = {
if (!clusterByName.contains(name)) {
val node = makeNode(name)
clusterByName = clusterByName.updated(name, node)
}
this
}
def currentLeader(): RaftNode[String] = clusterByName(leaderState.id)
def nodesWithRole(role: NodeRole): List[RaftNode[String]] = nodes.filter(_.state().role == role)
def nodes(): List[RaftNode[String]] = clusterByName.values.toList
def leaderState(): LeaderNodeState = leaderStateOpt.get
/** @return the current leader (if there is one)
*/
def leaderStateOpt(): Option[LeaderNodeState] = {
val leaders = clusterByName.values.map(_.state()) collect {
case leader: LeaderNodeState => leader
}
leaders.toList match {
case Nil => None
case List(leader) => Option(leader)
case many => throw new IllegalStateException(s"Multiple simultaneous leaders! ${many}")
}
}
def takeSnapshot(): Map[String, NodeSnapshot[String]] = clusterByName.map {
case (key, n) => (key, NodeSnapshot(n))
}
/**
* advance the timeline by one event
*
* @param latency
* @return
*/
def advance(latency: FiniteDuration = defaultLatency): AdvanceResult = {
advanceSafe(latency).getOrElse(
sys.error(s"Timeline is empty! This should never happen - there should always be timeouts queued"))
}
/**
* flush the timeline applying all current events
*
* @param latency
* @return
*/
def advanceAll(latency: FiniteDuration = defaultLatency): AdvanceResult = advanceBy(currentTimeline().size)
def advanceUntil(predicate: AdvanceResult => Boolean): AdvanceResult = {
advanceUntil(100, defaultLatency, predicate)
}
def advanceUntilDebug(predicate: AdvanceResult => Boolean): AdvanceResult = {
advanceUntil(100, defaultLatency, predicate, debug = true)
}
/**
* Continues to advance the timeline until the given 'predicate' condition is true (up to 'max')
*
* @param max the maximum number to advance (the timeline should continually get timeouts registered, so without a max this could potentially loop forever)
* @param latency the latency to use for enqueueing request/responses
* @param predicate the condition we're moving towards
* @param debug if true this spits out the timeline at each step as a convenience
* @return the result
*/
def advanceUntil(
max: Int,
latency: FiniteDuration,
predicate: AdvanceResult => Boolean,
debug: Boolean = false): AdvanceResult = {
def logDebug(newTimeline: Timeline[TimelineType]) = {
if (debug) {
println("- " * 50)
println(debugString(Option(newTimeline)))
}
}
logDebug(currentTimeline())
var next: AdvanceResult = advance(latency)
val list = ListBuffer[AdvanceResult](next)
while (!predicate(next) && list.size < max) {
logDebug(next.beforeTimeline)
next = advance(latency)
list += next
}
logDebug(next.beforeTimeline)
require(list.size < max, s"The condition was never met after $max iterations")
concatResults(list.toList)
}
/**
* flush 'nr' events
*
* @param nr the number of events to advance
* @param latency
* @return
*/
def advanceBy(nr: Int, latency: FiniteDuration = defaultLatency): AdvanceResult = {
val list = (1 to nr).map { _ => advance(latency)
}.toList
concatResults(list)
}
private def concatResults(list: List[AdvanceResult]): AdvanceResult = {
val last: AdvanceResult = list.last
last.copy(
beforeTimeline = list.head.beforeTimeline,
beforeStateByName = list.head.beforeStateByName,
advanceEvents = list.flatMap(_.advanceEvents))
}
/**
* pops the next event from the sharedSimulatedTimeline and enqueues the result onto the sharedSimulatedTimeline.
*
* @param latency the time buffer to assume for sending/receiving messages.
* @return the result of advancing the next event in the timeline, if there was a next event
*/
private def advanceSafe(latency: FiniteDuration = defaultLatency): Option[AdvanceResult] = {
val beforeState: Map[String, NodeSnapshot[String]] = takeSnapshot()
val beforeTimeline = currentTimeline
// pop one off our timeline stack, then subsequently update our sharedSimulatedTimeline
beforeTimeline.pop().map {
case (newTimeline, e) =>
// first ensure the latest timeline is the advanced, popped one
updateTimeline(newTimeline)
val (recipient, result) = applyTimelineEvent(e, latency)
AdvanceResult(
recipient,
beforeState,
beforeTimeline,
e,
result,
currentTimeline,
undeliveredTimeline,
takeSnapshot())
}
}
def nodeFor(idx: Int): RaftNode[String] = {
clusterByName.getOrElse(nameForIdx(idx), sys.error(s"Couldn't find ${nameForIdx(idx)} in ${clusterByName.keySet}"))
}
def snapshotFor(idx: Int): NodeSnapshot[String] = NodeSnapshot(nodeFor(idx))
/**
* Applies the next timeline event.
*
* Typically we allow the simulator to use 'advance*' methods to let the nodes drive their own logic, but this
* CAN be called directlly by tests if we want to force an event (e.g. force an election, etc)
*
* @param nextEvent
* @param currentTime
* @param latency
*/
def applyTimelineEvent(
nextEvent: TimelineType,
latency: FiniteDuration = defaultLatency,
currentTime: Long = currentTimeline().currentTime): (String, RaftNode[String]#Result) = {
val res @ (recipient, result) = processNextEvent(nextEvent, currentTime)
applyResult(latency, recipient, result)
res
}
@tailrec
private def applyResult(latency: FiniteDuration, node: String, result: RaftNodeResult[String]): Unit = {
result match {
case _: NoOpResult =>
case LeaderCommittedResult(_, msg) => applyResult(latency, node, msg)
case NodeAppendResult(_, msg) => applyResult(latency, node, msg)
case AddressedRequest(msgs) =>
val newTimeline = msgs.zipWithIndex.foldLeft(currentTimeline) {
case (time, ((to, msg), i)) =>
val (newTime, _) =
// if we just blindly use the latency, the we'll end up sending messages based off whatever the current
// time is in the timeline, when really it should be done after the next
time.pushAfter(latency + i.millis, SendRequest(node, to, msg)) {
case SendRequest(from, _, _) => from == node
}
newTime
}
updateTimeline(newTimeline)
case AddressedResponse(to, msg) =>
val (newTime, _) = currentTimeline.insertAfter(latency, SendResponse(node, to, msg))
updateTimeline(newTime)
}
}
/**
* process this event (assumed to be the next in the timeline)
*
* @return a tuple of the next event, recipient of the event, and result in an option, or None if no events are enqueued (which should never be the case, as we should always be sending heartbeats)
*/
private def processNextEvent(
nextEvent: TimelineType,
currentTime: Long = currentTimeline().currentTime): (String, RaftNode[String]#Result) = {
def deliverMsg(from: String, to: String, msg: RequestOrResponse[String]) = {
clusterByName.get(to) match {
case Some(node) if !stoppedNodes.contains(to) =>
node.handleMessage(from, msg)
case _ =>
markUndelivered(undeliveredTimeline.insertAfter(currentTime.millis, nextEvent)._1)
NoOpResult(s"Can't deliver msg from $from to $to : $msg")
}
}
def deliverTimerMsg(to: String, msg: TimerMessage) = {
clusterByName.get(to) match {
case Some(node) if !stoppedNodes.contains(to) => node.onTimerMessage(msg)
case _ =>
markUndelivered(undeliveredTimeline.insertAfter(currentTime.millis, nextEvent)._1)
NoOpResult(s"Can't deliver timer msg for $to : $msg")
}
}
nextEvent match {
case SendTimeout(node) => (node, deliverTimerMsg(node, SendHeartbeatTimeout))
case ReceiveTimeout(node: String) => (node, deliverTimerMsg(node, ReceiveHeartbeatTimeout))
case SendRequest(from, to, request) =>
val result = deliverMsg(from, to, request)
(to, result)
case SendResponse(from, to, response) =>
val result = deliverMsg(from, to, response)
(to, result)
}
}
override def toString(): String = debugString()
def debugString(previousTimeline: Option[Timeline[TimelineType]] = None): String = {
val timelineString = pretty("", previousTimeline)
val strings = takeSnapshot().map {
case (id, node) if stoppedNodes(id) => node.copy(name = s"${node.name} [stopped]").pretty()
case (_, node) => node.pretty()
}
strings.mkString(s"${timelineString}\n", "\n", "")
}
}
object RaftSimulator {
type NodeResult = RaftNodeResult[String]
/**
* Our log has the simple 'string' type. Our RaftSimulator's state machine will check those entries against these commands
* to add, remove (pause, etc) nodes
*/
private val AddCommand = "ADD:(.+)".r
def addNode(name: String) = s"ADD:${name}"
private val RemoveCommand = "REMOVE:(.+)".r
def removeNode(name: String) = s"REMOVE:${name}"
// the send heartbeat timeouts should be regular, so our fixed set of 'random' values shouldn't
// have much variance
def sendHeartbeatTimeouts: Iterator[FiniteDuration] =
Iterator(100.millis, 105.millis, 95.millis, 101.millis) ++ sendHeartbeatTimeouts
def receiveHeartbeatTimeouts: Iterator[FiniteDuration] = {
Iterator(350.millis, 280.millis, 400.millis, 370.millis) ++ receiveHeartbeatTimeouts
}
def newNode(name: String, cluster: RaftCluster, timer: RaftClock): RaftNode[String] = {
val st8: RaftNode[String] = RaftNode.inMemory[String](name)(timer).withCluster(cluster)
st8.resetReceiveHeartbeat()
st8
}
def clusterOfSize(n: Int)(
implicit createNode: (String, RaftCluster, RaftClock) => RaftNode[String] = newNode _): RaftSimulator = {
new RaftSimulator(
sendHeartbeatTimeouts,
receiveHeartbeatTimeouts,
(1 to n).map(nameForIdx).toList,
10.millis,
createNode)
}
}