Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pserver Save state #2716

Merged
merged 12 commits into from
Jul 11, 2017
8 changes: 6 additions & 2 deletions go/cmd/pserver/pserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ func main() {
"comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls")
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
checkpointInterval := flag.Int("checkpoint-interval", 10, "save checkpoint per interval seconds")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default 10 seconds maybe too quick?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, fix interval is not proper for every training job. Time consumed always determined by training data amount. Round count may be better here.
Change it to 10 min(600seconds)

logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic")
flag.Parse()
Expand All @@ -31,18 +33,20 @@ func main() {
log.SetLevel(level)

var idx int
var cp pserver.Checkpoint
var e *pserver.EtcdClient
if *index >= 0 {
idx = *index
} else {
timeout := time.Second * time.Duration((*etcdTimeout))
e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout)
e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout)
idx, err = e.Register()
if err != nil {
panic(err)
}
}

s, err := pserver.NewService(idx)
s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp)
if err != nil {
panic(err)
}
Expand Down
13 changes: 13 additions & 0 deletions go/pserver/etcd_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ const (
PsDesired = "/ps_desired"
// PsAddr is the base dir for pserver to store their addr
PsPath = "/ps/"
// PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint = "/checkpoints/"
)

// EtcdClient is the etcd client that the pserver uses for fault
Expand Down Expand Up @@ -186,3 +188,14 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {

return idx, nil
}

// PutKey put into etcd with value by key specified
func (e *EtcdClient) PutKey(key string, value []byte, timeout int) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout))
_, err := e.etcdClient.Put(ctx, key, string(value))
cancel()
if err != nil {
return err
}
return nil
}
18 changes: 15 additions & 3 deletions go/pserver/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,28 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
return (*[1 << 30]byte)(p)[:len:len]
}

func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer {
func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer {
o := &optimizer{}
o.elementType = paramWithConfigs.Param.ElementType
p := paramWithConfigs.Param
c := paramWithConfigs.Config
s := State
log.WithFields(log.Fields{
"ElementType": p.ElementType,
"ParamSize": len(p.Content),
"ConfigSize": len(c),
"StateSize": len(s),
}).Info("New Optimizer Created with config:")
var cbuffer unsafe.Pointer
cbuffer = C.malloc(C.size_t(len(p.Content)))
C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)))
var cstate unsafe.Pointer
if len(s) != 0 {
cstate = unsafe.Pointer(&s[0])
}

o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)),
C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)/C.sizeof_float),
(*C.char)(nullPtr), 0)
C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)/C.sizeof_float), (*C.char)(cstate), C.int(len(s)))
return o
}

Expand All @@ -60,6 +66,12 @@ func (o *optimizer) GetWeights() []byte {
return cArrayToSlice(buffer, int(bufferLen)*C.sizeof_float)
}

func (o *optimizer) GetStates() []byte {
var cbuffer *C.char
cbuffer_len := C.paddle_optimizer_get_state(o.opt, &cbuffer)
return cArrayToSlice(unsafe.Pointer(cbuffer), int(cbuffer_len))
}

func (o *optimizer) UpdateParameter(g Gradient) error {
if o.elementType != g.ElementType {
return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, o.elementType, g.ElementType)
Expand Down
2 changes: 1 addition & 1 deletion go/pserver/optimizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ func TestOptimizerCreateRelease(t *testing.T) {
Param: p,
Config: config,
}
o := newOptimizer(param)
o := newOptimizer(param, nil)
o.Cleanup()
}
110 changes: 99 additions & 11 deletions go/pserver/service.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
package pserver

import (
"bufio"
"bytes"
"crypto/md5"
"encoding/gob"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strconv"
"sync"
"time"

log "github.com/sirupsen/logrus"
)

// ElementType is the type of elements of a Parameter.
Expand Down Expand Up @@ -39,26 +51,55 @@ type ParameterWithConfig struct {
Config []byte // parameter configuration in Proto Buffer format
}

// Checkpoint of Parameter and State
type parameterCheckPoint struct {
ParamConfig ParameterWithConfig
State []byte
}

// checkpoint signature
type checkpointMeta struct {
UUID string `json:"uuid"`
Md5sum string `json:"md5sum"`
Timestamp string `json:"timestamp"`
}

// Checkpoint is the pserver shard persist in file
type Checkpoint []parameterCheckPoint
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exported type is an array of unexported type, maybe inconvenience to use.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


// Gradient is the gradient of the parameter.
type Gradient Parameter

