In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
slim = tf.contrib.slim

In [None]:
# 字元集
CHAR_SET = [str(i) for i in range(10)]
CHAR_SET_LEN = len(CHAR_SET)
# 批次大小
BATCH_SIZE = 1
# tfrecord文件
TFRECORD_FILE = 'captcha/test.tfrecord'

In [None]:
def read_and_decode(serial_exmp):
    features = tf.parse_single_example(serial_exmp,
                                       features={
                                           'image': tf.FixedLenFeature([], tf.string),
                                           'label0': tf.FixedLenFeature([], tf.int64),
                                           'label1': tf.FixedLenFeature([], tf.int64),
                                           'label2': tf.FixedLenFeature([], tf.int64),
                                           'label3': tf.FixedLenFeature([], tf.int64)
                                       })
    image = tf.decode_raw(features['image'], tf.uint8)
    image = tf.reshape(image,[224, 224])
    image = tf.cast(image, tf.float32) / 255.0    # 0 to 1
    image = tf.subtract(image, 0.5)               # -0.5 to 0.5
    image = tf.multiply(image, 2.0)               # -1 to 1

    label0 = tf.cast(features['label0'], tf.int32)
    label1 = tf.cast(features['label1'], tf.int32)
    label2 = tf.cast(features['label2'], tf.int32)
    label3 = tf.cast(features['label3'], tf.int32)
    return image, label0, label1, label2, label3

dataset = tf.data.TFRecordDataset(TFRECORD_FILE)

# 此時dataset中的一個元素是(image, label0, label1, label2, label3)
dataset = dataset.map(read_and_decode)
dataset = dataset.shuffle(buffer_size=2000).batch(BATCH_SIZE).repeat()

In [None]:
def alexnet_v2_captcha_multi(inputs,
                             num_classes=10,
                             is_training=True,
                             dropout_keep_prob=0.5,
                             spatial_squeeze=True,
                             scope_name='alexnet_v2_captcha_multi',
                             global_pool=False):
    '''
    參考 tensorflow github source code，改成 multi task learning
    '''
    with tf.variable_scope(scope_name) as sc:
        net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', scope='conv1')
        net = slim.max_pool2d(net, [3, 3], 2, scope='pool1')   
        net = slim.conv2d(net, 192, [5, 5], scope='conv2')
        net = slim.max_pool2d(net, [3, 3], 2, scope='pool2')
        net = slim.conv2d(net, 384, [3, 3], scope='conv3')
        net = slim.conv2d(net, 384, [3, 3], scope='conv4')
        net = slim.conv2d(net, 256, [3, 3], scope='conv5')
        net = slim.max_pool2d(net, [3, 3], 2, scope='pool5')

        with slim.arg_scope([slim.conv2d],
                            weights_initializer=tf.truncated_normal_initializer(0.0, 0.005),
                            weights_regularizer=slim.l2_regularizer(0.0005),
                            biases_initializer=tf.constant_initializer(0.1)):
            net = slim.conv2d(net, 4096, [5, 5], padding='VALID', scope='fc6')
            net = slim.dropout(net, dropout_keep_prob, is_training=is_training, scope='dropout6')
            net = slim.conv2d(net, 4096, [1, 1], scope='fc7')

            if num_classes:
                net = slim.dropout(net, dropout_keep_prob, is_training=is_training, scope='dropout7')

                net0 = slim.conv2d(net, num_classes, [1, 1],
                                   activation_fn=None,
                                   normalizer_fn=None,
                                   biases_initializer=tf.zeros_initializer(),
                                   scope='fc8_0')

                net1 = slim.conv2d(net, num_classes, [1,1],
                                   activation_fn=None,
                                   normalizer_fn=None,
                                   biases_initializer=tf.zeros_initializer(),
                                   scope='fc8_1')

                net2 = slim.conv2d(net, num_classes, [1,1],
                                   activation_fn=None,
                                   normalizer_fn=None,
                                   biases_initializer=tf.zeros_initializer(),
                                   scope='fc8_2')

                net3 = slim.conv2d(net, num_classes, [1,1],
                                   activation_fn=None,
                                   normalizer_fn=None,
                                   biases_initializer=tf.zeros_initializer(),
                                   scope='fc8_3')

            # 壓縮維度 4D to 2D，[batch, 1, 1, 10] to [batch, 10]
            if spatial_squeeze:
                net0 = tf.squeeze(net0, [1, 2], name='fc8_0/squeezed')
                net1 = tf.squeeze(net1, [1, 2], name='fc8_1/squeezed')
                net2 = tf.squeeze(net2, [1, 2], name='fc8_2/squeezed')
                net3 = tf.squeeze(net3, [1, 2], name='fc8_3/squeezed')
    return net0, net1, net2, net3

In [None]:
# 定義網路變數
x = tf.placeholder(tf.float32, [None, 224, 224])

X = tf.reshape(x, [BATCH_SIZE, 224, 224, 1])
logits0, logits1, logits2, logits3 = alexnet_v2_captcha_multi(X)

# 預測值
predict0 = tf.reshape(logits0, [-1, CHAR_SET_LEN])  
predict0 = tf.argmax(predict0, 1)

predict1 = tf.reshape(logits1, [-1, CHAR_SET_LEN])  
predict1 = tf.argmax(predict1, 1)

predict2 = tf.reshape(logits2, [-1, CHAR_SET_LEN])  
predict2 = tf.argmax(predict2, 1)

predict3 = tf.reshape(logits3, [-1, CHAR_SET_LEN])  
predict3 = tf.argmax(predict3, 1)

In [None]:
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess,'captcha/model/crack_captcha.model-3219')
    
    fig = plt.figure(figsize=(20, 10), dpi=200)
    for i in range(10):
        b_image, b_label0, b_label1, b_label2, b_label3 = sess.run(next_element)
        label0, label1, label2, label3 = sess.run([predict0, predict1, predict2, predict3], feed_dict={x: b_image})
        
        # list to string
        b_label0 = str(b_label0[0])
        b_label1 = str(b_label1[0])
        b_label2 = str(b_label2[0])
        b_label3 = str(b_label3[0])
        label0 = str(label0[0])
        label1 = str(label1[0])
        label2 = str(label2[0])
        label3 = str(label3[0])
        
        # 顯示圖片
        ax = plt.subplot(2, 5, i+1)
        ax.imshow(b_image[0], cmap='gray')    
        ax.axis('off')
        
        title_label = b_label0 + b_label1 + b_label2 + b_label3
        title_predict = label0 + label1 + label2 + label3
        title = 'label=' + title_label + '\n' + 'predict=' + title_predict
        ax.set_title(title)