forked from yzzueong/Amazon-DIN-TFrecord-estimator
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
66 lines (57 loc) · 2.58 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import tensorflow.compat.v1 as tf
import os
import model
from feature import *
import config
import input_data
FLAGS = config.FLAGS
def main(unused_argv):
feature_configs = FeatureConfig().create_features_columns()
classifier = tf.estimator.Estimator(
model_fn=model.build_base_model, #build_DIN_model
config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir,
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
keep_checkpoint_max=3),
params = {
"feature_configs": feature_configs,
"hidden_units": FLAGS.hidden_units.split(","),
"learning_rate": FLAGS.learning_rate,
"attention_hidden_units": FLAGS.attention_hidden_units.split(','),
"dropout_rate": FLAGS.dropout_rate
}
)
def train_eval_model():
train_spec = tf.estimator.TrainSpec(input_fn=lambda: input_data.train_input_fn(FLAGS.train_record_dir, FLAGS.batch_size),
max_steps=FLAGS.train_steps)
eval_spec = tf.estimator.EvalSpec(input_fn=lambda: input_data.eval_input_fn(FLAGS.test_record_dir, FLAGS.batch_size),
start_delay_secs=60,
throttle_secs = 30,
steps=200)
tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec)
def export_model():
feature_spec = feature_configs.feature_spec
feature_map = {}
for key, feature in feature_spec.items():
if key not in fe.feature_configs:
continue
if isinstance(feature, tf.io.VarLenFeature): # 可变长度
feature_map[key] = tf.placeholder(dtype=feature.dtype, shape=[1], name=key)
elif isinstance(feature, tf.io.FixedLenFeature): # 固定长度
feature_map[key] = tf.placeholder(dtype=feature.dtype, shape=[None, feature.shape[0]], name=key)
serving_input_recevier_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(feature_map)
export_dir = classifier.export_saved_model(FLAGS.output_model, serving_input_recevier_fn)
# 模型训练
train_eval_model()
# 导出模型,只在chief中导出一次即可
if FLAGS.run_on_cluster:
if task_type == "chief":
export_model()
else:
pass
#export_model()
if __name__ == "__main__":
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
#if FLAGS.run_on_cluster: parse_argument()
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run(main=main)