-
Notifications
You must be signed in to change notification settings - Fork 104
/
train_ps.py
141 lines (104 loc) · 5.66 KB
/
train_ps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import argparse
import codecs
import json
import logging
import numpy as np
import os
import re
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint
from model_def import get_model, HEIGHT, WIDTH, DEPTH, NUM_CLASSES
from utilities import process_input
logging.getLogger().setLevel(logging.INFO)
tf.logging.set_verbosity(tf.logging.ERROR)
# Copy inference pre/post-processing script so it will be included in the model package
os.system('mkdir /opt/ml/model/code')
os.system('cp inference.py /opt/ml/model/code')
os.system('cp requirements.txt /opt/ml/model/code')
class CustomTensorBoardCallback(TensorBoard):
def on_batch_end(self, batch, logs=None):
pass
def save_history(path, history):
history_for_json = {}
# transform float values that aren't json-serializable
for key in list(history.history.keys()):
if type(history.history[key]) == np.ndarray:
history_for_json[key] == history.history[key].tolist()
elif type(history.history[key]) == list:
if type(history.history[key][0]) == np.float32 or type(history.history[key][0]) == np.float64:
history_for_json[key] = list(map(float, history.history[key]))
with codecs.open(path, 'w', encoding='utf-8') as f:
json.dump(history_for_json, f, separators=(',', ':'), sort_keys=True, indent=4)
def save_model(model, output):
# create a TensorFlow SavedModel for deployment to a SageMaker endpoint with TensorFlow Serving
tf.contrib.saved_model.save_keras_model(model, args.model_dir)
logging.info("Model successfully saved at: {}".format(output))
return
def main(args):
if 'sourcedir.tar.gz' in args.tensorboard_dir:
tensorboard_dir = re.sub('source/sourcedir.tar.gz', 'model', args.tensorboard_dir)
else:
tensorboard_dir = args.tensorboard_dir
logging.info("Writing TensorBoard logs to {}".format(tensorboard_dir))
logging.info("getting data")
train_dataset = process_input(args.epochs, args.batch_size, args.train, 'train', args.data_config)
eval_dataset = process_input(args.epochs, args.batch_size, args.eval, 'eval', args.data_config)
validation_dataset = process_input(args.epochs, args.batch_size, args.validation, 'validation', args.data_config)
logging.info("configuring model")
logging.info("Hosts: "+ os.environ.get('SM_HOSTS'))
size = len(args.hosts)
#Deal with this
model = get_model(args.learning_rate, args.weight_decay, args.optimizer, args.momentum, size)
callbacks = []
if args.current_host == args.hosts[0]:
callbacks.append(ModelCheckpoint(args.output_data_dir + '/checkpoint-{epoch}.h5'))
callbacks.append(CustomTensorBoardCallback(log_dir=tensorboard_dir))
logging.info("Starting training")
history = model.fit(x=train_dataset[0],
y=train_dataset[1],
steps_per_epoch=(num_examples_per_epoch('train') // args.batch_size) // size,
epochs=args.epochs,
validation_data=validation_dataset,
validation_steps=(num_examples_per_epoch('validation') // args.batch_size) // size, callbacks=callbacks)
score = model.evaluate(eval_dataset[0],
eval_dataset[1],
steps=num_examples_per_epoch('eval') // args.batch_size,
verbose=0)
logging.info('Test loss:{}'.format(score[0]))
logging.info('Test accuracy:{}'.format(score[1]))
# PS: Save model and history only on worker 0
if args.current_host == args.hosts[0]:
save_history(args.model_dir + "/ps_history.p", history)
save_model(model, args.model_dir)
def num_examples_per_epoch(subset='train'):
if subset == 'train':
return 40000
elif subset == 'validation':
return 10000
elif subset == 'eval':
return 10000
else:
raise ValueError('Invalid data subset "%s"' % subset)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--hosts',type=list,default=json.loads(os.environ.get('SM_HOSTS')))
parser.add_argument('--current-host',type=str,default=os.environ.get('SM_CURRENT_HOST'))
parser.add_argument('--train',type=str,required=False,default=os.environ.get('SM_CHANNEL_TRAIN'))
parser.add_argument('--validation',type=str,required=False,default=os.environ.get('SM_CHANNEL_VALIDATION'))
parser.add_argument('--eval',type=str,required=False,default=os.environ.get('SM_CHANNEL_EVAL'))
parser.add_argument('--model_dir',type=str,required=True,help='The directory where the model will be stored.')
parser.add_argument('--model_output_dir',type=str,default=os.environ.get('SM_MODEL_DIR'))
parser.add_argument('--output_data_dir',type=str,default=os.environ.get('SM_OUTPUT_DATA_DIR'))
parser.add_argument('--output-dir',type=str,default=os.environ.get('SM_OUTPUT_DIR'))
parser.add_argument('--tensorboard-dir',type=str,default=os.environ.get('SM_MODULE_DIR'))
parser.add_argument('--weight-decay',type=float,default=2e-4,help='Weight decay for convolutions.')
parser.add_argument('--learning-rate',type=float,default=0.001,help='Initial learning rate.')
parser.add_argument('--epochs',type=int,default=10)
parser.add_argument('--batch-size',type=int,default=128)
parser.add_argument('--data-config',type=json.loads,default=os.environ.get('SM_INPUT_DATA_CONFIG'))
parser.add_argument('--fw-params',type=json.loads,default=os.environ.get('SM_FRAMEWORK_PARAMS'))
parser.add_argument('--optimizer',type=str,default='adam')
parser.add_argument('--momentum',type=float,default='0.9')
args = parser.parse_args()
main(args)