## 数据格式转化脚本

#### 将voc目标检测数据转换为TFRecord格式，方便TensorFlow读取

In [1]:
import os
import sys
import random
from math import ceil
import tensorflow as tf
import xml.etree.ElementTree as ET

TFR_NAME = './TFR_Data/hir2019'
IMAGE_PATH = './HIR2019/JPEGImages'
ANNOTATION_PATH = './HIR2019/Annotations'
SAMPLES_PER_FILES = 500
HIR_LABELS = {
    'none': (0, 'Background'),
    'scratch': (1, 'Front'),
}
tfr_dir = os.path.split(TFR_NAME)[0]
if not os.path.exists(tfr_dir):
    os.makedirs(tfr_dir)
if not os.path.exists(IMAGE_PATH):
    raise BaseException('file {} is not exists'.format(IMAGE_PATH))
file_names = sorted(os.listdir(IMAGE_PATH))
random.seed = 10
random.shuffle(file_names)
sys.stdout.write('Number of images is {}'.format(len(file_names)))

Number of images is 1500

In [5]:
def int64_feature(value):
    """Wrapper for inserting int64 features into Example proto.
    """
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
 
def float_feature(value):
    """Wrapper for inserting float features into Example proto.
    """
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))
 
def bytes_feature(value):
    """Wrapper for inserting bytes features into Example proto.
    """
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

import re
import linecache
from PIL import Image

def annotation_parse(image_name):
#     tree = ET.parse(xml_path)
#     root = tree.getroot()
    # Image shape.
    im = Image.open(IMAGE_PATH+'/'+image_name+'.jpg')
    shape = [im.size[1], im.size[0], 3]
#     size = root.find('size')
#     shape = [int(size.find('height').text),
#              int(size.find('width').text),
#              int(size.find('depth').text)]
    # Find annotations.
    line_num = int(re.sub('\D','',image_name))
    line = linecache.getline('./HIR2019/train_labels.txt', line_num)
    
    bboxes = []
    labels = []
#     if 'OK' in line:
#         return shape, bboxes, labels
    if 'NG' in line:
        line = line[19:-33]
        for obj in line.split('],['):
            labels.append(HIR_LABELS['scratch'][0])
            
            ymin = obj.split(',')[1]
            xmin = obj.split(',')[0]
            ymax = obj.split(',')[3]
            xmax = obj.split(',')[2]
            bboxes.append((float(ymin) / shape[0],
                           float(xmin) / shape[1],
                           float(ymax) / shape[0],
                           float(xmax) / shape[1]
                           ))
#             bbox = obj.find('bndbox')
#             bboxes.append((float(bbox.find('ymin').text) / shape[0],
#                            float(bbox.find('xmin').text) / shape[1],
#                            float(bbox.find('ymax').text) / shape[0],
#                            float(bbox.find('xmax').text) / shape[1]
#                            ))
        return shape, bboxes, labels
    else:
        print('error')
        exit(0)

In [6]:
num_tfr = ceil(len(file_names)/SAMPLES_PER_FILES)
i = 0
for idx in range(num_tfr):
    tfr_file = '{}_{:03d}.tfrecord'.format(TFR_NAME, idx)
    sys.stdout.write("Writing file '{}'......\n".format(tfr_file))
    # 建立书写器
    with tf.python_io.TFRecordWriter(tfr_file) as writer:
        while i < SAMPLES_PER_FILES * (idx + 1) and i < len(file_names):
            image_file = os.path.join(file_names[i].strip('.jpg'))
            
            line_num = int(image_file[-4:])
            line = linecache.getline('./HIR2019/train_labels.txt', line_num)
            i += 1
            if 'NG' in line:
                _, box, label = annotation_parse(image_file)
                image_file = os.path.join(IMAGE_PATH, file_names[i])
                image_data = tf.gfile.FastGFile(image_file, 'rb').read()

                xmin, ymin, xmax, ymax = ([] for _ in range(4))
                for b in box:
                    assert len(b) == 4
                    [coord.append(point) for coord, point in zip([ymin, xmin, ymax, xmax], b)]
                image_format = b'JPEG'
                # 建立example
                example = tf.train.Example(features=tf.train.Features(feature={
                        'image/object/bbox/xmin': float_feature(xmin),
                        'image/object/bbox/xmax': float_feature(xmax),
                        'image/object/bbox/ymin': float_feature(ymin),
                        'image/object/bbox/ymax': float_feature(ymax),
                        'image/object/bbox/label': int64_feature(label),
                        'image/format': bytes_feature(image_format),  # 图像编码格式
                        'image/encoded': bytes_feature(image_data)}))  # 二进制图像数据
                # 书写入文件
                writer.write(example.SerializeToString())

Writing file './TFR_Data/hir2019_000.tfrecord'......
Writing file './TFR_Data/hir2019_001.tfrecord'......
Writing file './TFR_Data/hir2019_002.tfrecord'......


In [4]:
import tensorflow.contrib.slim as slim

def get_split(tfr_path, tfr_pattren, num_classes=2):
    
    # ===============TFR文件名匹配模板===============
    tfr_pattren = os.path.join(tfr_path, tfr_pattren)
    
    # =========阅读器=========
    reader = tf.TFRecordReader()
    
    # ===================解码器===================
    keys_to_features = {  # 解码TFR文件方式
        'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64),
    }
    items_to_handlers = {  # 解码二进制数据
        # 图像解码设置蛮有意思的
        'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
        'object/bbox': slim.tfexample_decoder.BoundingBox(
            ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'),
        'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'),
    }
    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
    
    # =======描述字段=======
    items_to_descriptions={
        'image': 'A color image of varying height and width.',
        'shape': 'Shape of the image',
        'object/bbox': 'A list of bounding boxes, one per each object.',
        'object/label': 'A list of labels, one per each object.',
    }
    
    return slim.dataset.Dataset(
            data_sources=tfr_pattren,                     # TFR文件名
            reader=reader,                                # 阅读器
            decoder=decoder,                              # 解码器
            num_samples=len(file_names),       # 数目
            items_to_descriptions=items_to_descriptions,  # decoder条目描述字段
            num_classes=num_classes,                      # 类别数
            labels_to_names=None                          # 字典{图片:类别,……}
    )

pattren = 'hir2019_*.tfrecord'
dataset = get_split(tfr_dir, pattren, num_classes=2)
provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,  # DatasetDataProvider 需要 slim.dataset.Dataset 做参数
        num_readers=2,
        common_queue_capacity=20 * 5,
        common_queue_min=10 * 5,
        shuffle=True)
image, glabels, gbboxes = provider.get(['image',
                                        'object/label',
                                        'object/bbox'])
image, glabels, gbboxes

TypeError: 'TFRecordReader' object is not callable

In [None]:
with tf.Session() as sess:
    init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
    sess.run(init_op)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    print(sess.run([glabels, gbboxes]))
    coord.request_stop()
    coord.join(threads)

In [None]:
class SSD:
    def __init__(self):
        pass
    def inference(self):
        pass