diff --git a/train.py b/train.py index 8aefeb8..c8765ad 100644 --- a/train.py +++ b/train.py @@ -1,7 +1,10 @@ +import os import tensorflow as tf import driving_data import model +LOGDIR = './save' + sess = tf.InteractiveSession() loss = tf.reduce_mean(tf.square(tf.sub(model.y_, model.y))) @@ -18,5 +21,8 @@ print("step %d, train loss %g"%(i, loss.eval(feed_dict={ model.x:xs, model.y_: ys, model.keep_prob: 1.0}))) if i % 100 == 0: - save_path = saver.save(sess, "save/model.ckpt") - print("Model saved in file: %s" % save_path) + if not os.path.exists(LOGDIR): + os.makedirs(LOGDIR) + checkpoint_path = os.path.join(LOGDIR, "model.ckpt") + filename = saver.save(sess, checkpoint_path) + print("Model saved in file: %s" % filename)