In [28]:
import random
from captcha.image import ImageCaptcha
from PIL import Image
import numpy as np
import sys
import os
import tensorflow as tf

In [38]:
def gen_random_text(src, text_len=4):
    text = []
    for _ in range(text_len):
        text.append(random.choice(src))
    text = ''.join(text)
    return text


def gen_verify_code(char_set, output_dir):
    image = ImageCaptcha()
    rand_text = gen_random_text(char_set)
    captcha = image.generate(rand_text)
    image.write(rand_text, os.path.join(output_dir, rand_text + '.jpg'))
    

def bytes_feature(values):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))


def int64_feature(values):
    if not isinstance(values, (tuple, list)):
        values = [values]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))


def image_to_tfexample(image_data, label0, label1, label2, label3):
    return tf.train.Example(features=tf.train.Features(feature={
                'image': bytes_feature(image_data),
                'label0': int64_feature(label0),
                'label1': int64_feature(label1),
                'label2': int64_feature(label2),
                'label3': int64_feature(label3)
            }))
    

def tfrecord_pack(type, files, output_dir, shard=False, shard_size=5000):
    assert type in ['train', 'test']
    
    nr_images = len(files)
    if shard == True:
        nr_shards = (nr_images + shard_size - 1) // shard_size
    else:
        nr_shards = 1
        shard_size = nr_images
    
    for shard in range(nr_shards):
        # output
        output = 'Verify_image_%s_%02d-of-%02d.tfrecord' % (type, shard + 1, nr_shards)
        output = os.path.join(output_dir, output)
        
        with tf.python_io.TFRecordWriter(output) as writer:
            # start and end index of a shard
            start_idx = shard * shard_size
            end_idx = min((shard + 1) * shard_size, nr_images)

            for i in range(start_idx, end_idx):
                try:
                    sys.stdout.write('\r>> Converting image %d/%d shard %d' % (i + 1, nr_images, shard))
                    sys.stdout.flush()
                    
                    image_data = Image.open(files[i])
                    image_data = image_data.resize((224, 224))
                    image_data = np.array(image_data.convert('L'))
                    image_data = image_data.tobytes()
                    
                    verify_code = os.path.basename(files[i])[0:4]
                    labels = [ int(c) for c in verify_code ]

                    tfexample = image_to_tfexample(image_data, labels[0], labels[1], labels[2], labels[3])
                    writer.write(tfexample.SerializeToString())
                except IOError as e:
                    print('Error: skip it')
        
        sys.stdout.write('\t[Done]\n')
        sys.stdout.flush()

In [40]:
digits = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
images_dir = 'captcha/images'
tfrecord_dir = 'captcha/tfrecord'
epoch = 10000
test_rate = .3

if not os.path.exists(images_dir):
    os.makedirs(images_dir)
    print('Generating verify code images...')
    sys.stdout.flush()
    for i in range(epoch):
        sys.stdout.write('\r>>Generating verify code image %d/%d' % (i + 1, epoch))
        sys.stdout.flush()
        try:
            gen_verify_code(digits, images_dir)
        except:
            print('Error: Generate image. Skip it.')
    sys.stdout.write('\t\t[Done]\n')
    sys.stdout.flush()
else:
    print('Existing verify code images.')

if not os.path.exists(tfrecord_dir):
    os.makedirs(tfrecord_dir)
    files = os.listdir(images_dir)
    images_filepath = [ os.path.join(images_dir, f) for f in files ]
    
    nr_test_images = int(len(images_filepath) * 0.3)
    random.shuffle(images_filepath)
    train_images = images_filepath[nr_test_images:]
    test_images = images_filepath[:nr_test_images]

    print('Converting train images to TFRecord format...')
    tfrecord_pack('train', train_images, tfrecord_dir)
    print('Converting test images to TFRecord format...')
    tfrecord_pack('test', test_images, tfrecord_dir)
else:
    print('Existing TFRecord files.')

Existing verify code images.
Converting train images to TFRecord format...
>> Converting image 6564/6564 shard 0			[Done]
Done.
Converting test images to TFRecord format...
>> Converting image 2813/2813 shard 0			[Done]
Done.
