From 5ef1425adb75eb1b0212518e0f12fefd8d9a8970 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Mon, 3 Jul 2017 21:13:20 +0800 Subject: [PATCH 1/8] "init saving model" --- go/pserver/optimizer.go | 16 ++++++++++++++-- go/pserver/service.go | 12 +++++++++--- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/go/pserver/optimizer.go b/go/pserver/optimizer.go index b4a040f46bff5..427251f900836 100644 --- a/go/pserver/optimizer.go +++ b/go/pserver/optimizer.go @@ -40,17 +40,23 @@ func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer { o.elementType = paramWithConfigs.Param.ElementType p := paramWithConfigs.Param c := paramWithConfigs.Config + s := paramWithConfigs.State log.WithFields(log.Fields{ "ElementType": p.ElementType, "ParamSize": len(p.Content), "ConfigSize": len(c), + "StateSize": len(s), }).Info("New Optimizer Created with config:") var cbuffer unsafe.Pointer cbuffer = C.malloc(C.size_t(len(p.Content))) C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content))) + var cstate unsafe.Pointer + if len(s) != 0 { + cstate = unsafe.Pointer(&s[0]) + } + o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)), - C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)/C.sizeof_float), - (*C.char)(nullPtr), 0) + C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)/C.sizeof_float), (*C.char)(cstate), C.int(len(s))) return o } @@ -60,6 +66,12 @@ func (o *optimizer) GetWeights() []byte { return cArrayToSlice(buffer, int(buffer_len)*C.sizeof_float) } +func (o *optimizer) GetStates() []byte { + var cbuffer *C.char + cbuffer_len := C.paddle_optimizer_get_state(o.opt, &cbuffer) + return cArrayToSlice(unsafe.Pointer(cbuffer), int(cbuffer_len)) +} + func (o *optimizer) UpdateParameter(g Gradient) error { if o.elementType != g.ElementType { return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, o.elementType, g.ElementType) diff --git a/go/pserver/service.go b/go/pserver/service.go index e15a4e5a58a3b..a5ff86290333e 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -38,6 +38,7 @@ type Parameter struct { type ParameterWithConfig struct { Param Parameter Config []byte // parameter configuration in Proto Buffer format + State []byte // parameter training state } // Gradient is the gradient of the parameter. @@ -58,7 +59,7 @@ func NewService(idx int) (*Service, error) { s := &Service{ idx: idx, } - s.optMap = make(map[string]*optimizer) + s.optMap = make(map[string]*optimizer) s.initialized = make(chan struct{}) return s, nil } @@ -143,7 +144,12 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { // Save tells the parameter server to save parameters. func (s *Service) Save(path string, dummy *int) error { <-s.initialized - - // TODO + for opt, ok := range s.optMap { + if ok != nil { + return fmt.Errorf("parameter optimizerMap error: ", ok) + } + state := opt.GetStates() + weights := opt.GetWeights() + } return nil } From f1330e216a1b8130bb578b69ff2d6a67357cdd1b Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Mon, 3 Jul 2017 23:20:39 +0800 Subject: [PATCH 2/8] "saving checkpoint" --- go/pserver/service.go | 79 +++++++++++++++++++++++++++++++++++--- go/pserver/service_test.go | 6 +++ 2 files changed, 80 insertions(+), 5 deletions(-) diff --git a/go/pserver/service.go b/go/pserver/service.go index a5ff86290333e..a4cf3e4750744 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -1,9 +1,19 @@ package pserver import ( + "bufio" + "bytes" + "crypto/md5" + "encoding/gob" + "encoding/hex" "errors" "fmt" + "os" + "strconv" "sync" + "time" + + log "github.com/sirupsen/logrus" ) // ElementType is the type of elements of a Parameter. @@ -14,6 +24,10 @@ const ( Uninitialized = "pserver not fully initialized" ) +const ( + checkpoint_path = "/checkpoints/" +) + // Supported element types const ( Int32 ElementType = iota @@ -53,6 +67,24 @@ type Service struct { optMap map[string]*optimizer } +type Checkpoint struct { + uuid string + md5sum string + timestamp string +} + +//serialize ParameterWithConfig to byte stream +func GetBytes(content ...interface{}) ([]byte, error) { + + var buf bytes.Buffer + encoder := gob.NewEncoder(&buf) + err := encoder.Encode(content) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + // NewService creates a new service, will bypass etcd registration if no // endpoints specified. func NewService(idx int) (*Service, error) { @@ -143,13 +175,50 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { // Save tells the parameter server to save parameters. func (s *Service) Save(path string, dummy *int) error { + //FIXME: checkpoint is only used by pserver + // and has a constant path of */checkpoints/{pserver_idx}* <-s.initialized - for opt, ok := range s.optMap { - if ok != nil { - return fmt.Errorf("parameter optimizerMap error: ", ok) + s.mu.Lock() + defer s.mu.Unlock() + var paramWithConfig ParameterWithConfig + for name, opt := range s.optMap { + paramWithConfig.Param.Name = name + paramWithConfig.Param.ElementType = opt.elementType + paramWithConfig.Param.Content = opt.GetWeights() + paramWithConfig.State = opt.GetStates() + content, err := GetBytes(paramWithConfig) + if err != nil { + log.Errorln(err) + } + ck := Checkpoint{} + h := md5.New() + ck.md5sum = hex.EncodeToString(h.Sum(content)) + ck.timestamp = time.Now().String() + ck.uuid = checkpoint_path + strconv.Itoa(s.idx) + ckbytes, err := GetBytes(ck) + if err != nil { + log.Errorln(err) + } + // TODO: according design doc, need to save uuid to etcd in json format + // {\"uuid\": [UUID], \"md5\", \"MD5 sum\", \"timestamp\": xxxx} + log.Infof("parameter checkpoint %s", ckbytes) + + if _, err = os.Stat(ck.uuid); os.IsNotExist(err) { + log.Info("checkpoint not exists.") + } else { + err = os.Remove(ck.uuid) + log.Infof("remove %s", ck.uuid) + } + f, err := os.Create(ck.uuid) + defer f.Close() + if err != nil { + log.Errorln(err) + } + writer := bufio.NewWriter(f) + _, err = writer.Write(content) + if err != nil { + log.Errorln(err) } - state := opt.GetStates() - weights := opt.GetWeights() } return nil } diff --git a/go/pserver/service_test.go b/go/pserver/service_test.go index f86619447c28b..28956e4d851a6 100644 --- a/go/pserver/service_test.go +++ b/go/pserver/service_test.go @@ -79,6 +79,8 @@ func TestServiceFull(t *testing.T) { if !reflect.DeepEqual(param1, p) { t.FailNow() } + var dummy int + s.Save("", &dummy) } func TestMultipleInit(t *testing.T) { @@ -166,3 +168,7 @@ func TestBlockUntilInitialized(t *testing.T) { wg.Wait() } + +func TestCheckpointSpeed(t *testing.T) { + //TODO: test speed +} From 65afbe11853c2e32ca4196965e309e33ab843fd1 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Mon, 3 Jul 2017 23:38:21 +0800 Subject: [PATCH 3/8] "fix gob register error" --- go/pserver/service.go | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/go/pserver/service.go b/go/pserver/service.go index a4cf3e4750744..decd3682aec94 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -25,7 +25,7 @@ const ( ) const ( - checkpoint_path = "/checkpoints/" + checkpoint_path = "./checkpoints/" ) // Supported element types @@ -67,10 +67,10 @@ type Service struct { optMap map[string]*optimizer } -type Checkpoint struct { - uuid string - md5sum string - timestamp string +type checkpoint struct { + Uuid string + Md5sum string + Timestamp string } //serialize ParameterWithConfig to byte stream @@ -93,6 +93,8 @@ func NewService(idx int) (*Service, error) { } s.optMap = make(map[string]*optimizer) s.initialized = make(chan struct{}) + gob.Register(ParameterWithConfig{}) + gob.Register(checkpoint{}) return s, nil } @@ -190,32 +192,33 @@ func (s *Service) Save(path string, dummy *int) error { if err != nil { log.Errorln(err) } - ck := Checkpoint{} + ck := checkpoint{} h := md5.New() - ck.md5sum = hex.EncodeToString(h.Sum(content)) - ck.timestamp = time.Now().String() - ck.uuid = checkpoint_path + strconv.Itoa(s.idx) + ck.Md5sum = hex.EncodeToString(h.Sum(content)) + ck.Timestamp = time.Now().String() + ck.Uuid = checkpoint_path + strconv.Itoa(s.idx) ckbytes, err := GetBytes(ck) if err != nil { log.Errorln(err) } - // TODO: according design doc, need to save uuid to etcd in json format - // {\"uuid\": [UUID], \"md5\", \"MD5 sum\", \"timestamp\": xxxx} + // TODO: according design doc, need to save Uuid to etcd in json format + // {\"Uuid\": [UUID], \"md5\", \"MD5 sum\", \"Timestamp\": xxxx} log.Infof("parameter checkpoint %s", ckbytes) - if _, err = os.Stat(ck.uuid); os.IsNotExist(err) { + if _, err = os.Stat(ck.Uuid); os.IsNotExist(err) { log.Info("checkpoint not exists.") } else { - err = os.Remove(ck.uuid) - log.Infof("remove %s", ck.uuid) + err = os.Remove(ck.Uuid) + log.Infof("remove %s", ck.Uuid) } - f, err := os.Create(ck.uuid) + f, err := os.Create(ck.Uuid) defer f.Close() if err != nil { log.Errorln(err) } writer := bufio.NewWriter(f) _, err = writer.Write(content) + writer.Flush() if err != nil { log.Errorln(err) } From 40295b9ed9ede878c930c6fc9ce6719c8270db07 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Fri, 7 Jul 2017 19:56:29 +0800 Subject: [PATCH 4/8] "fix pserver saving etcd" --- go/cmd/pserver/pserver.go | 5 +- go/pserver/etcd_client.go | 13 +++ go/pserver/optimizer.go | 4 +- go/pserver/service.go | 170 +++++++++++++++++++++---------------- go/pserver/service_test.go | 5 +- 5 files changed, 116 insertions(+), 81 deletions(-) diff --git a/go/cmd/pserver/pserver.go b/go/cmd/pserver/pserver.go index 31ef450f032f7..56c1f6e1db609 100644 --- a/go/cmd/pserver/pserver.go +++ b/go/cmd/pserver/pserver.go @@ -20,6 +20,8 @@ func main() { "comma separated endpoint string for pserver to connect to etcd") etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls") numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job") + checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path") + checkpointInterval := flag.Int("checkpoint-interval", "10", "save checkpoint per interval seconds") logLevel := flag.String("log-level", "info", "log level, possible values: debug, info, warning, error, fatal, panic") flag.Parse() @@ -31,6 +33,7 @@ func main() { log.SetLevel(level) var idx int + var cp pserver.Checkpoint if *index >= 0 { idx = *index } else { @@ -42,7 +45,7 @@ func main() { } } - s, err := pserver.NewService(idx) + s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp) if err != nil { panic(err) } diff --git a/go/pserver/etcd_client.go b/go/pserver/etcd_client.go index 37b8d522c1bd0..20041d04d089a 100644 --- a/go/pserver/etcd_client.go +++ b/go/pserver/etcd_client.go @@ -18,6 +18,8 @@ const ( PsDesired = "/ps_desired" // PsAddr is the base dir for pserver to store their addr PsPath = "/ps/" + // PsCheckpoint is the etcd path for store checkpoints information + PsCheckpoint = "/checkpoints/" ) // EtcdClient is the etcd client that the pserver uses for fault @@ -186,3 +188,14 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { return idx, nil } + +// PutKey put into etcd with value by key specified +func (e *EtcdClient) PutKey(key string, value []byte, timeout int) error { + ctx, err := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout)) + _, err = e.Put(ctx, key, value) + cancel() + if err != nil { + return err + } + return nil +} diff --git a/go/pserver/optimizer.go b/go/pserver/optimizer.go index 1c84e728e0b24..2d7882d1a75ef 100644 --- a/go/pserver/optimizer.go +++ b/go/pserver/optimizer.go @@ -35,12 +35,12 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte { return (*[1 << 30]byte)(p)[:len:len] } -func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer { +func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer { o := &optimizer{} o.elementType = paramWithConfigs.Param.ElementType p := paramWithConfigs.Param c := paramWithConfigs.Config - s := paramWithConfigs.State + s := State log.WithFields(log.Fields{ "ElementType": p.ElementType, "ParamSize": len(p.Content), diff --git a/go/pserver/service.go b/go/pserver/service.go index d1d041de59e0d..f27feb247d84a 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -5,10 +5,11 @@ import ( "bytes" "crypto/md5" "encoding/gob" - "encoding/hex" + "encoding/json" "errors" "fmt" "os" + "path/filepath" "strconv" "sync" "time" @@ -26,10 +27,6 @@ const ( Uninitialized = "pserver not fully initialized" ) -const ( - checkpoint_path = "./checkpoints/" -) - // Supported element types const ( Int32 ElementType = iota @@ -51,49 +48,68 @@ type Parameter struct { type ParameterWithConfig struct { Param Parameter Config []byte // parameter configuration in Proto Buffer format - State []byte // parameter training state } +// Checkpoint of Parameter and State +type parameterCheckPoint struct { + ParamConfig ParameterWithConfig + State []byte +} + +// checkpoint signature +type checkpointMeta struct { + UUID string `json:"uuid"` + Md5sum string `json:"md5sum"` + Timestamp string `json:"timestamp"` +} + +// Checkpoint is the pserver shard persist in file +type Checkpoint []parameterCheckPoint + // Gradient is the gradient of the parameter. -type Gradient Parameter // Service is the RPC service for pserver. type Service struct { - initialized chan struct{} - idx int - - mu sync.Mutex - optMap map[string]*optimizer + initialized chan struct{} + idx int + checkpointInterval int + checkpointPath string + client *EtcdClient + mu sync.Mutex + optMap map[string]*optimizer } -type checkpoint struct { - Uuid string - Md5sum string - Timestamp string -} +// //serialize ParameterWithConfig to byte stream +// func GetBytes(content ...interface{}) ([]byte, error) { -//serialize ParameterWithConfig to byte stream -func GetBytes(content ...interface{}) ([]byte, error) { - - var buf bytes.Buffer - encoder := gob.NewEncoder(&buf) - err := encoder.Encode(content) - if err != nil { - return nil, err - } - return buf.Bytes(), nil -} +// var buf bytes.Buffer +// encoder := gob.NewEncoder(&buf) +// err := encoder.Encode(content) +// if err != nil { +// return nil, err +// } +// return buf.Bytes(), nil +// } // NewService creates a new service, will bypass etcd registration if no // endpoints specified. -func NewService(idx int) (*Service, error) { +func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) { s := &Service{ - idx: idx, + idx: idx, + checkpointInterval: time.Second * time.Duration(seconds), + checkpointPath: path, + client: client, } s.optMap = make(map[string]*optimizer) s.initialized = make(chan struct{}) - gob.Register(ParameterWithConfig{}) - gob.Register(checkpoint{}) + + if cp != nil { + for _, item := range cp { + p := item.ParamConfig + st := item.State + s.optMap[p.Param.Name] = newOptimizer(p, st) + } + } return s, nil } @@ -174,53 +190,57 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { return nil } -// Save tells the parameter server to save parameters. -func (s *Service) Save(path string, dummy *int) error { - //FIXME: checkpoint is only used by pserver - // and has a constant path of */checkpoints/{pserver_idx}* +// pserver save checkpoint +func (s *Service) doCheckpoint() error { <-s.initialized s.mu.Lock() defer s.mu.Unlock() - var paramWithConfig ParameterWithConfig + + cp := make([]parameterCheckPoint, 0, len(s.optMap)) + index := 0 for name, opt := range s.optMap { - paramWithConfig.Param.Name = name - paramWithConfig.Param.ElementType = opt.elementType - paramWithConfig.Param.Content = opt.GetWeights() - paramWithConfig.State = opt.GetStates() - content, err := GetBytes(paramWithConfig) - if err != nil { - log.Errorln(err) - } - ck := checkpoint{} - h := md5.New() - ck.Md5sum = hex.EncodeToString(h.Sum(content)) - ck.Timestamp = time.Now().String() - ck.Uuid = checkpoint_path + strconv.Itoa(s.idx) - ckbytes, err := GetBytes(ck) - if err != nil { - log.Errorln(err) - } - // TODO: according design doc, need to save Uuid to etcd in json format - // {\"Uuid\": [UUID], \"md5\", \"MD5 sum\", \"Timestamp\": xxxx} - log.Infof("parameter checkpoint %s", ckbytes) - - if _, err = os.Stat(ck.Uuid); os.IsNotExist(err) { - log.Info("checkpoint not exists.") - } else { - err = os.Remove(ck.Uuid) - log.Infof("remove %s", ck.Uuid) - } - f, err := os.Create(ck.Uuid) - defer f.Close() - if err != nil { - log.Errorln(err) - } - writer := bufio.NewWriter(f) - _, err = writer.Write(content) - writer.Flush() - if err != nil { - log.Errorln(err) - } + var pc parameterCheckPoint + pc.ParamConfig.Param.Name = name + pc.ParamConfig.Param.ElementType = opt.elementType + pc.ParamConfig.Param.Content = opt.GetWeights() + pc.State = opt.GetStates() + cp[index] = pc + index++ + } + var buf bytes.Buffer + encoder := gob.NewEncoder(&buf) + err := encoder.Encode(cp) + if err != nil { + return err + } + + cpMeta := checkpointMeta{} + cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx) + cpMeta.Timestamp = time.Now().String() + h := md5.New() + cpMeta.Md5sum = h.Sum(buf.Bytes()) + + cpMetajson, err := json.Marshal(cpMeta) + s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3) + if err != nil { + return err + } + if _, err = os.Stat(cpMeta.UUID); os.IsNotExist(err) { + log.Info("checkpoint does not exists.") + } else { + err = os.Remove(cpMeta.UUID) + log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID) + } + f, err := os.Create(cpMeta.UUID) + defer f.Close() + if err != nil { + log.Errorln(err) + } + writer := bufio.NewWriter(f) + _, err = writer.Write(buf.Bytes()) + writer.Flush() + if err != nil { + log.Errorln(err) } return nil } diff --git a/go/pserver/service_test.go b/go/pserver/service_test.go index 65a791ae477fc..75d4732ea7862 100644 --- a/go/pserver/service_test.go +++ b/go/pserver/service_test.go @@ -15,7 +15,8 @@ const ( ) func TestServiceFull(t *testing.T) { - s, err := pserver.NewService(0) + var cp pserver.Checkpoint + s, err := pserver.NewService(0, 1, "", nil, cp) if err != nil { t.Error(err) } @@ -83,8 +84,6 @@ func TestServiceFull(t *testing.T) { if !reflect.DeepEqual(param1, p) { t.FailNow() } - var dummy int - s.Save("", &dummy) } func TestMultipleInit(t *testing.T) { From 774604cdb8ec563efb85832b77a861e8aad36eb3 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Fri, 7 Jul 2017 20:17:13 +0800 Subject: [PATCH 5/8] "add more NewService argument" --- go/pserver/etcd_client.go | 4 ++-- go/pserver/optimizer_test.go | 2 +- go/pserver/service.go | 20 +++++--------------- go/pserver/service_test.go | 19 ++++++------------- 4 files changed, 14 insertions(+), 31 deletions(-) diff --git a/go/pserver/etcd_client.go b/go/pserver/etcd_client.go index 20041d04d089a..1f77787150d16 100644 --- a/go/pserver/etcd_client.go +++ b/go/pserver/etcd_client.go @@ -191,8 +191,8 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { // PutKey put into etcd with value by key specified func (e *EtcdClient) PutKey(key string, value []byte, timeout int) error { - ctx, err := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout)) - _, err = e.Put(ctx, key, value) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout)) + _, err := e.etcdClient.Put(ctx, key, string(value)) cancel() if err != nil { return err diff --git a/go/pserver/optimizer_test.go b/go/pserver/optimizer_test.go index 0b2f4cfa41a63..d19e9de92e0b3 100644 --- a/go/pserver/optimizer_test.go +++ b/go/pserver/optimizer_test.go @@ -19,6 +19,6 @@ func TestOptimizerCreateRelease(t *testing.T) { Param: p, Config: config, } - o := newOptimizer(param) + o := newOptimizer(param, nil) o.Cleanup() } diff --git a/go/pserver/service.go b/go/pserver/service.go index f27feb247d84a..cb3741af7ab95 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -5,6 +5,7 @@ import ( "bytes" "crypto/md5" "encoding/gob" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -67,30 +68,19 @@ type checkpointMeta struct { type Checkpoint []parameterCheckPoint // Gradient is the gradient of the parameter. +type Gradient Parameter // Service is the RPC service for pserver. type Service struct { initialized chan struct{} idx int - checkpointInterval int + checkpointInterval time.Duration checkpointPath string client *EtcdClient mu sync.Mutex optMap map[string]*optimizer } -// //serialize ParameterWithConfig to byte stream -// func GetBytes(content ...interface{}) ([]byte, error) { - -// var buf bytes.Buffer -// encoder := gob.NewEncoder(&buf) -// err := encoder.Encode(content) -// if err != nil { -// return nil, err -// } -// return buf.Bytes(), nil -// } - // NewService creates a new service, will bypass etcd registration if no // endpoints specified. func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) { @@ -129,7 +119,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er // TODO(helin): check if paramWithConfigs.Param.Content is // properly memory aligned, if not, make copy to a memory // aligned region. - s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs) + s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil) return nil } @@ -218,7 +208,7 @@ func (s *Service) doCheckpoint() error { cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx) cpMeta.Timestamp = time.Now().String() h := md5.New() - cpMeta.Md5sum = h.Sum(buf.Bytes()) + cpMeta.Md5sum = hex.EncodeToString(h.Sum(buf.Bytes())) cpMetajson, err := json.Marshal(cpMeta) s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3) diff --git a/go/pserver/service_test.go b/go/pserver/service_test.go index 75d4732ea7862..f365a4539a2d0 100644 --- a/go/pserver/service_test.go +++ b/go/pserver/service_test.go @@ -87,7 +87,8 @@ func TestServiceFull(t *testing.T) { } func TestMultipleInit(t *testing.T) { - s, err := pserver.NewService(0) + var cp pserver.Checkpoint + s, err := pserver.NewService(0, 1, "", nil, cp) if err != nil { t.Error(err) } @@ -103,7 +104,8 @@ func TestMultipleInit(t *testing.T) { } func TestUninitialized(t *testing.T) { - s, err := pserver.NewService(0) + var cp pserver.Checkpoint + s, err := pserver.NewService(0, 1, "", nil, cp) err = s.SendGrad(pserver.Gradient{}, nil) if err.Error() != pserver.Uninitialized { t.FailNow() @@ -111,7 +113,8 @@ func TestUninitialized(t *testing.T) { } func TestBlockUntilInitialized(t *testing.T) { - s, err := pserver.NewService(0) + var cp pserver.Checkpoint + s, err := pserver.NewService(0, 1, "", nil, cp) if err != nil { t.Error(err) } @@ -129,16 +132,6 @@ func TestBlockUntilInitialized(t *testing.T) { ch <- struct{}{} }() - wg.Add(1) - go func() { - err := s.Save("", nil) - if err != nil { - errCh <- err - } - wg.Done() - ch <- struct{}{} - }() - time.Sleep(50 * time.Millisecond) select { From 87e7924e4e904b98c2a1a4f817ae5f6646e69138 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Sat, 8 Jul 2017 00:31:33 +0800 Subject: [PATCH 6/8] "pserver flags type error" --- go/cmd/pserver/pserver.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/go/cmd/pserver/pserver.go b/go/cmd/pserver/pserver.go index 56c1f6e1db609..4abce3beb3872 100644 --- a/go/cmd/pserver/pserver.go +++ b/go/cmd/pserver/pserver.go @@ -21,7 +21,7 @@ func main() { etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls") numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job") checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path") - checkpointInterval := flag.Int("checkpoint-interval", "10", "save checkpoint per interval seconds") + checkpointInterval := flag.Int("checkpoint-interval", 10, "save checkpoint per interval seconds") logLevel := flag.String("log-level", "info", "log level, possible values: debug, info, warning, error, fatal, panic") flag.Parse() @@ -34,11 +34,12 @@ func main() { var idx int var cp pserver.Checkpoint + var e *pserver.EtcdClient if *index >= 0 { idx = *index } else { timeout := time.Second * time.Duration((*etcdTimeout)) - e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout) + e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout) idx, err = e.Register() if err != nil { panic(err) From 0ad7053e96b540120854daedfdb37230f3fc70ab Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Sun, 9 Jul 2017 15:12:09 +0800 Subject: [PATCH 7/8] "make parameterCheckpoint exported" --- go/cmd/pserver/pserver.go | 2 +- go/pserver/service.go | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/go/cmd/pserver/pserver.go b/go/cmd/pserver/pserver.go index 4abce3beb3872..0ecb1242c3c3d 100644 --- a/go/cmd/pserver/pserver.go +++ b/go/cmd/pserver/pserver.go @@ -21,7 +21,7 @@ func main() { etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls") numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job") checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path") - checkpointInterval := flag.Int("checkpoint-interval", 10, "save checkpoint per interval seconds") + checkpointInterval := flag.Int("checkpoint-interval", 600, "save checkpoint per interval seconds") logLevel := flag.String("log-level", "info", "log level, possible values: debug, info, warning, error, fatal, panic") flag.Parse() diff --git a/go/pserver/service.go b/go/pserver/service.go index cb3741af7ab95..6b52d0d896f8b 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -51,8 +51,8 @@ type ParameterWithConfig struct { Config []byte // parameter configuration in Proto Buffer format } -// Checkpoint of Parameter and State -type parameterCheckPoint struct { +// ParameterCheckpoint is Parameter and State checkpoint +type ParameterCheckpoint struct { ParamConfig ParameterWithConfig State []byte } @@ -65,7 +65,7 @@ type checkpointMeta struct { } // Checkpoint is the pserver shard persist in file -type Checkpoint []parameterCheckPoint +type Checkpoint []ParameterCheckpoint // Gradient is the gradient of the parameter. type Gradient Parameter @@ -186,10 +186,10 @@ func (s *Service) doCheckpoint() error { s.mu.Lock() defer s.mu.Unlock() - cp := make([]parameterCheckPoint, 0, len(s.optMap)) + cp := make([]ParameterCheckpoint, 0, len(s.optMap)) index := 0 for name, opt := range s.optMap { - var pc parameterCheckPoint + var pc ParameterCheckpoint pc.ParamConfig.Param.Name = name pc.ParamConfig.Param.ElementType = opt.elementType pc.ParamConfig.Param.Content = opt.GetWeights() @@ -210,8 +210,8 @@ func (s *Service) doCheckpoint() error { h := md5.New() cpMeta.Md5sum = hex.EncodeToString(h.Sum(buf.Bytes())) - cpMetajson, err := json.Marshal(cpMeta) - s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3) + cpMetajson, _ := json.Marshal(cpMeta) + err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3) if err != nil { return err } @@ -224,13 +224,13 @@ func (s *Service) doCheckpoint() error { f, err := os.Create(cpMeta.UUID) defer f.Close() if err != nil { - log.Errorln(err) + return err } writer := bufio.NewWriter(f) _, err = writer.Write(buf.Bytes()) writer.Flush() if err != nil { - log.Errorln(err) + return err } return nil } From e8296ff29186944a4b0b2ee3729c7e3433eacd1f Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Sun, 9 Jul 2017 19:28:41 +0800 Subject: [PATCH 8/8] restart teamcity JOB --- go/pserver/service_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/pserver/service_test.go b/go/pserver/service_test.go index f365a4539a2d0..9bf1a48a596f3 100644 --- a/go/pserver/service_test.go +++ b/go/pserver/service_test.go @@ -166,5 +166,5 @@ func TestBlockUntilInitialized(t *testing.T) { } func TestCheckpointSpeed(t *testing.T) { - //TODO: test speed + //TODO(zhihong): test speed }