Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add TaskFail interface #2719

Merged
merged 11 commits into from
Jul 11, 2017
7 changes: 6 additions & 1 deletion go/master/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (c *Client) getRecords() {
// We treat a task as finished whenever the last data
// instance of the task is read. This is not exactly
// correct, but a reasonable approximation.
c.taskFinished(t.ID)
c.taskFinished(t.Meta.ID)
}
}

Expand Down Expand Up @@ -118,6 +118,11 @@ func (c *Client) taskFinished(taskID int) error {
return c.conn.Call("Service.TaskFinished", taskID, nil)
}

// TaskFailed tell the master server as task is failed.
func (c *Client) taskFailed(meta TaskMeta) error {
return c.conn.Call("Service.TaskFailed", meta, nil)
}

// NextRecord returns next record in the dataset.
//
// NextRecord will block until the next record is available. It is
Expand Down
10 changes: 8 additions & 2 deletions go/master/client_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,16 @@ func TestGetFinishTask(t *testing.T) {
t.Fatalf("Should get error, pass: %d\n", i)
}

err = c.taskFinished(tasks[0].ID)
err = c.taskFinished(tasks[0].Meta.ID)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
}

err = c.taskFailed(tasks[0].Meta)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
}

tasks = tasks[1:]
task, err := c.getTask()
if err != nil {
Expand All @@ -107,7 +113,7 @@ func TestGetFinishTask(t *testing.T) {
tasks = append(tasks, task)

for _, task := range tasks {
err = c.taskFinished(task.ID)
err = c.taskFinished(task.Meta.ID)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
}
Expand Down
111 changes: 70 additions & 41 deletions go/master/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,30 +31,36 @@ type Chunk struct {
Index recordio.Index // chunk index
}

// TaskMeta is a struct which stores task's meta info.
type TaskMeta struct {
ID int
Epoch int
}

// Task is the basic unit of data instances assigned to trainers.
type Task struct {
ID int
Meta TaskMeta
Chunks []Chunk
}

type taskEntry struct {
Epoch int
NumTimeout int
Task Task
Task Task
// A task fails if it's timeout or trainer reports it exits unnormally.
NumFailure int
}

type taskQueues struct {
Todo []taskEntry
Pending map[int]taskEntry // map from task ID to task entry
Done []taskEntry
Failed []Task
Failed []taskEntry
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Task改成了taskEntry的原因是,我觉得Failed task应该保留进入错误队列时候的上下文状态。

}

// Service is the master server service.
type Service struct {
chunksPerTask int
timeoutDur time.Duration
timeoutMax int
failureMax int
ready chan struct{}
store Store

Expand All @@ -73,7 +79,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
var cur taskEntry
for i, c := range chunks {
if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 {
cur.Task.ID = id
cur.Task.Meta.ID = id
id++
result = append(result, cur)
cur.Task.Chunks = nil
Expand All @@ -83,19 +89,19 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
}

if len(cur.Task.Chunks) > 0 {
cur.Task.ID = id
cur.Task.Meta.ID = id
result = append(result, cur)
}

return result
}

// NewService creates a new service.
func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, timeoutMax int) (*Service, error) {
func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, failureMax int) (*Service, error) {
s := &Service{}
s.chunksPerTask = chunksPerTask
s.timeoutDur = timeoutDur
s.timeoutMax = timeoutMax
s.failureMax = failureMax
s.taskQueues = taskQueues{}
s.taskQueues.Pending = make(map[int]taskEntry)
s.ready = make(chan struct{})
Expand Down Expand Up @@ -257,6 +263,34 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error {
return nil
}

func (s *Service) processFailedTask(t taskEntry, epoch int) {
if t.Task.Meta.Epoch != epoch {
// new epoch, task launched after the
// schedule of this timeout check or failed status report.
return
}

defer func() {
err := s.snapshot()
if err != nil {
log.Errorln(err)
}
}()

delete(s.taskQueues.Pending, t.Task.Meta.ID)

t.NumFailure++
if t.NumFailure > s.failureMax {
log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure)
s.taskQueues.Failed = append(s.taskQueues.Failed, t)
return
}

log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure)
s.taskQueues.Todo = append(s.taskQueues.Todo, t)
return
}

func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
return func() {
s.mu.Lock()
Expand All @@ -267,30 +301,7 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
return
}

if t.Epoch != epoch {
// new epoch, task launched after the
// schedule of this timeout check.
return
}

defer func() {
err := s.snapshot()
if err != nil {
log.Errorln(err)
}
}()

delete(s.taskQueues.Pending, t.Task.ID)

t.NumTimeout++
if t.NumTimeout > s.timeoutMax {
log.Warningf("Task %v timed out %d times, discard.", t.Task, t.NumTimeout)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可能也会被failed调用,所以不一定都是time out,可以用泛化点的描述。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task)
return
}

