|
85 | 85 | else:
|
86 | 86 | print("Error: unknown optimizer: {}".format(FLAGS.optimizer))
|
87 | 87 | exit(1)
|
| 88 | + |
88 | 89 | with tf.device("/cpu:0"):
|
89 | 90 | global_step = tf.Variable(0, name='global_step', trainable=False)
|
90 | 91 | train_op = optimizer.minimize(cost, global_step=global_step)
|
91 | 92 |
|
92 |
| -# eval acc |
| 93 | +# eval |
| 94 | +label_num = 2 |
93 | 95 | tf.get_variable_scope().reuse_variables()
|
| 96 | + |
| 97 | +# train cross entropy loss |
| 98 | +train_logits, _ = model.forward(train_sparse_id, train_sparse_val) |
| 99 | +train_label = tf.to_int64(train_label) |
| 100 | +train_cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=train_logits, labels=train_label) |
| 101 | +train_loss_op = tf.reduce_mean(train_cross_entropy) |
| 102 | + |
| 103 | +# train auc |
| 104 | +train_auc_op = tf.metrics.auc(predictions=train_logits, labels=train_label) |
| 105 | + |
| 106 | +# valid cross entropy loss |
94 | 107 | valid_logits, _ = model.forward(valid_sparse_id, valid_sparse_val)
|
95 |
| -valid_softmax = tf.nn.softmax(valid_logits) |
96 | 108 | valid_label = tf.to_int64(valid_label)
|
97 |
| -correct_prediction = tf.equal(tf.argmax(valid_softmax, 1), valid_label) |
98 |
| -accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) |
| 109 | +valid_cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=valid_logits, labels=valid_label) |
| 110 | +valid_loss_op = tf.reduce_mean(valid_cross_entropy) |
99 | 111 |
|
100 |
| -# eval auc |
101 |
| -auc = tf.metrics.auc(predictions=valid_logits, labels=valid_label) |
| 112 | +# valid auc |
| 113 | +valid_auc_op = tf.metrics.auc(predictions=valid_logits, labels=valid_label) |
102 | 114 |
|
103 | 115 | # saver
|
104 | 116 | checkpoint_file = FLAGS.checkpoint_dir + "/model.checkpoint"
|
105 | 117 | saver = tf.train.Saver()
|
106 | 118 |
|
107 |
| -# summary |
108 |
| -#tf.scalar_summary('loss', loss) |
109 |
| -#tf.scalar_summary('accuracy', accuracy) |
110 |
| -#summary_op = tf.merge_all_summaries() |
111 |
| - |
112 | 119 | # train loop
|
113 | 120 | with tf.Session() as sess:
|
114 | 121 | init_op = tf.initialize_all_variables()
|
115 |
| - #writer = tf.train.SummaryWriter(FLAGS.tensorboard_dir, sess.graph) |
116 | 122 | sess.run(init_op)
|
117 | 123 | sess.run(tf.initialize_local_variables())
|
118 | 124 |
|
|
126 | 132 | threads = tf.train.start_queue_runners(coord=coord, sess=sess)
|
127 | 133 | try:
|
128 | 134 | while not coord.should_stop():
|
129 |
| - _, loss_value, step = sess.run([train_op, loss, global_step]) |
| 135 | + _, step = sess.run([train_op, global_step]) |
130 | 136 | if step % FLAGS.steps_to_validate == 0:
|
131 |
| - auc_value = sess.run([auc]) |
132 |
| - print("Step: {}, loss: {}, auc: {}".format( |
133 |
| - step, loss_value, auc_value)) |
134 |
| - #writer.add_summary(summary_value, step) |
| 137 | + train_loss, valid_loss = sess.run([train_loss_op, valid_loss_op]) |
| 138 | + print("Step: {}, train loss: {}, valid loss: {}".format( |
| 139 | + step, train_loss, valid_loss)) |
135 | 140 | except tf.errors.OutOfRangeError:
|
136 | 141 | print("training done")
|
137 | 142 | finally:
|
|
0 commit comments