|
100 | 100 | train_loss_op = tf.reduce_mean(train_cross_entropy)
|
101 | 101 |
|
102 | 102 | # train auc
|
103 |
| -train_auc_op = tf.metrics.auc(predictions=train_logits, labels=train_label) |
| 103 | +train_auc, _ = tf.metrics.auc(predictions=train_logits, labels=train_label) |
104 | 104 |
|
105 | 105 | # valid cross entropy loss
|
106 | 106 | valid_logits, _ = model.forward(valid_sparse_id, valid_sparse_val)
|
|
109 | 109 | valid_loss_op = tf.reduce_mean(valid_cross_entropy)
|
110 | 110 |
|
111 | 111 | # valid auc
|
112 |
| -valid_auc_op = tf.metrics.auc(predictions=valid_logits, labels=valid_label) |
| 112 | +valid_auc, _ = tf.metrics.auc(predictions=valid_logits, labels=valid_label) |
113 | 113 |
|
114 | 114 | # saver
|
115 | 115 | checkpoint_file = FLAGS.checkpoint_dir + "/model.checkpoint"
|
116 | 116 | saver = tf.train.Saver()
|
117 | 117 |
|
118 | 118 | # train loop
|
119 | 119 | with tf.Session() as sess:
|
120 |
| - init_op = tf.initialize_all_variables() |
121 |
| - sess.run(init_op) |
| 120 | + sess.run(tf.initialize_all_variables()) |
122 | 121 | sess.run(tf.initialize_local_variables())
|
123 | 122 |
|
124 | 123 | if FLAGS.train_from_checkpoint:
|
|
133 | 132 | while not coord.should_stop():
|
134 | 133 | _, step = sess.run([train_op, global_step])
|
135 | 134 | if step % FLAGS.steps_to_validate == 0:
|
136 |
| - train_loss, valid_loss = sess.run([train_loss_op, valid_loss_op]) |
137 |
| - print("Step: {}, train loss: {}, valid loss: {}".format( |
138 |
| - step, train_loss, valid_loss)) |
| 135 | + train_loss_val, train_auc_val, valid_loss_val, valid_auc_val = sess.run([train_loss_op, train_auc, valid_loss_op, valid_auc]) |
| 136 | + print("Step: {}, train loss: {}, train auc: {}, valid loss: {}, valid auc: {}".format( |
| 137 | + step, train_loss_val, train_auc_val, valid_loss_val, valid_auc_val)) |
139 | 138 | except tf.errors.OutOfRangeError:
|
140 | 139 | print("training done")
|
141 | 140 | finally:
|
|
0 commit comments