diff --git a/go/cmd/pserver/pserver.go b/go/cmd/pserver/pserver.go index 31ef450f032f7..0ecb1242c3c3d 100644 --- a/go/cmd/pserver/pserver.go +++ b/go/cmd/pserver/pserver.go @@ -20,6 +20,8 @@ func main() { "comma separated endpoint string for pserver to connect to etcd") etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls") numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job") + checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path") + checkpointInterval := flag.Int("checkpoint-interval", 600, "save checkpoint per interval seconds") logLevel := flag.String("log-level", "info", "log level, possible values: debug, info, warning, error, fatal, panic") flag.Parse() @@ -31,18 +33,20 @@ func main() { log.SetLevel(level) var idx int + var cp pserver.Checkpoint + var e *pserver.EtcdClient if *index >= 0 { idx = *index } else { timeout := time.Second * time.Duration((*etcdTimeout)) - e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout) + e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout) idx, err = e.Register() if err != nil { panic(err) } } - s, err := pserver.NewService(idx) + s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp) if err != nil { panic(err) } diff --git a/go/pserver/etcd_client.go b/go/pserver/etcd_client.go index 37b8d522c1bd0..1f77787150d16 100644 --- a/go/pserver/etcd_client.go +++ b/go/pserver/etcd_client.go @@ -18,6 +18,8 @@ const ( PsDesired = "/ps_desired" // PsAddr is the base dir for pserver to store their addr PsPath = "/ps/" + // PsCheckpoint is the etcd path for store checkpoints information + PsCheckpoint = "/checkpoints/" ) // EtcdClient is the etcd client that the pserver uses for fault @@ -186,3 +188,14 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { return idx, nil } + +// PutKey put into etcd with value by key specified +func (e *EtcdClient) PutKey(key string, value []byte, timeout int) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout)) + _, err := e.etcdClient.Put(ctx, key, string(value)) + cancel() + if err != nil { + return err + } + return nil +} diff --git a/go/pserver/optimizer.go b/go/pserver/optimizer.go index 54d108209402c..2d7882d1a75ef 100644 --- a/go/pserver/optimizer.go +++ b/go/pserver/optimizer.go @@ -35,22 +35,28 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte { return (*[1 << 30]byte)(p)[:len:len] } -func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer { +func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer { o := &optimizer{} o.elementType = paramWithConfigs.Param.ElementType p := paramWithConfigs.Param c := paramWithConfigs.Config + s := State log.WithFields(log.Fields{ "ElementType": p.ElementType, "ParamSize": len(p.Content), "ConfigSize": len(c), + "StateSize": len(s), }).Info("New Optimizer Created with config:") var cbuffer unsafe.Pointer cbuffer = C.malloc(C.size_t(len(p.Content))) C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content))) + var cstate unsafe.Pointer + if len(s) != 0 { + cstate = unsafe.Pointer(&s[0]) + } + o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)), - C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)/C.sizeof_float), - (*C.char)(nullPtr), 0) + C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)/C.sizeof_float), (*C.char)(cstate), C.int(len(s))) return o } @@ -60,6 +66,12 @@ func (o *optimizer) GetWeights() []byte { return cArrayToSlice(buffer, int(bufferLen)*C.sizeof_float) } +func (o *optimizer) GetStates() []byte { + var cbuffer *C.char + cbuffer_len := C.paddle_optimizer_get_state(o.opt, &cbuffer) + return cArrayToSlice(unsafe.Pointer(cbuffer), int(cbuffer_len)) +} + func (o *optimizer) UpdateParameter(g Gradient) error { if o.elementType != g.ElementType { return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, o.elementType, g.ElementType) diff --git a/go/pserver/optimizer_test.go b/go/pserver/optimizer_test.go index 0b2f4cfa41a63..d19e9de92e0b3 100644 --- a/go/pserver/optimizer_test.go +++ b/go/pserver/optimizer_test.go @@ -19,6 +19,6 @@ func TestOptimizerCreateRelease(t *testing.T) { Param: p, Config: config, } - o := newOptimizer(param) + o := newOptimizer(param, nil) o.Cleanup() } diff --git a/go/pserver/service.go b/go/pserver/service.go index ad16a5708d10b..6b52d0d896f8b 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -1,9 +1,21 @@ package pserver import ( + "bufio" + "bytes" + "crypto/md5" + "encoding/gob" + "encoding/hex" + "encoding/json" "errors" "fmt" + "os" + "path/filepath" + "strconv" "sync" + "time" + + log "github.com/sirupsen/logrus" ) // ElementType is the type of elements of a Parameter. @@ -39,26 +51,55 @@ type ParameterWithConfig struct { Config []byte // parameter configuration in Proto Buffer format } +// ParameterCheckpoint is Parameter and State checkpoint +type ParameterCheckpoint struct { + ParamConfig ParameterWithConfig + State []byte +} + +// checkpoint signature +type checkpointMeta struct { + UUID string `json:"uuid"` + Md5sum string `json:"md5sum"` + Timestamp string `json:"timestamp"` +} + +// Checkpoint is the pserver shard persist in file +type Checkpoint []ParameterCheckpoint + // Gradient is the gradient of the parameter. type Gradient Parameter // Service is the RPC service for pserver. type Service struct { - initialized chan struct{} - idx int - - mu sync.Mutex - optMap map[string]*optimizer + initialized chan struct{} + idx int + checkpointInterval time.Duration + checkpointPath string + client *EtcdClient + mu sync.Mutex + optMap map[string]*optimizer } // NewService creates a new service, will bypass etcd registration if no // endpoints specified. -func NewService(idx int) (*Service, error) { +func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) { s := &Service{ - idx: idx, + idx: idx, + checkpointInterval: time.Second * time.Duration(seconds), + checkpointPath: path, + client: client, } s.optMap = make(map[string]*optimizer) s.initialized = make(chan struct{}) + + if cp != nil { + for _, item := range cp { + p := item.ParamConfig + st := item.State + s.optMap[p.Param.Name] = newOptimizer(p, st) + } + } return s, nil } @@ -78,7 +119,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er // TODO(helin): check if paramWithConfigs.Param.Content is // properly memory aligned, if not, make copy to a memory // aligned region. - s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs) + s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil) return nil } @@ -139,10 +180,57 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { return nil } -// Save tells the parameter server to save parameters. -func (s *Service) Save(path string, dummy *int) error { +// pserver save checkpoint +func (s *Service) doCheckpoint() error { <-s.initialized + s.mu.Lock() + defer s.mu.Unlock() + + cp := make([]ParameterCheckpoint, 0, len(s.optMap)) + index := 0 + for name, opt := range s.optMap { + var pc ParameterCheckpoint + pc.ParamConfig.Param.Name = name + pc.ParamConfig.Param.ElementType = opt.elementType + pc.ParamConfig.Param.Content = opt.GetWeights() + pc.State = opt.GetStates() + cp[index] = pc + index++ + } + var buf bytes.Buffer + encoder := gob.NewEncoder(&buf) + err := encoder.Encode(cp) + if err != nil { + return err + } + + cpMeta := checkpointMeta{} + cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx) + cpMeta.Timestamp = time.Now().String() + h := md5.New() + cpMeta.Md5sum = hex.EncodeToString(h.Sum(buf.Bytes())) - // TODO + cpMetajson, _ := json.Marshal(cpMeta) + err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3) + if err != nil { + return err + } + if _, err = os.Stat(cpMeta.UUID); os.IsNotExist(err) { + log.Info("checkpoint does not exists.") + } else { + err = os.Remove(cpMeta.UUID) + log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID) + } + f, err := os.Create(cpMeta.UUID) + defer f.Close() + if err != nil { + return err + } + writer := bufio.NewWriter(f) + _, err = writer.Write(buf.Bytes()) + writer.Flush() + if err != nil { + return err + } return nil } diff --git a/go/pserver/service_test.go b/go/pserver/service_test.go index b6d20d2c8b7ba..9bf1a48a596f3 100644 --- a/go/pserver/service_test.go +++ b/go/pserver/service_test.go @@ -15,7 +15,8 @@ const ( ) func TestServiceFull(t *testing.T) { - s, err := pserver.NewService(0) + var cp pserver.Checkpoint + s, err := pserver.NewService(0, 1, "", nil, cp) if err != nil { t.Error(err) } @@ -86,7 +87,8 @@ func TestServiceFull(t *testing.T) { } func TestMultipleInit(t *testing.T) { - s, err := pserver.NewService(0) + var cp pserver.Checkpoint + s, err := pserver.NewService(0, 1, "", nil, cp) if err != nil { t.Error(err) } @@ -102,7 +104,8 @@ func TestMultipleInit(t *testing.T) { } func TestUninitialized(t *testing.T) { - s, err := pserver.NewService(0) + var cp pserver.Checkpoint + s, err := pserver.NewService(0, 1, "", nil, cp) err = s.SendGrad(pserver.Gradient{}, nil) if err.Error() != pserver.Uninitialized { t.FailNow() @@ -110,7 +113,8 @@ func TestUninitialized(t *testing.T) { } func TestBlockUntilInitialized(t *testing.T) { - s, err := pserver.NewService(0) + var cp pserver.Checkpoint + s, err := pserver.NewService(0, 1, "", nil, cp) if err != nil { t.Error(err) } @@ -128,16 +132,6 @@ func TestBlockUntilInitialized(t *testing.T) { ch <- struct{}{} }() - wg.Add(1) - go func() { - err := s.Save("", nil) - if err != nil { - errCh <- err - } - wg.Done() - ch <- struct{}{} - }() - time.Sleep(50 * time.Millisecond) select { @@ -170,3 +164,7 @@ func TestBlockUntilInitialized(t *testing.T) { wg.Wait() } + +func TestCheckpointSpeed(t *testing.T) { + //TODO(zhihong): test speed +}