In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
slim = tf.contrib.slim

In [None]:
# 字元集
CHAR_SET = [str(i) for i in range(10)]
CHAR_SET_LEN = len(CHAR_SET)
# 訓練集大小
TRAIN_NUM = 4000
# 批次大小
BATCH_SIZE = 100
# 迭代次數
TOTAL_STEPS = 4000
# tfrecord文件
TFRECORD_FILE = 'captcha/train.tfrecord'
# 初始學習率
LEARNING_RATE = 0.001

# TFRecordDataset

In [None]:
def read_and_decode(serial_exmp):
    features = tf.parse_single_example(serial_exmp,
                                       features={
                                           'image': tf.FixedLenFeature([], tf.string),
                                           'label0': tf.FixedLenFeature([], tf.int64),
                                           'label1': tf.FixedLenFeature([], tf.int64),
                                           'label2': tf.FixedLenFeature([], tf.int64),
                                           'label3': tf.FixedLenFeature([], tf.int64)
                                       })
    image = tf.decode_raw(features['image'], tf.uint8)
    image = tf.reshape(image,[224, 224])
    image = tf.cast(image, tf.float32) / 255.0    # 0 to 1
    image = tf.subtract(image, 0.5)               # -0.5 to 0.5
    image = tf.multiply(image, 2.0)               # -1 to 1

    label0 = tf.cast(features['label0'], tf.int32)
    label1 = tf.cast(features['label1'], tf.int32)
    label2 = tf.cast(features['label2'], tf.int32)
    label3 = tf.cast(features['label3'], tf.int32)
    return image, label0, label1, label2, label3

dataset = tf.data.TFRecordDataset(TFRECORD_FILE)

# 此時dataset中的一個元素是(image, label0, label1, label2, label3)
dataset = dataset.map(read_and_decode)
dataset = dataset.shuffle(buffer_size=2000).batch(BATCH_SIZE).repeat()

# 定義神經網路

In [None]:
def alexnet_v2_captcha_multi(inputs,
                             num_classes=10,
                             is_training=True,
                             dropout_keep_prob=0.5,
                             spatial_squeeze=True,
                             scope_name='alexnet_v2_captcha_multi',
                             global_pool=False):
    '''
    參考 tensorflow github source code，改成 multi task learning
    '''
    with tf.variable_scope(scope_name) as sc:
        net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', scope='conv1')
        net = slim.max_pool2d(net, [3, 3], 2, scope='pool1')   
        net = slim.conv2d(net, 192, [5, 5], scope='conv2')
        net = slim.max_pool2d(net, [3, 3], 2, scope='pool2')
        net = slim.conv2d(net, 384, [3, 3], scope='conv3')
        net = slim.conv2d(net, 384, [3, 3], scope='conv4')
        net = slim.conv2d(net, 256, [3, 3], scope='conv5')
        net = slim.max_pool2d(net, [3, 3], 2, scope='pool5')

        with slim.arg_scope([slim.conv2d],
                            weights_initializer=tf.truncated_normal_initializer(0.0, 0.005),
                            weights_regularizer=slim.l2_regularizer(0.0005),
                            biases_initializer=tf.constant_initializer(0.1)):
            net = slim.conv2d(net, 4096, [5, 5], padding='VALID', scope='fc6')
            net = slim.dropout(net, dropout_keep_prob, is_training=is_training, scope='dropout6')
            net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
            net = slim.dropout(net, dropout_keep_prob, is_training=is_training, scope='dropout7')

            if num_classes:

                net0 = slim.conv2d(net, num_classes, [1, 1],
                                   activation_fn=None,
                                   normalizer_fn=None,
                                   biases_initializer=tf.zeros_initializer(),
                                   scope='fc8_0')

                net1 = slim.conv2d(net, num_classes, [1,1],
                                   activation_fn=None,
                                   normalizer_fn=None,
                                   biases_initializer=tf.zeros_initializer(),
                                   scope='fc8_1')

                net2 = slim.conv2d(net, num_classes, [1,1],
                                   activation_fn=None,
                                   normalizer_fn=None,
                                   biases_initializer=tf.zeros_initializer(),
                                   scope='fc8_2')

                net3 = slim.conv2d(net, num_classes, [1,1],
                                   activation_fn=None,
                                   normalizer_fn=None,
                                   biases_initializer=tf.zeros_initializer(),
                                   scope='fc8_3')

            # 壓縮維度 4D to 2D，[batch, 1, 1, 10] to [batch, 10]
            if spatial_squeeze:
                net0 = tf.squeeze(net0, [1, 2], name='fc8_0/squeezed')
                net1 = tf.squeeze(net1, [1, 2], name='fc8_1/squeezed')
                net2 = tf.squeeze(net2, [1, 2], name='fc8_2/squeezed')
                net3 = tf.squeeze(net3, [1, 2], name='fc8_3/squeezed')
    return net0, net1, net2, net3

