forked from rubrikinc/kronos
/
state_machine_raft.go
209 lines (187 loc) · 6.56 KB
/
state_machine_raft.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
package oracle
import (
"context"
"github.com/cockroachdb/cockroach/pkg/util/protoutil"
"github.com/scaledata/etcd/snap"
"github.com/rubrikinc/kronos/kronosutil"
"github.com/rubrikinc/kronos/kronosutil/log"
"github.com/BMWE30/kronos/metadata"
"github.com/rubrikinc/kronos/pb"
)
// RaftConfig is used to initialize a raft based on the given parameters
type RaftConfig struct {
// RaftHostPort is the host:port of raft HTTP server
RaftHostPort *kronospb.NodeAddr
// GRPCHostPort is the host:port of raft HTTP server
GRPCHostPort *kronospb.NodeAddr
// SeedHosts is a comma separated list of kronos seed hosts in the cluster
SeedHosts []string
// CertsDir is the directory having node and CA certificates / pems
CertsDir string
// DataDir is the directory where a Kronos server will store its snapshots
// and raft WAL
DataDir string
// SnapCount is the number of raft entries after which a raft snapshot is
// triggered.
SnapCount uint64
}
// RaftStateMachine is a distributed state machine managed by raft.
// Each node maintains an in memory state machine where the actions fed in
// are the actions (messages) committed through raft
type RaftStateMachine struct {
// proposeC is the channel for proposing updates
proposeC chan<- string
// snapshotter is used to restore from a raft snapshot
snapshotter *snap.Snapshotter
// stateMachine is an in memory state machine managed through raft
stateMachine *inMemStateMachine
// channel to signal blocking/unblocking a snapshot
getSnapshotC chan struct{}
// closer is a cleanup function
closer func()
//raftNode
raftNode *raftNode
}
// convertToNodeAddrs converts a slice of host:port to a slice of
// kronospb.NodeAddr
func convertToNodeAddrs(addrs []string) ([]*kronospb.NodeAddr, error) {
var seedHostsAddrs []*kronospb.NodeAddr
for _, addr := range addrs {
nodeAddr, err := kronosutil.NodeAddr(addr)
if err != nil {
return nil, err
}
seedHostsAddrs = append(seedHostsAddrs, nodeAddr)
}
return seedHostsAddrs, nil
}
var _ StateMachine = &RaftStateMachine{}
// NewRaftStateMachine returns an instance of a distributed oracle state
// machine managed by raft
func NewRaftStateMachine(ctx context.Context, rc *RaftConfig, nodeID string) StateMachine {
var raftStateMachine *RaftStateMachine
proposeC := make(chan string)
getSnapshot := func() ([]byte, error) { return raftStateMachine.GetSnapshot(ctx) }
if nodeID == ""{
nodeID = metadata.FetchOrAssignNodeID(ctx, rc.DataDir).String()
} else {
metadata.PersistNewNodeID(ctx, nodeID, rc.DataDir)
}
commitC, errorC, snapshotterReady, rn := newRaftNode(rc, getSnapshot, proposeC, nodeID)
raftStateMachine = &RaftStateMachine{
proposeC: proposeC,
snapshotter: <-snapshotterReady,
stateMachine: NewMemStateMachine().(*inMemStateMachine),
getSnapshotC: make(chan struct{}),
raftNode: rn,
}
raftStateMachine.closer = func() {
close(proposeC)
}
// replay existing commits synchronously so that that the state machine is at
// the last known state before initializing.
raftStateMachine.readCommits(ctx, commitC, errorC)
// read commits from raft into RaftStateMachine
go raftStateMachine.readCommits(context.Background(), commitC, errorC)
return raftStateMachine
}
func (s *RaftStateMachine) AddNewSeedHosts(newSeedHosts []string){
addrs, err := convertToNodeAddrs(newSeedHosts)
if err != nil{
log.Info(context.Background(), err)
}
s.raftNode.AddNewSeedHosts(addrs, context.Background())
}
func (s *RaftStateMachine) AddNode(id string, addr string){
nodeAddr, err := kronosutil.NodeAddr(addr)
if err != nil{
log.Error(context.Background(), err)
return
}
s.raftNode.addNode(id, nodeAddr)
if err := s.raftNode.cluster.Persist(); err != nil {
log.Fatalf(context.Background(), "Failed to persist cluster, error: %v", err)
}
//s.raftNode.updateClusterFromConfState(context.Background())
}
func (s *RaftStateMachine) GetID()string{
return s.raftNode.nodeID
}
// Close cleans up RaftStateMachine
func (s *RaftStateMachine) Close() {
s.closer()
}
// State returns a snapshot of the current state of the state machine
func (s *RaftStateMachine) State(ctx context.Context) *kronospb.OracleState {
return s.stateMachine.State(ctx)
}
// SubmitProposal submits a new proposal to the StateMachine. The state machine
// accepts the proposal if PrevID matches the ID of the StateMachine.
// This function does not return anything as the proposal is async.
func (s *RaftStateMachine) SubmitProposal(ctx context.Context, proposal *kronospb.OracleProposal) {
encodedProposal, err := protoutil.Marshal(proposal)
if err != nil {
log.Fatalf(ctx, "Failed to marshal proposal: %v, err: %v", proposal, err)
}
s.proposeC <- string(encodedProposal)
}
// readCommits reads committed messages in raft and applies the messages to
// the in memory state machine
func (s *RaftStateMachine) readCommits(
ctx context.Context, commitC <-chan string, errorC <-chan error,
) {
for data := range commitC {
switch data {
case replayedWALMsg:
// WAL is replayed synchronously when a server restarts
log.Info(ctx, "Done replaying WAL entries on state machine.")
return
case loadSnapshotMsg:
// signaled to load snapshot
if err := s.recoverFromSnapshot(ctx); err != nil {
log.Fatalf(ctx, "Failed to recover from snapshot, err: %v", err)
}
case unblockSnapshotMsg:
// signaled to unblock a snapshot
s.getSnapshotC <- struct{}{}
default:
proposal := &kronospb.OracleProposal{}
if err := protoutil.Unmarshal([]byte(data), proposal); err != nil {
log.Fatalf(ctx, "Failed to unmarshal message %+.100q, err: %v", data, err)
}
s.stateMachine.SubmitProposal(ctx, proposal)
}
}
if err, ok := <-errorC; ok {
log.Fatalf(ctx, "Received error from raft, err: %v", err)
}
}
// GetSnapshot returns a snapshot of the in memory state machine
func (s *RaftStateMachine) GetSnapshot(ctx context.Context) ([]byte, error) {
// Block until we get a signal on getSnapshotC
<-s.getSnapshotC
return protoutil.Marshal(s.State(ctx))
}
// recoverFromSnapshot restores the state of the in memory state machine from a
// snapshot
func (s *RaftStateMachine) recoverFromSnapshot(ctx context.Context) error {
snapshot, err := s.snapshotter.Load()
if err != nil {
// There must always be snapshot to load whenever we are signaled to
// consume it.
return err
}
log.Infof(
ctx,
"Applying snapshot to state machine, index: %d, term: %d, bytes: %d",
snapshot.Metadata.Index,
snapshot.Metadata.Term,
len(snapshot.Data),
)
var data kronospb.OracleState
if err := protoutil.Unmarshal(snapshot.Data, &data); err != nil {
return err
}
s.stateMachine.restoreState(data)
return nil
}