forked from ganyc717/LeNet
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Train.py
59 lines (51 loc) · 1.98 KB
/
Train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import tensorflow.examples.tutorials.mnist.input_data as input_data
import tensorflow as tf
import config as cfg
import os
import lenet
from lenet import Lenet
import numpy as np
import matplotlib.pyplot as plt
def main():
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
sess = tf.Session()
batch_size = cfg.BATCH_SIZE
parameter_path = cfg.PARAMETER_FILE
lenet = Lenet()
max_iter = cfg.MAX_ITER
saver = tf.train.Saver()
if os.path.exists(parameter_path):
saver.restore(parameter_path)
else:
sess.run(tf.initialize_all_variables())
temp_step = 10
result_step = np.arange(0, temp_step*100, 100)
result_acc = np.zeros(temp_step)
result_loss = np.zeros(temp_step)
result_test = np.zeros(temp_step)
for i in range(temp_step*100+1):
batch = mnist.train.next_batch(50)
if i % 100 == 0:
r = int(i / 100 - 1)
result_acc[r] = sess.run(lenet.train_accuracy,feed_dict={
lenet.raw_input_image: batch[0], lenet.raw_input_label: batch[1]
})
result_loss[r] = sess.run(lenet.loss, feed_dict={
lenet.raw_input_image: batch[0], lenet.raw_input_label: batch[1]
})
print("step %d, training accuracy %g, training loss %g" % (i, result_acc[r], result_loss[r]))
result_test[r] = sess.run(lenet.train_accuracy, feed_dict={
lenet.raw_input_image: mnist.test.images, lenet.raw_input_label: mnist.test.labels
})
print("test accuracy %g" % (result_test[r]))
sess.run(lenet.train_op,feed_dict={lenet.raw_input_image: batch[0],lenet.raw_input_label: batch[1]})
save_path = saver.save(sess, parameter_path)
plt.plot(result_step, result_acc, label='training accuracy')
plt.plot(result_step, result_test, label='test accuracy')
plt.title('LeNet')
plt.xlabel('step')
plt.ylabel('accuracy')
plt.legend()
plt.show()
if __name__ == '__main__':
main()