#### 功能说明：
+ 遍历原始数据，生成训练集txt文件和测试集txt文件
+ txt文件中一行为一个图片的数据，数据以空格分割，第一个元素是图片路径，后面元素依次是每一个物体的xmin，ymin，xmax和ymax

#### 原数据格式要求：
+ 数据存放在ROOT_DATA_DIR指向的文件夹中，每次收集的数据以日期命名为一个文件夹，标注数据为日期命名的xml文件，另外有一个ipynb文件用于对数据进行预处理

In [6]:
import os
import xml.etree.ElementTree as ET 
import struct
import numpy as np
import cv2
import random
from matplotlib import pyplot as plt
%matplotlib inline
from IPython.core.debugger import Tracer

#### 可配置的参数：
+ 原数据文件夹
+ 测试集数量占数据总量的比例
+ 训练集和测试集txt文件的名称

In [7]:
ROOT_DATA_DIR = "data_doc"
test_ratio = 0.1
train_file_name = "doc_train.txt"
test_file_name = "doc_test.txt"

In [21]:
def parse_xml(image_dir, xml_file):
    if os.path.isfile(xml_file) == False:
        return [], []
    tree = ET.parse(xml_file)
    root = tree.getroot()

    image_paths = []
    labels = []

    images = root.find("images")
    images = images.findall("image")

    for image in images:
        image_name = image.get('file').split('/')[-1]
        image_path = os.path.join(image_dir, image_name)
        cur_img = cv2.imread(image_path)
        if cur_img == None:
            print("wrong img name: " + xml_file)
            continue
        img_height, img_width, _ = cur_img.shape

        image_labels = []
        boxes = image.findall('box')
        if len(boxes) == 0:
            continue
        for box in boxes:
            obj_label = box.find('label')
            if obj_label.text == "bar":
                continue
            top = int(box.get('top'))
            left = int(box.get('left'))
            width = int(box.get('width'))
            height = int(box.get('height'))

            xmin = np.max([left, 0])
            xmax = np.min([left + width, img_width])
            ymin = np.max([top, 0])
            ymax = np.min([top + height, img_height])

            # all the objects are class 0
            image_labels.append([xmin, ymin, xmax, ymax, 0])
        
        if len(image_labels) == 0:
            continue

        image_paths.append(os.path.join(image_dir, image_name))
        labels.append(image_labels)
    return image_paths, labels

In [22]:
image_paths = []
labels = []

data_dirs = os.listdir(ROOT_DATA_DIR)

for data_dir in data_dirs:    
    data_dir_path = os.path.join(ROOT_DATA_DIR, data_dir)
    label_path = data_dir_path + ".xml"
    
    if os.path.isfile(label_path) == False:
        continue
    cur_paths, cur_labels = parse_xml(data_dir_path, label_path)
    image_paths.extend(cur_paths)
    labels.extend(cur_labels)
#     break



In [23]:
# convert records to strings
def convert_to_string(image_path, labels):
    """convert image_path, lables to string 
    Returns:
    string 
    """
    out_string = ''
    out_string += image_path
    for label in labels:
        for i in label:
            out_string += ' ' + str(i)
    out_string += '\n'
    return out_string

records = []
for idx, image_path in enumerate(image_paths):
    label = labels[idx]
    records.append(convert_to_string(image_path, label))

random.shuffle(records)
print(len(records))

119


In [24]:
# split into training set and test set
total_num = len(records)
test_num = int(test_ratio * total_num)
train_num = total_num - test_num
train_records = records[0:train_num]
test_records = records[train_num:]

In [25]:
# save to text file
train_out_file = open(train_file_name, "w")
for record in train_records:
    train_out_file.write(record)
train_out_file.close()
test_out_file = open(test_file_name, "w")
for record in test_records:
    test_out_file.write(record)
test_out_file.close()