In [None]:
# 定義網路變數
x = tf.placeholder(tf.float32, [None, 224, 224])
y0 = tf.placeholder(tf.float32, [None])
y1 = tf.placeholder(tf.float32, [None])
y2 = tf.placeholder(tf.float32, [None])
y3 = tf.placeholder(tf.float32, [None])
lr = tf.Variable(LEARNING_RATE, dtype=tf.float32)

X = tf.reshape(x, [BATCH_SIZE, 224, 224, 1])
logits0, logits1, logits2, logits3 = alexnet_v2_captcha_multi(X)

one_hot_label0 = tf.one_hot(indices=tf.cast(y0,tf.int32), depth=CHAR_SET_LEN)
one_hot_label1 = tf.one_hot(indices=tf.cast(y1,tf.int32), depth=CHAR_SET_LEN)
one_hot_label2 = tf.one_hot(indices=tf.cast(y2,tf.int32), depth=CHAR_SET_LEN)
one_hot_label3 = tf.one_hot(indices=tf.cast(y3,tf.int32), depth=CHAR_SET_LEN)

loss0 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=one_hot_label0, logits=logits0))
loss1 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=one_hot_label1, logits=logits1))
loss2 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=one_hot_label2, logits=logits2))
loss3 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=one_hot_label3, logits=logits3))

total_loss = (loss0 + loss1 + loss2 + loss3) / 4.0
train_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(total_loss)

# 計算準確率
y_pred0 = tf.nn.softmax(logits0)
correct_pre0 = tf.equal(tf.argmax(one_hot_label0, 1), tf.argmax(y_pred0, 1))
accuracy0 = tf.reduce_mean(tf.cast(correct_pre0, tf.float32))

y_pred1 = tf.nn.softmax(logits1)
correct_pre1 = tf.equal(tf.argmax(one_hot_label1, 1), tf.argmax(y_pred1, 1))
accuracy1 = tf.reduce_mean(tf.cast(correct_pre1, tf.float32))

y_pred2 = tf.nn.softmax(logits2)
correct_pre2 = tf.equal(tf.argmax(one_hot_label2, 1), tf.argmax(y_pred2, 1))
accuracy2 = tf.reduce_mean(tf.cast(correct_pre2, tf.float32))

y_pred3 = tf.nn.softmax(logits3)
correct_pre3 = tf.equal(tf.argmax(one_hot_label3, 1), tf.argmax(y_pred3, 1))
accuracy3 = tf.reduce_mean(tf.cast(correct_pre3, tf.float32))

# 檢測 next_element

In [None]:
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    b_image, b_label0, b_label1, b_label2, b_label3 = sess.run(next_element)
    print(b_image.shape)
    plt.imshow(b_image[0], cmap='gray')
    print(b_label0[0], b_label1[0], b_label2[0], b_label3[0])

# 開始訓練

In [None]:
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    writer = tf.summary.FileWriter('captcha/TensorBoard/', graph = sess.graph)
    
    for i in range(TOTAL_STEPS):
        b_image, b_label0, b_label1, b_label2, b_label3 = sess.run(next_element)       
        sess.run(train_op, feed_dict={x:b_image, y0:b_label0, y1:b_label1, y2:b_label2, y3:b_label3})
            
        if i % 500 == 0 and i > 0:
            sess.run(tf.assign(lr, lr * 0.5))

        if i % 10 == 0:
            acc0, acc1, acc2, acc3, loss_ = sess.run([accuracy0, accuracy1, accuracy2, accuracy3, total_loss],
                                                     feed_dict={x:b_image,
                                                               y0:b_label0,
                                                               y1:b_label1,
                                                               y2:b_label2,
                                                               y3:b_label3})
            learning_rate = sess.run(lr)
            print("Iter:%d/%d ,  Loss:%.3f  Accuracy:%.2f,%.2f,%.2f,%.2f  Learning_rate:%.5f" % (
                i, TOTAL_STEPS, loss_, acc0, acc1, acc2, acc3, learning_rate))

        if acc0 > 0.90 and acc1 > 0.90 and acc2 > 0.90 and acc3 > 0.90:
            saver.save(sess,'captcha/model/crack_captcha.model', global_step=i)