|
38 | 38 | if not os.path.exists(FLAGS.tensorboard_dir):
|
39 | 39 | os.makedirs(FLAGS.tensorboard_dir)
|
40 | 40 |
|
41 |
| -# train loop |
42 |
| -with tf.Graph().as_default(): |
| 41 | +with tf.device('/cpu:0'): |
43 | 42 | # data iter
|
44 | 43 | data = Data(FLAGS.sparse_fields)
|
45 | 44 | train_label, train_sparse_id, train_sparse_val = data.ReadBatch(FLAGS.train_file,
|
|
82 | 81 | print("Error: unknown optimizer: {}".format(FLAGS.optimizer))
|
83 | 82 | exit(1)
|
84 | 83 |
|
85 |
| - with tf.device("/cpu:0"): |
86 |
| - global_step = tf.Variable(0, name='global_step', trainable=False) |
| 84 | + global_step = tf.Variable(0, name='global_step', trainable=False) |
87 | 85 | train_op = optimizer.minimize(cost, global_step=global_step)
|
88 | 86 |
|
89 |
| - # eval |
| 87 | + # to eval |
90 | 88 | tf.get_variable_scope().reuse_variables()
|
91 | 89 |
|
92 |
| - # valid cross entropy loss |
93 |
| - #valid_logits, _ = model.forward(valid_sparse_id, valid_sparse_val) |
94 |
| - #valid_label = tf.to_int64(valid_label) |
95 |
| - #valid_cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=valid_logits, labels=valid_label) |
96 |
| - #valid_loss_op = tf.reduce_mean(valid_cross_entropy) |
97 |
| - |
98 |
| - # valid auc |
99 |
| - #valid_auc, _ = tf.metrics.auc(predictions=valid_logits, labels=valid_label) |
| 90 | + # valid metric |
| 91 | + valid_logits, _ = model.forward(valid_sparse_id, valid_sparse_val) |
| 92 | + valid_label = tf.to_int64(valid_label) |
| 93 | + valid_cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=valid_logits, labels=valid_label) |
| 94 | + valid_loss = tf.reduce_mean(valid_cross_entropy) |
| 95 | + valid_auc, _ = tf.metrics.auc(predictions=valid_logits, labels=valid_label) |
100 | 96 |
|
101 | 97 | # saver
|
102 | 98 | checkpoint_file = FLAGS.checkpoint_dir + "/model.checkpoint"
|
103 | 99 | saver = tf.train.Saver()
|
104 | 100 |
|
105 |
| - with tf.Session() as sess: |
106 |
| - sess.run(tf.initialize_all_variables()) |
107 |
| - sess.run(tf.initialize_local_variables()) |
108 |
| - sess.run(tf.tables_initializer()) |
109 |
| - |
110 |
| - if FLAGS.train_from_checkpoint: |
111 |
| - checkpoint_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) |
112 |
| - if checkpoint_state and checkpoint_state.model_checkpoint_path: |
113 |
| - print("Continue training from checkpoint {}".format(checkpoint_state.model_checkpoint_path)) |
114 |
| - saver.restore(sess, checkpoint_state.model_checkpoint_path) |
115 |
| - |
116 |
| - coord = tf.train.Coordinator() |
117 |
| - threads = tf.train.start_queue_runners(coord=coord, sess=sess) |
118 |
| - try: |
119 |
| - while not coord.should_stop(): |
120 |
| - _, step, train_loss_val, train_auc_val = sess.run([train_op, global_step, loss, auc]) |
121 |
| - #if step % FLAGS.steps_to_validate == 0: |
122 |
| - #valid_loss_val, valid_auc_val = sess.run([valid_loss_op, valid_auc]) |
123 |
| - #print("Step: {}, train loss: {}, train auc: {}, valid loss: {}, valid auc: {}".format( |
124 |
| - # step, train_loss_val, train_auc_val, valid_loss_val, valid_auc_val)) |
125 |
| - except tf.errors.OutOfRangeError: |
126 |
| - print("training done") |
127 |
| - finally: |
128 |
| - coord.request_stop() |
129 |
| - |
130 |
| - saver.save(sess, checkpoint_file) |
131 |
| - tf.train.write_graph(sess.graph.as_graph_def(), FLAGS.model_dir, 'graph.pb', as_text=False) |
132 |
| - tf.train.write_graph(sess.graph.as_graph_def(), FLAGS.model_dir, 'graph.txt', as_text=True) |
133 |
| - |
134 |
| - # wait for threads to exit |
135 |
| - coord.join(threads) |
136 |
| - sess.close() |
| 101 | +with tf.Session() as sess: |
| 102 | + sess.run(tf.global_variables_initializer()) |
| 103 | + sess.run(tf.local_variables_initializer()) |
| 104 | + sess.run(tf.tables_initializer()) |
| 105 | + |
| 106 | + if FLAGS.train_from_checkpoint: |
| 107 | + checkpoint_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) |
| 108 | + if checkpoint_state and checkpoint_state.model_checkpoint_path: |
| 109 | + print("Continue training from checkpoint {}".format(checkpoint_state.model_checkpoint_path)) |
| 110 | + saver.restore(sess, checkpoint_state.model_checkpoint_path) |
| 111 | + |
| 112 | + coord = tf.train.Coordinator() |
| 113 | + threads = tf.train.start_queue_runners(coord=coord, sess=sess) |
| 114 | + try: |
| 115 | + while not coord.should_stop(): |
| 116 | + _, step, train_loss_val, train_auc_val = sess.run([train_op, global_step, loss, auc]) |
| 117 | + if step % FLAGS.steps_to_validate == 0: |
| 118 | + valid_loss_val, valid_auc_val = sess.run([valid_loss, valid_auc]) |
| 119 | + print("Step: {}, train loss: {}, train auc: {}, valid loss: {}, valid auc: {}".format( |
| 120 | + step, train_loss_val, train_auc_val, valid_loss_val, valid_auc_val)) |
| 121 | + except tf.errors.OutOfRangeError: |
| 122 | + print("training done") |
| 123 | + finally: |
| 124 | + coord.request_stop() |
| 125 | + |
| 126 | + saver.save(sess, checkpoint_file) |
| 127 | + tf.train.write_graph(sess.graph.as_graph_def(), FLAGS.model_dir, 'graph.pb', as_text=False) |
| 128 | + tf.train.write_graph(sess.graph.as_graph_def(), FLAGS.model_dir, 'graph.txt', as_text=True) |
| 129 | + |
| 130 | + # wait for threads to exit |
| 131 | + coord.join(threads) |
| 132 | + sess.close() |
0 commit comments