forked from kubeflow/katib
/
defaultcontroller.go
123 lines (116 loc) · 3.44 KB
/
defaultcontroller.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
package studycontroller
import (
"context"
"log"
"time"
"github.com/kubeflow/katib/pkg/api"
"github.com/kubeflow/katib/pkg/db"
"google.golang.org/grpc"
)
type StudyControllerDefault struct {
dbIf db.VizierDBInterface
}
func NewStudyControllerDefault() Interface {
return &StudyControllerDefault{}
}
func (s *StudyControllerDefault) Run(managerAddr string, sctlId string) error {
s.dbIf = db.New()
sctl, err := s.dbIf.GetStudyController(sctlId)
if err != nil {
return err
}
go s.getAndRun(managerAddr, sctl)
return nil
}
func (s *StudyControllerDefault) getAndRun(managerAddr string, sctl *api.StudyController) {
conn, err := grpc.Dial(managerAddr, grpc.WithInsecure())
if err != nil {
log.Printf("could not connect: %v", err)
return
}
defer conn.Close()
c := api.NewManagerClient(conn)
var updatetimer, estimer *time.Ticker
updatetimer = time.NewTicker(time.Duration(sctl.UpdateInterval) * time.Second)
defer updatetimer.Stop()
if sctl.EarlystoppingInterval > 0 && sctl.EarlystoppingAlgorithm != "" && sctl.EarlystoppingAlgorithm != "none" {
estimer = time.NewTicker(time.Duration(sctl.EarlystoppingInterval) * time.Second)
defer estimer.Stop()
} else {
//Create dummy Ticker instance
estimer = time.NewTicker(time.Duration(1) * time.Second)
estimer.Stop()
}
ctx := context.Background()
suggestReq := &api.GetSuggestionsRequest{
StudyId: sctl.StudyId,
SuggestionAlgorithm: sctl.SuggestionAlgorithm,
RequestNumber: sctl.RequestSuggestionNum,
ParamId: sctl.SuggestionParamId,
}
suggestReply, err := c.GetSuggestions(ctx, suggestReq)
if err != nil {
return
}
s.dbIf.UpdateStudyControllerState(sctl.StudyControllerId, api.State_RUNNING, "")
workerIds := []string{}
for {
select {
case <-updatetimer.C:
var running int32 = 0
var complete int32 = 0
getWorkerRequest := &api.GetWorkersRequest{StudyId: sctl.StudyId}
getWorkerReply, err := c.GetWorkers(ctx, getWorkerRequest)
if err != nil {
log.Printf("GetWorker Error %v", err)
return
}
getMetricsRequest := &api.GetMetricsRequest{
StudyId: sctl.StudyId,
WorkerIds: workerIds,
}
_, err = c.GetMetrics(ctx, getMetricsRequest)
if err != nil {
log.Printf("GetMetrics Error %v", err)
return
}
for _, w := range getWorkerReply.Workers {
if w.Status == api.State_COMPLETED || w.Status == api.State_KILLED {
complete++
} else if w.Status == api.State_RUNNING || w.Status == api.State_PENDING {
running++
}
}
if complete == sctl.RequestSuggestionNum {
s.dbIf.UpdateStudyControllerState(sctl.StudyControllerId, api.State_COMPLETED, "")
return
}
if running < sctl.MaxParallel && running+complete < sctl.RequestSuggestionNum {
reqnum := sctl.MaxParallel - running
for i := 0; i < int(reqnum); i++ {
t := suggestReply.Trials[int(complete)+int(running)+i]
ws := *sctl.WorkerConfig
for _, p := range t.ParameterSet {
ws.Command = append(ws.Command, p.Name)
ws.Command = append(ws.Command, p.Value)
}
rtr := &api.RunTrialRequest{
StudyId: sctl.StudyId,
TrialId: t.TrialId,
Runtime: "kubernetes",
WorkerConfig: &ws,
}
runTrialReply, err := c.RunTrial(ctx, rtr)
if err != nil {
log.Printf("RunTrial Error %v", err)
return
}
workerIds = append(workerIds, runTrialReply.WorkerId)
}
}
case <-estimer.C:
//if sctl.EarlystoppingInterval != 0 {
//}
}
}
}