Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pserver etcd registration #2544

Merged
merged 3 commits into from
Jun 23, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions go/cmd/pserver/pserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,35 @@ import (
"net/http"
"net/rpc"
"strconv"
"time"

"github.com/namsral/flag"

"github.com/PaddlePaddle/Paddle/go/pserver"
log "github.com/sirupsen/logrus"
)

func main() {
port := flag.Int("port", 0, "port of the pserver")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls")
logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic")
flag.Parse()

s := pserver.NewService()
err := rpc.Register(s)
level, err := log.ParseLevel(*logLevel)
if err != nil {
panic(err)
}
log.SetLevel(level)

timeout := time.Second * time.Duration((*etcdTimeout))
s, err := pserver.NewService(*etcdEndpoint, timeout)
if err != nil {
panic(err)
}
err = rpc.Register(s)
if err != nil {
panic(err)
}
Expand All @@ -27,7 +44,9 @@ func main() {
panic(err)
}

log.Infof("start pserver at port %d", *port)
err = http.Serve(l, nil)

if err != nil {
panic(err)
}
Expand Down
8 changes: 6 additions & 2 deletions go/pserver/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strconv"
"strings"
"testing"
"time"

"github.com/PaddlePaddle/Paddle/go/pserver"
)
Expand All @@ -30,9 +31,12 @@ func init() {
port[i] = p

go func(l net.Listener) {
s := pserver.NewService()
s, err := pserver.NewService("", time.Second*5)
if err != nil {
panic(err)
}
server := rpc.NewServer()
err := server.Register(s)
err = server.Register(s)
if err != nil {
panic(err)
}
Expand Down
125 changes: 122 additions & 3 deletions go/pserver/service.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
package pserver

import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"sync"
"time"

"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus"
)

// ElementType is the type of elements of a Parameter.
Expand All @@ -24,6 +33,9 @@ const (
Float64
)

// PsDesired is etcd path for store desired pserver count
const PsDesired = "/ps_desired"

// Parameter is a piece of data to sync with the parameter server.
type Parameter struct {
Name string
Expand All @@ -47,14 +59,121 @@ type Service struct {
mu sync.Mutex
opt *optimizer
paramMap map[string]Parameter

etcdEndpoints string
etcdClient *clientv3.Client
// etcdTimeout is also used as retry intervals.
etcdTimeout time.Duration
// desired number of pservers in the job.
// assume desired will not change during one training job.
desired int
// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
externalIP string
}

// NewService creates a new service.
func NewService() *Service {
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified.
func NewService(endpoints string, timeout time.Duration) (*Service, error) {
s := &Service{opt: newOptimizer(sgd, 0.005)}
s.paramMap = make(map[string]Parameter)
s.initialized = make(chan struct{})
return s
s.etcdEndpoints = endpoints
s.etcdTimeout = timeout

var err error
s.externalIP, err = networkhelper.GetExternalIP()
if err != nil {
return nil, err
}

if endpoints != "" {
// initialize connection to etcd, try
ep := strings.Split(s.etcdEndpoints, ",")
for {
cli, err := clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: s.etcdTimeout,
})
if err != nil {
log.Errorf("connect to etcd error: %v", err)
time.Sleep(s.etcdTimeout)
continue
}
s.etcdClient = cli
log.Debugf("inited client to %s", s.etcdEndpoints)
break
}
// wait and set s.desired init value
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
resp, err := s.etcdClient.Get(ctx, PsDesired)
cancel()
if err != nil {
log.Errorf("getting %s error: %v", PsDesired, err)
time.Sleep(s.etcdTimeout)
continue
}
if len(resp.Kvs) != 0 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

len(resp.Kvs) will always be 1 if err == nil.

Btw, In the previous example https://github.com/coreos/etcd/blob/master/clientv3/concurrency/mutex.go#L64 I asked the coreos develop who wrote this code over IRC, he said checking len(ownerKey) == 0 is not necessary.
You can see from this line ownerKey := resp.Responses[1].GetResponseRange().Kvs: https://github.com/coreos/etcd/blob/master/clientv3/concurrency/mutex.go#L63 , they don't check slice len as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, will fix this in later PRs.

s.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil {
log.Errorf("value of %s invalid %v\n", PsDesired, err)
time.Sleep(s.etcdTimeout)
// NOTE: wait util ps_desired value change
continue
}
break
}
}
// try register pserver node on etcd
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
_, err := s.registerPserverEtcd(ctx)
cancel()
if err != nil {
log.Warn(err)
time.Sleep(s.etcdTimeout)
continue
}
break
}
} // if endpoints != ""
// Bypass etcd registration if no endpoints specified
return s, nil
}

