# 畳み込みフィルターを用いた手書き文字の分類

## セッション情報の保存機能
TensorFlowではトレーニング処理を実施中のセッションの状態をファイルに保存しておくことができる

## 単層CNNによる手書き文字の分類

In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

np.random.seed(42)
tf.set_random_seed(42)

Instructions for updating:
Use the retry module or similar alternatives.


In [2]:
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use urllib or similar directly.
Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting /tmp/data/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.


In [4]:
num_filters = 16

x = tf.placeholder(tf.float32, [None, 784])
x_image = tf.reshape(x, [-1, 28, 28, 1])

W_conv = tf.Variable(tf.truncated_normal([5,5,1,num_filters], stddev = 0.1))#畳み込みフィルターの動的な学習
#フィルターサイズ(縦x横) x 入力レイヤー数 x 出力レイヤー数
#stddevオプションで乱数の範囲を指定 +-0.1

h_conv = tf.nn.conv2d(x_image, W_conv, strides = [1,1,1,1], padding='SAME')

h_pool = tf.nn.max_pool(h_conv, ksize = [1,2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
#ここではエッジを取り出すことが目的ではなくあくまでも特徴を抽出することが目的なので、絶対値をとる操作は行わない.
#ピクセルの値が負になる可能性があるが、画像の特徴を抽出したデータとしては、意味のあるものになっている。

In [5]:
h_pool_flat = tf.reshape(h_pool, [-1, 14*14*num_filters])

num_units1 = 14*14*num_filters
num_units2 = 1024

w2 = tf.Variable(tf.truncated_normal([num_units1,  num_units2]))
b2 = tf.Variable(tf.zeros([num_units2]))
hidden2 = tf.nn.relu(tf.matmul(h_pool_flat, w2) + b2)

w0 = tf.Variable(tf.zeros([num_units2, 10]))
b0 = tf.Variable(tf.zeros([10]))
p = tf.nn.softmax(tf.matmul(hidden2, w0) + b0)

In [6]:
t = tf.placeholder(tf.float32, [None, 10])
loss = -tf.reduce_sum(t * tf.log(p))
train_step = tf.train.AdamOptimizer(0.0005).minimize(loss)
correct_prediction = tf.equal(tf.argmax(p, 1), tf.argmax(t, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

In [8]:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver() #save セッションの状態を保存
#saver.restore(sess, 'mdc_session-4000')のようにしてセッションの状態を復元する

In [13]:
i = 0
for _ in range (4000):
    i += 1
    batch_xs, batch_ts = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x:batch_xs, t:batch_ts})
    if i % 100 ==0:
        loss_val, acc_val = sess.run([loss, accuracy], feed_dict={x:mnist.test.images, t:mnist.test.labels})
        print('Step: %d, Loss: %f, Accuracy: %f' % (i, loss_val, acc_val)) 
        saver.save(sess, './mdc_session', global_step = i)

Step: 100, Loss: 1246.051880, Accuracy: 0.963100
Step: 200, Loss: 1056.027588, Accuracy: 0.969600
Step: 300, Loss: 893.175171, Accuracy: 0.972400
Step: 400, Loss: 1024.038330, Accuracy: 0.970200
Step: 500, Loss: 899.188599, Accuracy: 0.972900
Step: 600, Loss: 908.228394, Accuracy: 0.972500
Step: 700, Loss: 735.593567, Accuracy: 0.978000
Step: 800, Loss: 1042.673950, Accuracy: 0.967700
Step: 900, Loss: 777.806580, Accuracy: 0.977600
Step: 1000, Loss: 829.814941, Accuracy: 0.975900
Step: 1100, Loss: 811.404541, Accuracy: 0.976200
Step: 1200, Loss: 765.644409, Accuracy: 0.978200
Step: 1300, Loss: 773.578979, Accuracy: 0.977200
Step: 1400, Loss: 690.804932, Accuracy: 0.980000
Step: 1500, Loss: 677.963501, Accuracy: 0.978000
Step: 1600, Loss: 738.879333, Accuracy: 0.978000
Step: 1700, Loss: 669.488770, Accuracy: 0.979500
Step: 1800, Loss: 828.950073, Accuracy: 0.975400
Step: 1900, Loss: 767.830078, Accuracy: 0.977200
Step: 2000, Loss: 699.538208, Accuracy: 0.978900
Step: 2100, Loss: 655.494

In [14]:
!ls mdc_session*

mdc_session-3600.data-00000-of-00001 mdc_session-3800.meta
mdc_session-3600.index               mdc_session-3900.data-00000-of-00001
mdc_session-3600.meta                mdc_session-3900.index
mdc_session-3700.data-00000-of-00001 mdc_session-3900.meta
mdc_session-3700.index               mdc_session-4000.data-00000-of-00001
mdc_session-3700.meta                mdc_session-4000.index
mdc_session-3800.data-00000-of-00001 mdc_session-4000.meta
mdc_session-3800.index