// Service is the RPC service for pserver.
type Service struct {
initialized chan struct{}
idx int

mu sync.Mutex
optMap map[string]*optimizer
initialized chan struct{}
idx int
checkpointInterval time.Duration
checkpointPath string
client *EtcdClient
mu sync.Mutex
optMap map[string]*optimizer
}

// NewService creates a new service, will bypass etcd registration if no
// endpoints specified.
func NewService(idx int) (*Service, error) {
func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
s := &Service{
idx: idx,
idx: idx,
checkpointInterval: time.Second * time.Duration(seconds),
checkpointPath: path,
client: client,
}
s.optMap = make(map[string]*optimizer)
s.initialized = make(chan struct{})

if cp != nil {
for _, item := range cp {
p := item.ParamConfig
st := item.State
s.optMap[p.Param.Name] = newOptimizer(p, st)
}
}
return s, nil
}

Expand All @@ -78,7 +119,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
// TODO(helin): check if paramWithConfigs.Param.Content is
// properly memory aligned, if not, make copy to a memory
// aligned region.
s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs)
s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil)
return nil
}

Expand Down Expand Up @@ -139,10 +180,57 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
return nil
}

// Save tells the parameter server to save parameters.
func (s *Service) Save(path string, dummy *int) error {
// pserver save checkpoint
func (s *Service) doCheckpoint() error {
<-s.initialized
s.mu.Lock()
defer s.mu.Unlock()

cp := make([]parameterCheckPoint, 0, len(s.optMap))
index := 0
for name, opt := range s.optMap {
var pc parameterCheckPoint
pc.ParamConfig.Param.Name = name
pc.ParamConfig.Param.ElementType = opt.elementType
pc.ParamConfig.Param.Content = opt.GetWeights()
pc.State = opt.GetStates()
cp[index] = pc
index++
}
var buf bytes.Buffer
encoder := gob.NewEncoder(&buf)
err := encoder.Encode(cp)
if err != nil {
return err
}

cpMeta := checkpointMeta{}
cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx)
cpMeta.Timestamp = time.Now().String()
h := md5.New()
cpMeta.Md5sum = hex.EncodeToString(h.Sum(buf.Bytes()))

// TODO
cpMetajson, err := json.Marshal(cpMeta)
s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3)
if err != nil {
return err
}
if _, err = os.Stat(cpMeta.UUID); os.IsNotExist(err) {
log.Info("checkpoint does not exists.")
} else {
err = os.Remove(cpMeta.UUID)
log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID)
}
f, err := os.Create(cpMeta.UUID)
defer f.Close()
if err != nil {
log.Errorln(err)
}
writer := bufio.NewWriter(f)
_, err = writer.Write(buf.Bytes())
writer.Flush()
if err != nil {
log.Errorln(err)
}
return nil
}
26 changes: 12 additions & 14 deletions go/pserver/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ const (
)

func TestServiceFull(t *testing.T) {
s, err := pserver.NewService(0)
var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -86,7 +87,8 @@ func TestServiceFull(t *testing.T) {
}

func TestMultipleInit(t *testing.T) {
s, err := pserver.NewService(0)
var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil {
t.Error(err)
}
Expand All @@ -102,15 +104,17 @@ func TestMultipleInit(t *testing.T) {
}

func TestUninitialized(t *testing.T) {
s, err := pserver.NewService(0)
var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
err = s.SendGrad(pserver.Gradient{}, nil)
if err.Error() != pserver.Uninitialized {
t.FailNow()
}
}

func TestBlockUntilInitialized(t *testing.T) {
s, err := pserver.NewService(0)
var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil {
t.Error(err)
}
Expand All @@ -128,16 +132,6 @@ func TestBlockUntilInitialized(t *testing.T) {
ch <- struct{}{}
}()

wg.Add(1)
go func() {
err := s.Save("", nil)
if err != nil {
errCh <- err
}
wg.Done()
ch <- struct{}{}
}()

time.Sleep(50 * time.Millisecond)

select {
Expand Down Expand Up @@ -170,3 +164,7 @@ func TestBlockUntilInitialized(t *testing.T) {

wg.Wait()
}

func TestCheckpointSpeed(t *testing.T) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Speed can be tested with benchmark. Here is an example: https://dave.cheney.net/2013/06/30/how-to-write-benchmarks-in-go

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leave a TODO here, will be tested after reaching an agreement with @Yancey1989 's recover logic.

//TODO: test speed
}