Skip to content

Commit ecb1f3f

Browse files
committed
fix
1 parent 034e0de commit ecb1f3f

File tree

2 files changed

+46
-49
lines changed

2 files changed

+46
-49
lines changed

python/deep_model.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ def concat(self, fields, sparse_ids, sparse_vals):
2222
mapping_ints = tf.constant([0])
2323
table = tf.contrib.lookup.index_table_from_tensor(mapping=mapping_ints, num_oov_buckets=100000, dtype=tf.int64)
2424
sparse_id_in_this_field = table.lookup(sparse_ids[i])
25-
embedding_variable = tf.Variable(tf.truncated_normal([100002, self.embedding_size], stddev=0.1))
26-
embedding = tf.nn.embedding_lookup_sparse(embedding_variable, sparse_id_in_this_field, sparse_vals[i], "mod", combiner="sum")
27-
emb.append(embedding)
25+
with tf.variable_scope("emb_"+str(field_id)):
26+
embedding_variable = tf.Variable(tf.truncated_normal([100002, self.embedding_size], stddev=0.1))
27+
embedding = tf.nn.embedding_lookup_sparse(embedding_variable, sparse_id_in_this_field, sparse_vals[i], "mod", combiner="sum")
28+
emb.append(embedding)
2829
self.embedding.append(embedding_variable)
2930

3031
return tf.concat(emb, 1, name='concat_embedding')
@@ -34,7 +35,7 @@ def forward(self, sparse_id, sparse_val):
3435
forward graph
3536
'''
3637

37-
with tf.variable_scope("forward", reuse=tf.AUTO_REUSE):
38+
with tf.variable_scope("forward"):
3839
self.embedding = []
3940
self.hiddenW = []
4041
self.hiddenB = []

python/train.py

Lines changed: 41 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@
3838
if not os.path.exists(FLAGS.tensorboard_dir):
3939
os.makedirs(FLAGS.tensorboard_dir)
4040

41-
# train loop
42-
with tf.Graph().as_default():
41+
with tf.device('/cpu:0'):
4342
# data iter
4443
data = Data(FLAGS.sparse_fields)
4544
train_label, train_sparse_id, train_sparse_val = data.ReadBatch(FLAGS.train_file,
@@ -82,55 +81,52 @@
8281
print("Error: unknown optimizer: {}".format(FLAGS.optimizer))
8382
exit(1)
8483

85-
with tf.device("/cpu:0"):
86-
global_step = tf.Variable(0, name='global_step', trainable=False)
84+
global_step = tf.Variable(0, name='global_step', trainable=False)
8785
train_op = optimizer.minimize(cost, global_step=global_step)
8886

89-
# eval
87+
# to eval
9088
tf.get_variable_scope().reuse_variables()
9189

92-
# valid cross entropy loss
93-
#valid_logits, _ = model.forward(valid_sparse_id, valid_sparse_val)
94-
#valid_label = tf.to_int64(valid_label)
95-
#valid_cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=valid_logits, labels=valid_label)
96-
#valid_loss_op = tf.reduce_mean(valid_cross_entropy)
97-
98-
# valid auc
99-
#valid_auc, _ = tf.metrics.auc(predictions=valid_logits, labels=valid_label)
90+
# valid metric
91+
valid_logits, _ = model.forward(valid_sparse_id, valid_sparse_val)
92+
valid_label = tf.to_int64(valid_label)
93+
valid_cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=valid_logits, labels=valid_label)
94+
valid_loss = tf.reduce_mean(valid_cross_entropy)
95+
valid_auc, _ = tf.metrics.auc(predictions=valid_logits, labels=valid_label)
10096

10197
# saver
10298
checkpoint_file = FLAGS.checkpoint_dir + "/model.checkpoint"
10399
saver = tf.train.Saver()
104100

105-
with tf.Session() as sess:
106-
sess.run(tf.initialize_all_variables())
107-
sess.run(tf.initialize_local_variables())
108-
sess.run(tf.tables_initializer())
109-
110-
if FLAGS.train_from_checkpoint:
111-
checkpoint_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
112-
if checkpoint_state and checkpoint_state.model_checkpoint_path:
113-
print("Continue training from checkpoint {}".format(checkpoint_state.model_checkpoint_path))
114-
saver.restore(sess, checkpoint_state.model_checkpoint_path)
115-
116-
coord = tf.train.Coordinator()
117-
threads = tf.train.start_queue_runners(coord=coord, sess=sess)
118-
try:
119-
while not coord.should_stop():
120-
_, step, train_loss_val, train_auc_val = sess.run([train_op, global_step, loss, auc])
121-
#if step % FLAGS.steps_to_validate == 0:
122-
#valid_loss_val, valid_auc_val = sess.run([valid_loss_op, valid_auc])
123-
#print("Step: {}, train loss: {}, train auc: {}, valid loss: {}, valid auc: {}".format(
124-
# step, train_loss_val, train_auc_val, valid_loss_val, valid_auc_val))
125-
except tf.errors.OutOfRangeError:
126-
print("training done")
127-
finally:
128-
coord.request_stop()
129-
130-
saver.save(sess, checkpoint_file)
131-
tf.train.write_graph(sess.graph.as_graph_def(), FLAGS.model_dir, 'graph.pb', as_text=False)
132-
tf.train.write_graph(sess.graph.as_graph_def(), FLAGS.model_dir, 'graph.txt', as_text=True)
133-
134-
# wait for threads to exit
135-
coord.join(threads)
136-
sess.close()
101+
with tf.Session() as sess:
102+
sess.run(tf.global_variables_initializer())
103+
sess.run(tf.local_variables_initializer())
104+
sess.run(tf.tables_initializer())
105+
106+
if FLAGS.train_from_checkpoint:
107+
checkpoint_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
108+
if checkpoint_state and checkpoint_state.model_checkpoint_path:
109+
print("Continue training from checkpoint {}".format(checkpoint_state.model_checkpoint_path))
110+
saver.restore(sess, checkpoint_state.model_checkpoint_path)
111+
112+
coord = tf.train.Coordinator()
113+
threads = tf.train.start_queue_runners(coord=coord, sess=sess)
114+
try:
115+
while not coord.should_stop():
116+
_, step, train_loss_val, train_auc_val = sess.run([train_op, global_step, loss, auc])
117+
if step % FLAGS.steps_to_validate == 0:
118+
valid_loss_val, valid_auc_val = sess.run([valid_loss, valid_auc])
119+
print("Step: {}, train loss: {}, train auc: {}, valid loss: {}, valid auc: {}".format(
120+
step, train_loss_val, train_auc_val, valid_loss_val, valid_auc_val))
121+
except tf.errors.OutOfRangeError:
122+
print("training done")
123+
finally:
124+
coord.request_stop()
125+
126+
saver.save(sess, checkpoint_file)
127+
tf.train.write_graph(sess.graph.as_graph_def(), FLAGS.model_dir, 'graph.pb', as_text=False)
128+
tf.train.write_graph(sess.graph.as_graph_def(), FLAGS.model_dir, 'graph.txt', as_text=True)
129+
130+
# wait for threads to exit
131+
coord.join(threads)
132+
sess.close()

0 commit comments

Comments
 (0)