// registerPserverEtcd registers pserver node on etcd using transaction.
func (s *Service) registerPserverEtcd(ctx context.Context) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error {
registered := false
for i := 0; i < s.desired; i++ {
psKey := "/ps/" + strconv.Itoa(i)
log.Debugf("checking %s", psKey)
ps := c.Get(psKey)
log.Debugf("got value (%s) for key: %s", ps, psKey)

if ps == "" {
resp, err := s.etcdClient.Grant(context.TODO(), 5)
if err != nil {
log.Fatal(err)
}
// find the first id and write info
c.Put(psKey, s.externalIP, clientv3.WithLease(resp.ID))
log.Debugf("set pserver node %s with value %s", psKey, s.externalIP)
_, kaerr := s.etcdClient.KeepAlive(context.TODO(), resp.ID)
if kaerr != nil {
log.Errorf("keepalive etcd node error: %v", kaerr)
return kaerr
}
log.Debug("register finished")
registered = true
break
}
}
if registered == true {
return nil
}
return errors.New("not registerd, may due to already have enough pservers")
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
}

// InitParam initializes a parameter.
Expand Down
25 changes: 17 additions & 8 deletions go/pserver/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@ import (
)

func TestFull(t *testing.T) {
s := pserver.NewService()
s, err := pserver.NewService("", time.Second*5)
if err != nil {
t.Error(err)
}
var p pserver.Parameter
p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32
err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil)
err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil)
if err != nil {
t.FailNow()
}
Expand Down Expand Up @@ -72,8 +75,11 @@ func TestFull(t *testing.T) {
}

func TestMultipleInit(t *testing.T) {
s := pserver.NewService()
err := s.FinishInitParams(0, nil)
s, err := pserver.NewService("", time.Second*5)
if err != nil {
t.Error(err)
}
err = s.FinishInitParams(0, nil)
if err != nil {
t.FailNow()
}
Expand All @@ -85,15 +91,18 @@ func TestMultipleInit(t *testing.T) {
}

func TestUninitialized(t *testing.T) {
s := pserver.NewService()
err := s.SendGrad(pserver.Gradient{}, nil)
s, err := pserver.NewService("", time.Second*5)
err = s.SendGrad(pserver.Gradient{}, nil)
if err.Error() != pserver.Uninitialized {
t.FailNow()
}
}

func TestBlockUntilInitialized(t *testing.T) {
s := pserver.NewService()
s, err := pserver.NewService("", time.Second*5)
if err != nil {
t.Error(err)
}
ch := make(chan struct{}, 2)
errCh := make(chan error, 2)
var wg sync.WaitGroup
Expand Down Expand Up @@ -133,7 +142,7 @@ func TestBlockUntilInitialized(t *testing.T) {
p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32
err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil)
err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil)
if err != nil {
t.FailNow()
}
Expand Down
45 changes: 45 additions & 0 deletions go/utils/networkhelper/helper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package networkhelper

import (
"errors"
"net"
)

// GetExternalIP returns the ip address of local network interface, not the
// loopback device.
func GetExternalIP() (string, error) {
ifaces, err := net.Interfaces()
if err != nil {
return "", err
}
for _, iface := range ifaces {
if iface.Flags&net.FlagUp == 0 {
continue // interface down
}
if iface.Flags&net.FlagLoopback != 0 {
continue // loopback interface
}
addrs, err := iface.Addrs()
if err != nil {
return "", err
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
if ip == nil || ip.IsLoopback() {
continue
}
ip = ip.To4()
if ip == nil {
continue // not an ipv4 address
}
return ip.String(), nil
}
}
return "", errors.New("are you connected to the network?")
}
10 changes: 10 additions & 0 deletions go/utils/networkhelper/helper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package networkhelper

import "testing"

func TestGetIP(t *testing.T) {
_, err := GetExternalIP()
if err != nil {
t.Errorf("GetExternalIP returns error : %v\n", err)
}
}