diff --git a/go/cmd/master/master.go b/go/cmd/master/master.go index 25cd1cafcdf32..54fa254863156 100644 --- a/go/cmd/master/master.go +++ b/go/cmd/master/master.go @@ -1,45 +1,69 @@ package main import ( + "fmt" "net" "net/http" "net/rpc" "strconv" + "strings" "time" "github.com/namsral/flag" + log "github.com/sirupsen/logrus" "github.com/PaddlePaddle/Paddle/go/master" + "github.com/PaddlePaddle/Paddle/go/utils/networkhelper" ) func main() { port := flag.Int("port", 8080, "port of the master server.") - - faultTolerance := flag.Bool("fault_tolerance", false, "enable fault tolerance (requires etcd).") + ttlSec := flag.Int("ttl", 60, "etcd lease TTL in seconds.") + endpoints := flag.String("endpoints", "http://127.0.0.1:2379", "comma separated etcd endpoints. If empty, fault tolerance will not be enabled.") taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.") taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.") chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.") flag.Parse() - if *faultTolerance { - panic("fault tolernance not implemented.") + if *endpoints == "" { + log.Warningln("-endpoints not set, fault tolerance not be enabled.") + } + + var store master.Store + if *endpoints != "" { + eps := strings.Split(*endpoints, ",") + ip, err := networkhelper.GetExternalIP() + if err != nil { + log.Fatal(err) + } + addr := fmt.Sprintf("%s:%d", ip, *port) + store, err = master.NewEtcdClient(eps, addr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, *ttlSec) + if err != nil { + log.Fatal(err) + } + } else { + store = &master.InMemStore{} + } + + s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) + if err != nil { + log.Fatal(err) } - s := master.NewService(*chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) - err := rpc.Register(s) + err = rpc.Register(s) if err != nil { - panic(err) + log.Fatal(err) } rpc.HandleHTTP() l, err := net.Listen("tcp", ":"+strconv.Itoa(*port)) if err != nil { - panic(err) + log.Fatal(err) } err = http.Serve(l, nil) if err != nil { - panic(err) + log.Fatal(err) } } diff --git a/go/master/client_internal_test.go b/go/master/client_internal_test.go index 00fcca0e2cf44..251225780ae30 100644 --- a/go/master/client_internal_test.go +++ b/go/master/client_internal_test.go @@ -47,9 +47,13 @@ func TestGetFinishTask(t *testing.T) { } go func(l net.Listener) { - s := NewService(chunkPerTask, time.Second, 1) + s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1) + if err != nil { + panic(err) + } + server := rpc.NewServer() - err := server.Register(s) + err = server.Register(s) if err != nil { panic(err) } diff --git a/go/master/client_test.go b/go/master/client_test.go index 2b3f873ecf3a6..85a86761c2e58 100644 --- a/go/master/client_test.go +++ b/go/master/client_test.go @@ -33,9 +33,13 @@ func TestNextRecord(t *testing.T) { } go func(l net.Listener) { - s := master.NewService(10, time.Second, 1) + s, err := master.NewService(&master.InMemStore{}, 10, time.Second, 1) + if err != nil { + panic(err) + } + server := rpc.NewServer() - err := server.Register(s) + err = server.Register(s) if err != nil { panic(err) } diff --git a/go/master/etcd_client.go b/go/master/etcd_client.go new file mode 100644 index 0000000000000..b7293a759896f --- /dev/null +++ b/go/master/etcd_client.go @@ -0,0 +1,144 @@ +package master + +import ( + "context" + "time" + + "github.com/coreos/etcd/clientv3" + "github.com/coreos/etcd/clientv3/concurrency" + log "github.com/sirupsen/logrus" +) + +const ( + // DefaultLockPath is the default etcd master lock path. + DefaultLockPath = "/master/lock" + // DefaultStatePath is the default etcd key for master state. + DefaultStatePath = "/master/state" + // DefaultAddrPath is the default etcd key for master address. + DefaultAddrPath = "/master/addr" +) + +// EtcdClient is the etcd client that master uses for fault tolerance +// and service registry. +type EtcdClient struct { + lockPath string + statePath string + client *clientv3.Client + lock *concurrency.Mutex +} + +// NewEtcdClient creates a new EtcdClient. +func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) { + log.Debugf("Connecting to etcd at %v", endpoints) + // TODO(helin): gracefully shutdown etcd store. Becuase etcd + // store holds a etcd lock, even though the lock will expire + // when the lease timeout, we need to implement graceful + // shutdown to release the lock. + cli, err := clientv3.New(clientv3.Config{ + Endpoints: endpoints, + DialTimeout: dialTimeout, + }) + if err != nil { + return nil, err + } + + sess, err := concurrency.NewSession(cli, concurrency.WithTTL(ttlSec)) + if err != nil { + return nil, err + } + + lock := concurrency.NewMutex(sess, lockPath) + // It's fine for the lock to get stuck, in this case we have + // multiple master servers running (only configured to have + // one master running, but split-brain problem may cuase + // multiple master servers running), and the cluster management + // software will kill one of them. + log.Debugf("Trying to acquire lock at %s.", lockPath) + err = lock.Lock(context.TODO()) + if err != nil { + return nil, err + } + log.Debugf("Successfully acquired lock at %s.", lockPath) + + put := clientv3.OpPut(addrPath, string(addr)) + resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit() + if err != nil { + return nil, err + } + + if !resp.Succeeded { + log.Fatal("No longer owns the master lock. Exiting.") + } + + e := &EtcdClient{ + lockPath: lockPath, + statePath: statePath, + client: cli, + lock: lock, + } + + return e, nil +} + +// Save saves the state into the etcd. +func (e *EtcdClient) Save(state []byte) error { + ctx := context.TODO() + put := clientv3.OpPut(e.statePath, string(state)) + resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(put).Commit() + if err != nil { + return err + } + + if !resp.Succeeded { + log.Errorln("No longer owns the lock, trying to lock again") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + err := e.lock.Lock(ctx) + cancel() + if err != nil { + // We lost the master lock and can not acquire + // it back, it means some other master is + // already started. We don't want cluster + // managment system to kill the master server + // who is holding the lock and running + // correctly. So the most feasible solution is + // to kill current master server. The current + // state is not saved, but the trainer's RPC + // call will fail, so the trainer will retry. + log.Fatalf("Could not acquire the lock at %s: %v. Exiting.", e.lockPath, err) + } + log.Infof("Successfully acquired lock at %s.", e.lockPath) + return e.Save(state) + } + + return nil +} + +// Load loads the state from etcd. +func (e *EtcdClient) Load() ([]byte, error) { + ctx := context.TODO() + get := clientv3.OpGet(e.statePath) + + resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(get).Commit() + if err != nil { + return nil, err + } + + if !resp.Succeeded { + log.Errorln("No longer owns the lock, trying to lock and load again.") + err = e.lock.Lock(context.Background()) + if err != nil { + return nil, err + } + + return e.Load() + } + + kvs := resp.Responses[0].GetResponseRange().Kvs + if len(kvs) == 0 { + // No state exists + return nil, nil + } + + state := kvs[0].Value + return state, nil +} diff --git a/go/master/inmem_store.go b/go/master/inmem_store.go new file mode 100644 index 0000000000000..bcd549b20e463 --- /dev/null +++ b/go/master/inmem_store.go @@ -0,0 +1,28 @@ +package master + +import "sync" + +// InMemStore is an in memory implementation of Store interface. +// +// It does not tolerate the fault that casues the program to crash. +type InMemStore struct { + mu sync.Mutex + buf []byte +} + +// Save saves the state into the in-memory store. +func (m *InMemStore) Save(state []byte) error { + m.mu.Lock() + defer m.mu.Unlock() + + m.buf = state + return nil +} + +// Load loads the state from the in-memory store. +func (m *InMemStore) Load() ([]byte, error) { + m.mu.Lock() + defer m.mu.Unlock() + + return m.buf, nil +} diff --git a/go/master/service.go b/go/master/service.go index 55e1e2d1a4a5c..58e68e7448599 100644 --- a/go/master/service.go +++ b/go/master/service.go @@ -1,6 +1,9 @@ package master import ( + "bytes" + "compress/gzip" + "encoding/gob" "errors" "os" "path/filepath" @@ -12,24 +15,54 @@ import ( "github.com/PaddlePaddle/recordio" ) +const ( + dialTimeout = 5 * time.Second +) + +// Store is the interface for save and load the master state. +type Store interface { + Save([]byte) error + Load() ([]byte, error) +} + +// Chunk is a chunk of data consisted of several data instances. +type Chunk struct { + Path string + Index recordio.Index // chunk index +} + +// Task is the basic unit of data instances assigned to trainers. +type Task struct { + ID int + Chunks []Chunk +} + +type taskEntry struct { + Epoch int + NumTimeout int + Task Task +} + +type taskQueues struct { + Todo []taskEntry + Pending map[int]taskEntry // map from task ID to task entry + Done []taskEntry + Failed []Task +} + // Service is the master server service. type Service struct { chunksPerTask int timeoutDur time.Duration timeoutMax int ready chan struct{} + store Store mu sync.Mutex initDone bool taskQueues taskQueues } -// Recover recovers service state from etcd. -func Recover() (*Service, error) { - // TODO(helin): recover from snapshot state from etcd. - return nil, nil -} - func partition(chunks []Chunk, chunksPerTask int) []taskEntry { id := 0 if chunksPerTask <= 0 { @@ -58,7 +91,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { } // NewService creates a new service. -func NewService(chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Service { +func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, timeoutMax int) (*Service, error) { s := &Service{} s.chunksPerTask = chunksPerTask s.timeoutDur = timeoutDur @@ -66,38 +99,82 @@ func NewService(chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Se s.taskQueues = taskQueues{} s.taskQueues.Pending = make(map[int]taskEntry) s.ready = make(chan struct{}) - return s -} + s.store = store + recovered, err := s.recover() + if err != nil { + return nil, err + } -// Chunk is a chunk of data consisted of several data instances. -type Chunk struct { - Path string - Index recordio.Index // chunk index -} + if recovered { + // Recovered. Now the state is already initialized, + // and the master is ready. + s.initDone = true + close(s.ready) + log.Info("Master recovered from saved state.") + } -// Task is the basic unit of data instances assigned to trainers. -type Task struct { - ID int - Chunks []Chunk + return s, nil } -type taskEntry struct { - Epoch int - NumTimeout int - Task Task -} +// recover recovers service state from etcd. +func (s *Service) recover() (bool, error) { + state, err := s.store.Load() + if err != nil { + return false, err + } -type taskQueues struct { - Todo []taskEntry - Pending map[int]taskEntry // map from task ID to task entry - Done []taskEntry - Failed []Task + if state == nil { + log.Infoln("No state exists, not recovered.") + return false, nil + } + + log.Infof("Loaded snapshot of size: %d bytes.", len(state)) + gr, err := gzip.NewReader(bytes.NewReader(state)) + if err != nil { + return false, err + } + + dec := gob.NewDecoder(gr) + var tqs taskQueues + err = dec.Decode(&tqs) + if err != nil { + return false, err + } + + err = gr.Close() + if err != nil { + // Only close failed, recover actually succeed, so + // just log error. + log.Errorln(err) + } + + s.taskQueues = tqs + return true, nil } -// *must* be called with s.mu being held. +// snapshot *must* be called with s.mu being held. func (s *Service) snapshot() error { - // TODO(helin): snapshot state on etcd. - return nil + // TOOD(helin): etcd request has a size limit, so the snapshot + // size is limited by the max request size. We should either + // divide the snapshot into smaller chunks and save under + // different keys, or configure the request size to be big + // enough: + // https://github.com/coreos/etcd/blob/2f84f3d8d8ed8f9537ab6ffa44a3a1c7eddfa9b1/embed/config.go#L44 + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + enc := gob.NewEncoder(gw) + err := enc.Encode(s.taskQueues) + if err != nil { + return err + } + err = gw.Close() + if err != nil { + return err + } + + state := buf.Bytes() + log.Infof("Saving snapshot of size: %d bytes.", len(state)) + return s.store.Save(state) } func readChunks(globPaths []string) ([]Chunk, error) { @@ -207,12 +284,12 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { t.NumTimeout++ if t.NumTimeout > s.timeoutMax { - log.Warningf("Task %v timed out %d times, discard.\n", t.Task, t.NumTimeout) + log.Warningf("Task %v timed out %d times, discard.", t.Task, t.NumTimeout) s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task) return } - log.Warningf("Task %v timed out %d times, retry.\n", t.Task, t.NumTimeout) + log.Warningf("Task %v timed out %d times, retry.", t.Task, t.NumTimeout) s.taskQueues.Todo = append(s.taskQueues.Todo, t) } } diff --git a/go/pserver/cclient/cclient.go b/go/pserver/cclient/cclient.go index 92a41b7f54348..bbaf43d9f1434 100644 --- a/go/pserver/cclient/cclient.go +++ b/go/pserver/cclient/cclient.go @@ -133,7 +133,7 @@ func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, if err != nil { if err.Error() == pserver.AlreadyInitialized { - log.Warningf("parameter %s already initialized, treat paddle_init_param as sucessful.\n", name) + log.Warningf("parameter %s already initialized, treat paddle_init_param as sucessful.", name) return C.PSERVER_OK } log.Errorln(err) @@ -200,7 +200,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, for i, p := range ps { pn[i] = p.Name } - log.Errorf("pserver returned wrong number of parameters. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", ")) + log.Errorf("pserver returned wrong number of parameters. Requested: %s, returned: %s.", strings.Join(pn, ", "), strings.Join(ns, ", ")) return C.PSERVER_ERROR } @@ -210,7 +210,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, for i, p := range ps { pn[i] = p.Name } - log.Errorf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", ")) + log.Errorf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.", strings.Join(pn, ", "), strings.Join(ns, ", ")) return C.PSERVER_ERROR } }