-
Notifications
You must be signed in to change notification settings - Fork 5
/
lcm_utils.go
326 lines (277 loc) · 11.7 KB
/
lcm_utils.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
/*
* Copyright 2017-2018 IBM Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package lcm
import (
"fmt"
"io/ioutil"
"math/rand"
"strconv"
"strings"
"time"
"github.com/cenkalti/backoff"
"gopkg.in/yaml.v2"
"github.com/AISphere/ffdl-commons/config"
"github.com/AISphere/ffdl-commons/logger"
"github.com/AISphere/ffdl-commons/util"
"github.com/AISphere/ffdl-lcm/coord"
"github.com/AISphere/ffdl-lcm/service"
client "github.com/AISphere/ffdl-lcm/trainer-client"
"github.com/AISphere/ffdl-trainer/trainer/grpc_trainer_v2"
"golang.org/x/net/context"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
//creates all the znodes used by a training job before it is deployed
func createEtcdNodes(lcm *lcmService, jobName string, userID string, trainingID string, numOfLearners int, framework string, logr *logger.LocLoggingEntry) error {
pathToValueMapping := map[string]string{
trainingID + "/" + zkNotes: "",
trainingID + "/" + zkUserID: userID,
trainingID + "/" + zkFramework: framework,
trainingID + "/" + zkLearners + "/" + zkTotLearners: string(numOfLearners),
trainingID + "/" + zkJobName: jobName,
trainingID + "/" + zkLearners + "/" + zkLearnerLock: "",
trainingID + "/" + zkLearners + "/" + zkLearnerCounter: "1",
trainingID + "/" + zkLearners + "/" + zkAliveLearners: "0",
trainingID + "/" + zkGlobalCursor + "/" + zkGCState: "0",
}
for path, val := range pathToValueMapping {
pathCreated, error := lcm.etcdClient.PutIfKeyMissing(path, val, logr)
if error != nil {
return error
}
if !pathCreated {
return fmt.Errorf("Failed to create the path %v , since it was already present", path)
}
}
return nil
}
//helper function to construct a job monitor name from job name
func constructJMName(jobName string) string {
jmName := "jobmonitor-" + jobName
return jmName
}
//helper function to construct a learner name from job name
func constructLearnerName(learnerID int, jobName string) string {
return "learner-" + strconv.Itoa(learnerID) + "-" + jobName
}
//helper function to construct a learnerHelper name from job name
func constructLearnerHelperName(learnerID int, jobName string) string {
return "lhelper-" + strconv.Itoa(learnerID) + "-" + jobName
}
//helper function to construct a learner service name from job name
func constructLearnerServiceName(learnerID int, jobName string) string {
return constructLearnerName(learnerID, jobName)
}
//helper function to construct a learner service name from job name
func constructLearnerVolumeClaimName(learnerID int, jobName string) string {
return constructLearnerName(learnerID, jobName)
}
//helper function to construct a parameter server name from job name
func constructPSName(jobName string) string {
psName := "grpc-ps-" + jobName
return psName
}
// Get the disk size (in bytes) requested for a job.
func getStorageSize(r *service.ResourceRequirements) int64 {
// The default size for all jobs
size := config.GetVolumeSize()
// Use the requested volume size if it's specified
if r.Storage > 0 {
storageSizeInBytes := int64(calcStorage(r) * 1024 * 1024)
size = storageSizeInBytes
}
return size
}
// Return the name of a volume to use for a job.
func getStaticVolume(zone string, logr *logger.LocLoggingEntry) string {
type Items struct {
Name string `yaml:"name"`
Status string `yaml:"status"`
StorageClass string `yaml:"storage_class"`
Zone string `yaml:"zone"`
}
type Volumes struct {
Volumes []Items `yaml:"static-volumes-v2"`
}
var staticVolumes Volumes
pvcConfigMap := "/etc/static-volumes-v2/PVCs-v2.yaml"
bytes, err := ioutil.ReadFile(pvcConfigMap)
if err != nil {
logr.Warnf("Unable to load %s: %s", pvcConfigMap, err)
return ""
}
err = yaml.Unmarshal(bytes, &staticVolumes)
if err != nil {
return ""
}
if len(staticVolumes.Volumes) > 0 {
// If configmap contains zone, find co-located static volumes
if staticVolumes.Volumes[0].Zone != "" && zone != "" {
logr.Debugf("searching for static volumes with zone %s", zone)
var colocatedVolumeNames []string
for _, vol := range staticVolumes.Volumes {
if vol.Zone == zone {
colocatedVolumeNames = append(colocatedVolumeNames, vol.Name)
}
}
n := rand.Int() % len(colocatedVolumeNames)
return colocatedVolumeNames[n]
}
// TODO: Remove this `if/else` logic when all configmaps contain zone
// No zone found, assume single region zone; use any static volume
logr.Debugf("no zone found in configmap, using any static volume")
n := rand.Int() % len(staticVolumes.Volumes)
return staticVolumes.Volumes[n].Name
}
return ""
}
func handleDeploymentFailure(s *lcmService, dlaasJobName string, tID string,
userID string, component string, logr *logger.LocLoggingEntry) {
logr.Errorf("updating status to FAILED")
if errUpd := updateJobStatus(tID, grpc_trainer_v2.Status_FAILED, userID, service.StatusMessages_INTERNAL_ERROR.String(), client.ErrCodeFailedDeploy, logr); errUpd != nil {
logr.WithError(errUpd).Errorf("after failed %s, error while calling Trainer service client update", component)
}
//Cleaning up resources out of an abundance of caution
logr.Errorf("training FAILED so going ahead and cleaning up resources")
if errKill := s.killDeployedJob(dlaasJobName, tID, userID); errKill != nil {
logr.WithError(errKill).Errorf("after failed %s, problem calling KillDeployedJob for job ", component)
}
}
func jobBasePath(trainingID string) string {
return config.GetEtcdPrefix() + trainingID
}
// Return the etcd base path of learner znodes.
func learnerEtcdBasePath(trainingID string) string {
return jobBasePath(trainingID) + "/learners"
}
// Return the etcd base path of status of learner znodes.
func learnerNodeEtcdStatusPath(trainingID string, learnerID int) string {
return fmt.Sprintf("%s/learner_%d/status", learnerEtcdBasePath(trainingID), learnerID)
}
func learnerNodeEtcdStatusPathRelative(trainingID string, learnerID int) string {
return fmt.Sprintf("%s/learner_%d/status", trainingID, learnerID)
}
// Return the etcd base path of learner znodes.
func learnerNodeEtcdBasePath(trainingID string, learnerID int) string {
return fmt.Sprintf("%s/learner_%d/", learnerEtcdBasePath(trainingID), learnerID)
}
// calcMemory is a utility to convert the memory from DLaaS resource requirements
// to the default MB notation
func calcMemory(r *service.ResourceRequirements) float64 {
return calcSize(r.Memory, r.MemoryUnit)
}
// calcStorage is a utility to convert the storage from DLaaS resource requirements
// to the default MB notation
func calcStorage(r *service.ResourceRequirements) float64 {
return calcSize(r.Storage, r.StorageUnit)
}
// calcSize converts from memory resource requirements to the default MB notation
func calcSize(size float64, unit service.ResourceRequirements_MemoryUnit) float64 {
// according to google unit converter :)
switch unit {
case service.ResourceRequirements_MiB:
return util.RoundPlus(size*1.048576, 2)
case service.ResourceRequirements_GB:
return util.RoundPlus(size*1000, 2)
case service.ResourceRequirements_TB:
return util.RoundPlus(size*1000*1000, 2)
case service.ResourceRequirements_GiB:
return util.RoundPlus(size*1073.741824, 2)
case service.ResourceRequirements_TiB:
return util.RoundPlus(size*1073.741824*1073.741824, 2)
default:
return size // assume MB
}
}
//update job status in the database
//update job status in cassandra
func updateJobStatus(trainingID string, updStatus grpc_trainer_v2.Status, userID string, statusMessage string, errorCode string, logr *logger.LocLoggingEntry) error {
logr.Debugf("(updateJobStatus) Updating status of %s to %s", trainingID, updStatus.String())
updateRequest := &grpc_trainer_v2.UpdateRequest{TrainingId: trainingID, Status: updStatus, UserId: userID, StatusMessage: statusMessage, ErrorCode: errorCode}
trainer, err := client.NewTrainer()
if err != nil {
logr.WithError(err).Errorf("(updateJobStatus) Creating training client for status update failed. Training ID %s New Status %s", trainingID, updStatus.String())
logr.Errorf("(updateJobStatus) Error while creating training client is %s", err.Error())
}
defer trainer.Close()
err = util.Retry(10, 100*time.Millisecond, "UpdateTrainingJob", logr, func() error {
//ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout)
//defer cancel()
_, err = trainer.Client().UpdateTrainingJob(context.Background(), updateRequest)
if err != nil {
logr.WithError(err).Error("Failed to update status to the trainer. Retrying")
logr.Infof("WARNING: Status updates for %s may be temporarily inconsistent due to failure to communicate with Trainer.", trainingID)
}
return err
})
if err != nil {
logr.WithError(err).Errorf("Failed to update status to the trainer. Already retried several times.")
logr.Infof("WARNING : Status of job %s will likely be incorrect", trainingID)
return err
}
logr.Debugf("(updateJobStatus) Status update request for %s sent to trainer", trainingID)
return nil
}
func isJobDone(jobStatus string, logr *logger.LocLoggingEntry) bool {
statusUpdate := client.GetStatus(jobStatus, logr)
status := statusUpdate.Status
return status == grpc_trainer_v2.Status_COMPLETED || status == grpc_trainer_v2.Status_FAILED || status == grpc_trainer_v2.Status_HALTED
}
// Set the DLaaS service type label to an object.
// This label is used to configure Calico network policy rules for the pod.
func setServiceTypeLabel(spec *metav1.ObjectMeta, value string) {
spec.Labels["service"] = value
}
func k8sInteractionBackoff() *backoff.ExponentialBackOff {
back := backoff.NewExponentialBackOff()
back.MaxElapsedTime = 3 * time.Minute
back.MaxInterval = 1 * time.Minute
return back
}
func etdInteractionBackoff(maxElapsedTime, maxInterval time.Duration) *backoff.ExponentialBackOff {
back := backoff.NewExponentialBackOff()
back.MaxElapsedTime = maxElapsedTime
back.MaxInterval = maxInterval
return back
}
//onError function on how to deal with the scenario if connecting to coordinator failed. the error is still returned in case
func coordinator(logr *logger.LocLoggingEntry) (coord.Coordinator, error) {
var instance coord.Coordinator
var err error
err = backoff.
RetryNotify(func() error {
instance, err = coord.NewCoordinator(coord.Config{Endpoints: config.GetEtcdEndpoints(), Prefix: config.GetEtcdPrefix(),
Cert: config.GetEtcdCertLocation(), Username: config.GetEtcdUsername(), Password: config.GetEtcdPassword()}, logr)
return err
}, etdInteractionBackoff(1*time.Minute, 30*time.Second), func(err error, t time.Duration) {
logr.WithError(err).Errorf("failed to establish connection with etcd")
})
return instance, err
}
func isSplitMode(zone string, logr *logger.LocLoggingEntry) bool {
staticVolumeName := getStaticVolume(zone, logr)
useSplitLearner := len(staticVolumeName) > 0
return useSplitLearner
}
// Armada/IKS clusters will look like GitVersion:"v1.8.15+IKS"
// ICP clusters will look like GitVersion:"v1.11.1+icp-ee"
func getClusterEnv(gitVersion string, logr *logger.LocLoggingEntry) string {
if strings.Contains(strings.ToLower(gitVersion), "icp") {
logr.Infof("server version is: %s", gitVersion)
return "icp"
}
logr.Infof("defaulting to Armada/IKS, server version: %s", gitVersion)
return "iks"
}