log.Warningf("Task %v timed out %d times, retry.", t.Task, t.NumTimeout)
s.taskQueues.Todo = append(s.taskQueues.Todo, t)
s.processFailedTask(t, epoch)
}
}

Expand Down Expand Up @@ -339,18 +350,18 @@ func (s *Service) GetTask(dummy int, task *Task) error {
}

t := s.taskQueues.Todo[0]
t.Epoch++
t.Task.Meta.Epoch++
s.taskQueues.Todo = s.taskQueues.Todo[1:]
s.taskQueues.Pending[t.Task.ID] = t
s.taskQueues.Pending[t.Task.Meta.ID] = t
err := s.snapshot()
if err != nil {
return err
}

*task = t.Task
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete unused *task

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这行是返回值,是有用的。:)

log.WithFields(s.logFields()).Infof("Task #%d dispatched.", task.ID)
log.WithFields(s.logFields()).Infof("Task #%v dispatched.", t.Task.Meta)

time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.ID, t.Epoch))
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
return nil
}

Expand All @@ -365,13 +376,12 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {

t, ok := s.taskQueues.Pending[taskID]
if !ok {
err := errors.New("pending task not found")
log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID)
return err
return nil
}

// task finished, reset timeout
t.NumTimeout = 0
t.NumFailure = 0
s.taskQueues.Done = append(s.taskQueues.Done, t)
delete(s.taskQueues.Pending, taskID)

Expand All @@ -389,3 +399,22 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
}
return err
}

// TaskFailed tells the service that a task is failed.
func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
select {
case <-s.ready:
}

s.mu.Lock()
defer s.mu.Unlock()

t, ok := s.taskQueues.Pending[meta.ID]
if !ok {
log.WithFields(s.logFields()).Warningln("TaskFailed:Pending task #%v not found.", t.Task.Meta)
return nil
}

s.processFailedTask(t, meta.Epoch)
return nil
}
2 changes: 1 addition & 1 deletion go/master/service_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestPartionIndex(t *testing.T) {
cs := make([]Chunk, 100)
ts := partition(cs, 20)
for i := range ts {
if ts[i].Task.ID != i {
if ts[i].Task.Meta.ID != i {
t.Error(ts[i], i)
}
}
Expand Down
7 changes: 4 additions & 3 deletions go/pserver/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ func initClient() [numPserver]int {
ports[i] = p

go func(l net.Listener) {
s, err := pserver.NewService(0)
var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -174,7 +175,7 @@ func TestNativeClient(t *testing.T) {
// TODO: tmperary disable etcdClient test for dependency of etcd)
func EtcdClient(t *testing.T) {
initEtcdClient()
etcd_client := client.NewEtcd(etcdEndpoints)
c2 := client.NewClient(etcd_client, etcd_client.Desired(), selector(true))
etcdClient := client.NewEtcd(etcdEndpoints)
c2 := client.NewClient(etcdClient, etcdClient.Desired(), selector(true))
ClientTest(t, c2)
}