Skip to content
This repository has been archived by the owner on Aug 1, 2019. It is now read-only.

Commit

Permalink
Add FSMWatcher to wait for an FSM to reach a certain index
Browse files Browse the repository at this point in the history
  • Loading branch information
freeekanayaka committed Jan 23, 2018
1 parent 9a1aace commit 72b4398
Show file tree
Hide file tree
Showing 16 changed files with 419 additions and 146 deletions.
25 changes: 9 additions & 16 deletions cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,7 @@ func Cluster(t testing.TB, fsms []raft.FSM, knobs ...Knob) ([]*raft.Raft, func()
}

for _, knob := range knobs {
knob.init(cluster)
}

servers := make([]raft.Server, 0)
for i, node := range cluster.nodes {
if !node.Bootstrap {
continue
}
server := raft.Server{
ID: raft.ServerID(strconv.Itoa(i)),
Address: node.Transport.LocalAddr(),
}
servers = append(servers, server)
knob.pre(cluster)
}

bootstrapCluster(t, cluster.nodes)
Expand All @@ -79,6 +67,10 @@ func Cluster(t testing.TB, fsms []raft.FSM, knobs ...Knob) ([]*raft.Raft, func()
rafts[i] = raft
}

for _, knob := range knobs {
knob.post(rafts)
}

