Skip to content

Commit 3f24230

Browse files
committed
auc
1 parent 776ec48 commit 3f24230

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

python/train.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@
100100
train_loss_op = tf.reduce_mean(train_cross_entropy)
101101

102102
# train auc
103-
train_auc_op = tf.metrics.auc(predictions=train_logits, labels=train_label)
103+
train_auc, _ = tf.metrics.auc(predictions=train_logits, labels=train_label)
104104

105105
# valid cross entropy loss
106106
valid_logits, _ = model.forward(valid_sparse_id, valid_sparse_val)
@@ -109,16 +109,15 @@
109109
valid_loss_op = tf.reduce_mean(valid_cross_entropy)
110110

111111
# valid auc
112-
valid_auc_op = tf.metrics.auc(predictions=valid_logits, labels=valid_label)
112+
valid_auc, _ = tf.metrics.auc(predictions=valid_logits, labels=valid_label)
113113

114114
# saver
115115
checkpoint_file = FLAGS.checkpoint_dir + "/model.checkpoint"
116116
saver = tf.train.Saver()
117117

118118
# train loop
119119
with tf.Session() as sess:
120-
init_op = tf.initialize_all_variables()
121-
sess.run(init_op)
120+
sess.run(tf.initialize_all_variables())
122121
sess.run(tf.initialize_local_variables())
123122

124123
if FLAGS.train_from_checkpoint:
@@ -133,9 +132,9 @@
133132
while not coord.should_stop():
134133
_, step = sess.run([train_op, global_step])
135134
if step % FLAGS.steps_to_validate == 0:
136-
train_loss, valid_loss = sess.run([train_loss_op, valid_loss_op])
137-
print("Step: {}, train loss: {}, valid loss: {}".format(
138-
step, train_loss, valid_loss))
135+
train_loss_val, train_auc_val, valid_loss_val, valid_auc_val = sess.run([train_loss_op, train_auc, valid_loss_op, valid_auc])
136+
print("Step: {}, train loss: {}, train auc: {}, valid loss: {}, valid auc: {}".format(
137+
step, train_loss_val, train_auc_val, valid_loss_val, valid_auc_val))
139138
except tf.errors.OutOfRangeError:
140139
print("training done")
141140
finally:

0 commit comments

Comments
 (0)