## 数据格式转化脚本

#### 将voc目标检测数据转换为TFRecord格式，方便TensorFlow读取

In [1]:
import os
import sys
import random
from math import ceil
import tensorflow as tf
import xml.etree.ElementTree as ET

TFR_NAME = './TFR_Data/voc2012'
IMAGE_PATH = './VOC2012/JPEGImages'
ANNOTATION_PATH = './VOC2012/Annotations'
SAMPLES_PER_FILES = 2000
VOC_LABELS = {
    'none': (0, 'Background'),
    'aeroplane': (1, 'Vehicle'),
    'bicycle': (2, 'Vehicle'),
    'bird': (3, 'Animal'),
    'boat': (4, 'Vehicle'),
    'bottle': (5, 'Indoor'),
    'bus': (6, 'Vehicle'),
    'car': (7, 'Vehicle'),
    'cat': (8, 'Animal'),
    'chair': (9, 'Indoor'),
    'cow': (10, 'Animal'),
    'diningtable': (11, 'Indoor'),
    'dog': (12, 'Animal'),
    'horse': (13, 'Animal'),
    'motorbike': (14, 'Vehicle'),
    'person': (15, 'Person'),
    'pottedplant': (16, 'Indoor'),
    'sheep': (17, 'Animal'),
    'sofa': (18, 'Indoor'),
    'train': (19, 'Vehicle'),
    'tvmonitor': (20, 'Indoor'),
}
tfr_dir = os.path.split(TFR_NAME)[0]
if not os.path.exists(tfr_dir):
    os.makedirs(tfr_dir)
if not os.path.exists(IMAGE_PATH):
    raise BaseException('file {} is not exists'.format(IMAGE_PATH))
file_names = sorted(os.listdir(IMAGE_PATH))
random.seed = 10
random.shuffle(file_names)
sys.stdout.write('Number of images is {}'.format(len(file_names)))

  return f(*args, **kwds)


Number of images is 17125

In [2]:
def int64_feature(value):
    """Wrapper for inserting int64 features into Example proto.
    """
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
 
def float_feature(value):
    """Wrapper for inserting float features into Example proto.
    """
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))
 
def bytes_feature(value):
    """Wrapper for inserting bytes features into Example proto.
    """
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

def xml_parse(xml_path):
    tree = ET.parse(xml_path)
    root = tree.getroot()
    # Image shape.
    size = root.find('size')
    shape = [int(size.find('height').text),
             int(size.find('width').text),
             int(size.find('depth').text)]
    # Find annotations.
    bboxes = []
    labels = []
    labels_text = []
    difficult = []
    truncated = []
    for obj in root.findall('object'):
        label = obj.find('name').text
        labels.append(VOC_LABELS[label][0])
 
        if obj.find('difficult'):
            difficult.append(int(obj.find('difficult').text))
        else:
            difficult.append(0)
        if obj.find('truncated'):
            truncated.append(int(obj.find('truncated').text))
        else:
            truncated.append(0)
        bbox = obj.find('bndbox')
        bboxes.append((float(bbox.find('ymin').text) / shape[0],
                       float(bbox.find('xmin').text) / shape[1],
                       float(bbox.find('ymax').text) / shape[0],
                       float(bbox.find('xmax').text) / shape[1]
                       ))
    return shape, bboxes, labels, labels_text, difficult, truncated