cleanup := func() {
Shutdown(t, rafts)
}
Expand All @@ -89,7 +81,8 @@ func Cluster(t testing.TB, fsms []raft.FSM, knobs ...Knob) ([]*raft.Raft, func()
// Knob can be used to tweak the dependencies of test Raft nodes created with
// Cluster() or Node().
type Knob interface {
init(*cluster)
pre(*cluster)
post([]*raft.Raft)
}

// Shutdown all the given raft nodes and fail the test if any of them errors
Expand Down Expand Up @@ -148,7 +141,7 @@ func newDefaultNode(t testing.TB, i int) *node {
out := TestingWriter(t)
config := raft.DefaultConfig()
config.LocalID = raft.ServerID(addr)
config.Logger = log.New(out, fmt.Sprintf("%s: ", addr), 0)
config.Logger = log.New(out, fmt.Sprintf("%s: ", addr), log.Ltime|log.Lmicroseconds)

// Decrease timeouts, since everything happens in-memory by
// default.
Expand All @@ -161,7 +154,7 @@ func newDefaultNode(t testing.TB, i int) *node {
Config: config,
Logs: raft.NewInmemStore(),
Stable: raft.NewInmemStore(),
Snapshots: raft.NewDiscardSnapshotStore(),
Snapshots: raft.NewInmemSnapshotStore(),
Transport: transport,
Bootstrap: true,
}
Expand Down
5 changes: 4 additions & 1 deletion config.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@ type configKnob struct {
f func(int, *raft.Config)
}

func (k *configKnob) init(cluster *cluster) {
func (k *configKnob) pre(cluster *cluster) {
for i, node := range cluster.nodes {
k.f(i, node.Config)
}
}

func (k *configKnob) post([]*raft.Raft) {
}
82 changes: 82 additions & 0 deletions example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright 2017 Canonical Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package rafttest_test

import (
"fmt"
"testing"
"time"

"github.com/CanonicalLtd/raft-test"
"github.com/hashicorp/raft"
"github.com/stretchr/testify/require"
)

func Example() {
t := &testing.T{}

// Create 3 raft FSMs and wrap them with a watcher.
fsms := rafttest.FSMs(3)
watcher := rafttest.FSMWatcher(t, fsms)

// Create test cluster knobs.
notify := rafttest.Notify()
network := rafttest.Network()
config := rafttest.Config(func(n int, config *raft.Config) {
config.SnapshotInterval = 5 * time.Millisecond
config.SnapshotThreshold = 4
config.TrailingLogs = 10
})

// Create a cluster of raft instances.
rafts, cleanup := rafttest.Cluster(t, fsms, notify, network, config)
defer cleanup()

// Get the index of the first raft instance to acquiring leadership.
i := notify.NextAcquired(time.Second)
r := rafts[i]

// Get the index of one of the two follower raft instances.
j := rafttest.Other(rafts, i)

// Apply a log and wait for for all FSMs to apply it.
require.NoError(t, r.Apply([]byte{}, time.Second).Error())
for i := range fsms {
watcher.WaitIndex(i, 3, time.Second)
}

// Simulate a network disconnection of raft instance j.
network.Disconnect(j)

// Apply another few logs, leaving raft instance j behind.
for i := 0; i < 100; i++ {
require.NoError(t, r.Apply([]byte{}, time.Second).Error())
}
watcher.WaitIndex(i, 103, time.Second)

// Make sure a snapshot is taken by the leader i.
watcher.WaitSnapshot(i, 1, time.Second)

// Reconnect raft instance j.
network.Reconnect(j)

// Wait for raft instance j to use a snapshot shipped by the leader to
// catch up with logs.
watcher.WaitRestore(j, 1, time.Second)

// Output:
// 103
fmt.Println(rafts[j].AppliedIndex())
}
127 changes: 127 additions & 0 deletions fsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
package rafttest

import (
"fmt"
"io"
"sync"
"testing"
"time"

"github.com/hashicorp/raft"
)
Expand All @@ -34,6 +38,90 @@ func FSMs(n int) []raft.FSM {
return fsms
}

// FSMWatcher creates watchers for the given FSMs.
func FSMWatcher(t testing.TB, fsms []raft.FSM) *FSMWatcherAPI {
api := &FSMWatcherAPI{
t: t,
wrappers: make([]*fsmWrapper, len(fsms)),
}
for i, fsm := range fsms {
wrapper := &fsmWrapper{t: t, fsm: fsm}
api.wrappers[i] = wrapper
fsms[i] = wrapper
}
return api
}

// FSMWatcherAPI implements methods to wait for the underlying FSMs to reach
// certain states.
type FSMWatcherAPI struct {
t testing.TB
wrappers []*fsmWrapper
}

// WaitIndex blocks until the FSM with the given index has reached at least the
// given log.
//
// If the timeout expires, test will fail.
func (w *FSMWatcherAPI) WaitIndex(i int, index uint64, timeout time.Duration) {
helper, ok := w.t.(testingHelper)
if ok {
helper.Helper()
}

wrapper := w.wrappers[i]
check := func() bool {
wrapper.mu.Lock()
defer wrapper.mu.Unlock()
return wrapper.index >= index
}

message := fmt.Sprintf("fsm %d did not reach index %d", i, index)
wait(w.t, check, 25*time.Millisecond, timeout, message)
}

// WaitSnapshot blocks until the FSM with the given index has reached at least the
// given snapshot number.
//
// If the timeout expires, test will fail.
func (w *FSMWatcherAPI) WaitSnapshot(i int, n uint64, timeout time.Duration) {
helper, ok := w.t.(testingHelper)
if ok {
helper.Helper()
}

wrapper := w.wrappers[i]
check := func() bool {
wrapper.mu.Lock()
defer wrapper.mu.Unlock()
return wrapper.snapshots >= n
}

message := fmt.Sprintf("fsm %d did not reach snapshot %d", i, n)
wait(w.t, check, 25*time.Millisecond, timeout, message)
}

// WaitRestore blocks until the FSM with the given index has reached at least the
// given number of snapshot restores.
//
// If the timeout expires, test will fail.
func (w *FSMWatcherAPI) WaitRestore(i int, n uint64, timeout time.Duration) {
helper, ok := w.t.(testingHelper)
if ok {
helper.Helper()
}

wrapper := w.wrappers[i]
check := func() bool {
wrapper.mu.Lock()
defer wrapper.mu.Unlock()
return wrapper.restores >= n
}

message := fmt.Sprintf("fsm %d did not reach restore %d", i, n)
wait(w.t, check, 25*time.Millisecond, timeout, message)
}

// fsm is a dummy raft finite state machine that does nothing and
// always no-ops.
type fsm struct{}
Expand All @@ -58,3 +146,42 @@ func (s *fsmSnapshot) Persist(sink raft.SnapshotSink) error { return nil }

// Release is a no-op.
func (s *fsmSnapshot) Release() {}

// Wraps a raft.FSM, tracking logs, snapshots and restores.
type fsmWrapper struct {
t testing.TB
fsm raft.FSM
index uint64
snapshots uint64
restores uint64
mu sync.Mutex
}

// Apply always return a nil error without doing anything.
func (f *fsmWrapper) Apply(log *raft.Log) interface{} {
f.mu.Lock()
f.index = log.Index
f.mu.Unlock()

return f.fsm.Apply(log)
}

// Snapshot always return a dummy snapshot and no error without doing
// anything.
func (f *fsmWrapper) Snapshot() (raft.FSMSnapshot, error) {
f.mu.Lock()
f.snapshots++
f.mu.Unlock()

return f.fsm.Snapshot()
}

// Restore always return a nil error without reading anything from
// the reader.
func (f *fsmWrapper) Restore(reader io.ReadCloser) error {
f.mu.Lock()
f.restores++
f.mu.Unlock()

return f.fsm.Restore(reader)
}
81 changes: 81 additions & 0 deletions fsm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@
package rafttest_test

import (
"sync"
"testing"
"time"

"github.com/CanonicalLtd/raft-test"
"github.com/hashicorp/raft"
"github.com/stretchr/testify/assert"
)

func TestFSM_Restore(t *testing.T) {
Expand All @@ -26,3 +30,80 @@ func TestFSM_Restore(t *testing.T) {
t.Fatal(err)
}
}

func TestFSMWatcher_WaitIndex(t *testing.T) {
fsms := rafttest.FSMs(2)
watcher := rafttest.FSMWatcher(t, fsms)

go func() {
fsms[0].Apply(&raft.Log{Index: 1})
fsms[0].Apply(&raft.Log{Index: 2})
}()

go func() {
fsms[1].Apply(&raft.Log{Index: 1})
fsms[1].Apply(&raft.Log{Index: 2})
fsms[1].Apply(&raft.Log{Index: 3})
}()

watcher.WaitIndex(0, 2, time.Second)
watcher.WaitIndex(1, 3, time.Second)
}

func TestFSMWatcher_WaitIndexTimeout(t *testing.T) {
fsms := rafttest.FSMs(2)

testingT := &testing.T{}
watcher := rafttest.FSMWatcher(testingT, fsms)

succeeded := false

fsms[0].Apply(&raft.Log{Index: 1})
fsms[0].Apply(&raft.Log{Index: 2})

wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
watcher.WaitIndex(0, 3, time.Microsecond)
succeeded = true
}()

wg.Wait()

assert.False(t, succeeded)
}

func TestFSMWatcher_WaitSnapshot(t *testing.T) {
fsms := rafttest.FSMs(2)
watcher := rafttest.FSMWatcher(t, fsms)

go func() {
fsms[0].Snapshot()
}()

go func() {
fsms[1].Snapshot()
fsms[1].Snapshot()
}()

watcher.WaitSnapshot(0, 1, time.Second)
watcher.WaitSnapshot(1, 2, time.Second)
}

func TestFSMWatcher_WaitRestore(t *testing.T) {
fsms := rafttest.FSMs(2)
watcher := rafttest.FSMWatcher(t, fsms)

go func() {
fsms[0].Restore(nil)
}()

go func() {
fsms[1].Restore(nil)
fsms[1].Restore(nil)
}()

watcher.WaitRestore(0, 1, time.Second)
watcher.WaitRestore(1, 2, time.Second)
}

0 comments on commit 72b4398

Please sign in to comment.