-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pserver Save state #2716
Pserver Save state #2716
Changes from 10 commits
5ef1425
f1330e2
e6c98e4
65afbe1
bfc3b43
40295b9
8426beb
774604c
2f2ffd9
87e7924
0ad7053
e8296ff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,21 @@ | ||
package pserver | ||
|
||
import ( | ||
"bufio" | ||
"bytes" | ||
"crypto/md5" | ||
"encoding/gob" | ||
"encoding/hex" | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"os" | ||
"path/filepath" | ||
"strconv" | ||
"sync" | ||
"time" | ||
|
||
log "github.com/sirupsen/logrus" | ||
) | ||
|
||
// ElementType is the type of elements of a Parameter. | ||
|
@@ -39,26 +51,55 @@ type ParameterWithConfig struct { | |
Config []byte // parameter configuration in Proto Buffer format | ||
} | ||
|
||
// Checkpoint of Parameter and State | ||
type parameterCheckPoint struct { | ||
ParamConfig ParameterWithConfig | ||
State []byte | ||
} | ||
|
||
// checkpoint signature | ||
type checkpointMeta struct { | ||
UUID string `json:"uuid"` | ||
Md5sum string `json:"md5sum"` | ||
Timestamp string `json:"timestamp"` | ||
} | ||
|
||
// Checkpoint is the pserver shard persist in file | ||
type Checkpoint []parameterCheckPoint | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Exported type is an array of unexported type, maybe inconvenience to use. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
||
// Gradient is the gradient of the parameter. | ||
type Gradient Parameter | ||
|
||
// Service is the RPC service for pserver. | ||
type Service struct { | ||
initialized chan struct{} | ||
idx int | ||
|
||
mu sync.Mutex | ||
optMap map[string]*optimizer | ||
initialized chan struct{} | ||
idx int | ||
checkpointInterval time.Duration | ||
checkpointPath string | ||
client *EtcdClient | ||
mu sync.Mutex | ||
optMap map[string]*optimizer | ||
} | ||
|
||
// NewService creates a new service, will bypass etcd registration if no | ||
// endpoints specified. | ||
func NewService(idx int) (*Service, error) { | ||
func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) { | ||
s := &Service{ | ||
idx: idx, | ||
idx: idx, | ||
checkpointInterval: time.Second * time.Duration(seconds), | ||
checkpointPath: path, | ||
client: client, | ||
} | ||
s.optMap = make(map[string]*optimizer) | ||
s.initialized = make(chan struct{}) | ||
|
||
if cp != nil { | ||
for _, item := range cp { | ||
p := item.ParamConfig | ||
st := item.State | ||
s.optMap[p.Param.Name] = newOptimizer(p, st) | ||
} | ||
} | ||
return s, nil | ||
} | ||
|
||
|
@@ -78,7 +119,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er | |
// TODO(helin): check if paramWithConfigs.Param.Content is | ||
// properly memory aligned, if not, make copy to a memory | ||
// aligned region. | ||
s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs) | ||
s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil) | ||
return nil | ||
} | ||
|
||
|
@@ -139,10 +180,57 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { | |
return nil | ||
} | ||
|
||
// Save tells the parameter server to save parameters. | ||
func (s *Service) Save(path string, dummy *int) error { | ||
// pserver save checkpoint | ||
func (s *Service) doCheckpoint() error { | ||
<-s.initialized | ||
s.mu.Lock() | ||
defer s.mu.Unlock() | ||
|
||
cp := make([]parameterCheckPoint, 0, len(s.optMap)) | ||
index := 0 | ||
for name, opt := range s.optMap { | ||
var pc parameterCheckPoint | ||
pc.ParamConfig.Param.Name = name | ||
pc.ParamConfig.Param.ElementType = opt.elementType | ||
pc.ParamConfig.Param.Content = opt.GetWeights() | ||
pc.State = opt.GetStates() | ||
cp[index] = pc | ||
index++ | ||
} | ||
var buf bytes.Buffer | ||
encoder := gob.NewEncoder(&buf) | ||
err := encoder.Encode(cp) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
cpMeta := checkpointMeta{} | ||
cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx) | ||
cpMeta.Timestamp = time.Now().String() | ||
h := md5.New() | ||
cpMeta.Md5sum = hex.EncodeToString(h.Sum(buf.Bytes())) | ||
|
||
// TODO | ||
cpMetajson, err := json.Marshal(cpMeta) | ||
s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3) | ||
if err != nil { | ||
return err | ||
} | ||
if _, err = os.Stat(cpMeta.UUID); os.IsNotExist(err) { | ||
log.Info("checkpoint does not exists.") | ||
} else { | ||
err = os.Remove(cpMeta.UUID) | ||
log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID) | ||
} | ||
f, err := os.Create(cpMeta.UUID) | ||
defer f.Close() | ||
if err != nil { | ||
log.Errorln(err) | ||
} | ||
writer := bufio.NewWriter(f) | ||
_, err = writer.Write(buf.Bytes()) | ||
writer.Flush() | ||
if err != nil { | ||
log.Errorln(err) | ||
} | ||
return nil | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,8 @@ const ( | |
) | ||
|
||
func TestServiceFull(t *testing.T) { | ||
s, err := pserver.NewService(0) | ||
var cp pserver.Checkpoint | ||
s, err := pserver.NewService(0, 1, "", nil, cp) | ||
if err != nil { | ||
t.Error(err) | ||
} | ||
|
@@ -86,7 +87,8 @@ func TestServiceFull(t *testing.T) { | |
} | ||
|
||
func TestMultipleInit(t *testing.T) { | ||
s, err := pserver.NewService(0) | ||
var cp pserver.Checkpoint | ||
s, err := pserver.NewService(0, 1, "", nil, cp) | ||
if err != nil { | ||
t.Error(err) | ||
} | ||
|
@@ -102,15 +104,17 @@ func TestMultipleInit(t *testing.T) { | |
} | ||
|
||
func TestUninitialized(t *testing.T) { | ||
s, err := pserver.NewService(0) | ||
var cp pserver.Checkpoint | ||
s, err := pserver.NewService(0, 1, "", nil, cp) | ||
err = s.SendGrad(pserver.Gradient{}, nil) | ||
if err.Error() != pserver.Uninitialized { | ||
t.FailNow() | ||
} | ||
} | ||
|
||
func TestBlockUntilInitialized(t *testing.T) { | ||
s, err := pserver.NewService(0) | ||
var cp pserver.Checkpoint | ||
s, err := pserver.NewService(0, 1, "", nil, cp) | ||
if err != nil { | ||
t.Error(err) | ||
} | ||
|
@@ -128,16 +132,6 @@ func TestBlockUntilInitialized(t *testing.T) { | |
ch <- struct{}{} | ||
}() | ||
|
||
wg.Add(1) | ||
go func() { | ||
err := s.Save("", nil) | ||
if err != nil { | ||
errCh <- err | ||
} | ||
wg.Done() | ||
ch <- struct{}{} | ||
}() | ||
|
||
time.Sleep(50 * time.Millisecond) | ||
|
||
select { | ||
|
@@ -170,3 +164,7 @@ func TestBlockUntilInitialized(t *testing.T) { | |
|
||
wg.Wait() | ||
} | ||
|
||
func TestCheckpointSpeed(t *testing.T) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Speed can be tested with benchmark. Here is an example: https://dave.cheney.net/2013/06/30/how-to-write-benchmarks-in-go There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. leave a TODO here, will be tested after reaching an agreement with @Yancey1989 's recover logic. |
||
//TODO: test speed | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Default 10 seconds maybe too quick?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, fix interval is not proper for every training job. Time consumed always determined by training data amount.
Round
count may be better here.Change it to 10 min(600seconds)