In [3]:
num_tfr = ceil(len(file_names)/SAMPLES_PER_FILES)
i = 0
for idx in range(num_tfr):
    tfr_file = '{}_{:03d}.tfrecord'.format(TFR_NAME, idx)
    sys.stdout.write("Writing file '{}'......\n".format(tfr_file))
    # 建立书写器
    with tf.python_io.TFRecordWriter(tfr_file) as writer:
        while i < SAMPLES_PER_FILES * (idx + 1) and i < len(file_names):
            xml_file = os.path.join(ANNOTATION_PATH, 
                                file_names[i].strip('.jpg') + '.xml')
            image_file = os.path.join(IMAGE_PATH, file_names[i])
            _, box, label, _, _, _ = xml_parse(xml_file)
            image_data = tf.gfile.FastGFile(image_file, 'rb').read()
            i += 1
            
            xmin, ymin, xmax, ymax = ([] for _ in range(4))
            for b in box:
                assert len(b) == 4
                [coord.append(point) for coord, point in zip([ymin, xmin, ymax, xmax], b)]
            image_format = b'JPEG'
            # 建立example
            example = tf.train.Example(features=tf.train.Features(feature={
                    'image/object/bbox/xmin': float_feature(xmin),
                    'image/object/bbox/xmax': float_feature(xmax),
                    'image/object/bbox/ymin': float_feature(ymin),
                    'image/object/bbox/ymax': float_feature(ymax),
                    'image/object/bbox/label': int64_feature(label),
                    'image/format': bytes_feature(image_format),  # 图像编码格式
                    'image/encoded': bytes_feature(image_data)}))  # 二进制图像数据
            # 书写入文件
            writer.write(example.SerializeToString())

Writing file './TFR_Data/voc2012_000.tfrecord'......
Writing file './TFR_Data/voc2012_001.tfrecord'......
Writing file './TFR_Data/voc2012_002.tfrecord'......
Writing file './TFR_Data/voc2012_003.tfrecord'......
Writing file './TFR_Data/voc2012_004.tfrecord'......
Writing file './TFR_Data/voc2012_005.tfrecord'......
Writing file './TFR_Data/voc2012_006.tfrecord'......
Writing file './TFR_Data/voc2012_007.tfrecord'......
Writing file './TFR_Data/voc2012_008.tfrecord'......


In [6]:
import tensorflow.contrib.slim as slim

def get_split(tfr_path, tfr_pattren, num_classes=21):
    
    # ===============TFR文件名匹配模板===============
    tfr_pattren = os.path.join(tfr_path, tfr_pattren)
    
    # =========阅读器=========
    reader = tf.TFRecordReader
    
    # ===================解码器===================
    keys_to_features = {  # 解码TFR文件方式
        'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64),
    }
    items_to_handlers = {  # 解码二进制数据
        # 图像解码设置蛮有意思的
        'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
        'object/bbox': slim.tfexample_decoder.BoundingBox(
            ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'),
        'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'),
    }
    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
    
    # =======描述字段=======
    items_to_descriptions={
        'image': 'A color image of varying height and width.',
        'shape': 'Shape of the image',
        'object/bbox': 'A list of bounding boxes, one per each object.',
        'object/label': 'A list of labels, one per each object.',
    }
    
    return slim.dataset.Dataset(
            data_sources=tfr_pattren,                     # TFR文件名
            reader=reader,                                # 阅读器
            decoder=decoder,                              # 解码器
            num_samples=len(file_names),       # 数目
            items_to_descriptions=items_to_descriptions,  # decoder条目描述字段
            num_classes=num_classes,                      # 类别数
            labels_to_names=None                          # 字典{图片:类别,……}
    )

pattren = 'voc2012_*.tfrecord'
dataset = get_split(tfr_dir, pattren, num_classes=21)
provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,  # DatasetDataProvider 需要 slim.dataset.Dataset 做参数
        num_readers=2,
        common_queue_capacity=20 * 5,
        common_queue_min=10 * 5,
        shuffle=True)
image, glabels, gbboxes = provider.get(['image',
                                        'object/label',
                                        'object/bbox'])
image, glabels, gbboxes

(<tf.Tensor 'case_1/If_1/Merge:0' shape=(?, ?, 3) dtype=uint8>,
 <tf.Tensor 'SparseToDense_1:0' shape=(?,) dtype=int64>,
 <tf.Tensor 'transpose_1:0' shape=(?, 4) dtype=float32>)

In [5]:
with tf.Session() as sess:
    init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
    sess.run(init_op)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    print(sess.run([glabels, gbboxes]))
    coord.request_stop()
    coord.join(threads)

[array([15]), array([[ 0.40266666,  0.252     ,  1.        ,  0.76999998]], dtype=float32)]


In [None]:
class SSD:
    def __init__(self):
        pass
    def inference(self):
        pass