Skip to content

Commit eb49da2

Browse files
committed
fix
1 parent 15aad77 commit eb49da2

File tree

1 file changed

+22
-17
lines changed

1 file changed

+22
-17
lines changed

python/train.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -85,34 +85,40 @@
8585
else:
8686
print("Error: unknown optimizer: {}".format(FLAGS.optimizer))
8787
exit(1)
88+
8889
with tf.device("/cpu:0"):
8990
global_step = tf.Variable(0, name='global_step', trainable=False)
9091
train_op = optimizer.minimize(cost, global_step=global_step)
9192

92-
# eval acc
93+
# eval
94+
label_num = 2
9395
tf.get_variable_scope().reuse_variables()
96+
97+
# train cross entropy loss
98+
train_logits, _ = model.forward(train_sparse_id, train_sparse_val)
99+
train_label = tf.to_int64(train_label)
100+
train_cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=train_logits, labels=train_label)
101+
train_loss_op = tf.reduce_mean(train_cross_entropy)
102+
103+
# train auc
104+
train_auc_op = tf.metrics.auc(predictions=train_logits, labels=train_label)
105+
106+
# valid cross entropy loss
94107
valid_logits, _ = model.forward(valid_sparse_id, valid_sparse_val)
95-
valid_softmax = tf.nn.softmax(valid_logits)
96108
valid_label = tf.to_int64(valid_label)
97-
correct_prediction = tf.equal(tf.argmax(valid_softmax, 1), valid_label)
98-
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
109+
valid_cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=valid_logits, labels=valid_label)
110+
valid_loss_op = tf.reduce_mean(valid_cross_entropy)
99111

100-
# eval auc
101-
auc = tf.metrics.auc(predictions=valid_logits, labels=valid_label)
112+
# valid auc
113+
valid_auc_op = tf.metrics.auc(predictions=valid_logits, labels=valid_label)
102114

103115
# saver
104116
checkpoint_file = FLAGS.checkpoint_dir + "/model.checkpoint"
105117
saver = tf.train.Saver()
106118

107-
# summary
108-
#tf.scalar_summary('loss', loss)
109-
#tf.scalar_summary('accuracy', accuracy)
110-
#summary_op = tf.merge_all_summaries()
111-
112119
# train loop
113120
with tf.Session() as sess:
114121
init_op = tf.initialize_all_variables()
115-
#writer = tf.train.SummaryWriter(FLAGS.tensorboard_dir, sess.graph)
116122
sess.run(init_op)
117123
sess.run(tf.initialize_local_variables())
118124

@@ -126,12 +132,11 @@
126132
threads = tf.train.start_queue_runners(coord=coord, sess=sess)
127133
try:
128134
while not coord.should_stop():
129-
_, loss_value, step = sess.run([train_op, loss, global_step])
135+
_, step = sess.run([train_op, global_step])
130136
if step % FLAGS.steps_to_validate == 0:
131-
auc_value = sess.run([auc])
132-
print("Step: {}, loss: {}, auc: {}".format(
133-
step, loss_value, auc_value))
134-
#writer.add_summary(summary_value, step)
137+
train_loss, valid_loss = sess.run([train_loss_op, valid_loss_op])
138+
print("Step: {}, train loss: {}, valid loss: {}".format(
139+
step, train_loss, valid_loss))
135140
except tf.errors.OutOfRangeError:
136141
print("training done")
137142
finally:

0 commit comments

Comments
 (0)