#### 功能说明：
+ 遍历原始数据，生成训练集txt文件和测试集txt文件
+ txt文件中一行为一个图片的数据，数据以空格分割，第一个元素是图片路径，后面元素依次是每一个物体的xmin，ymin，xmax和ymax

#### 原数据格式要求：
+ 数据存放在ROOT_DATA_DIR指向的文件夹中，每次收集的数据以日期命名为一个文件夹
+ 一次数据文件夹内分类编号命名的文件夹，品类数用N表示，其中0到N-1为单独分类的物品，N及N以上为混合拍摄的物品
+ 每个分类文件夹内，包含0-4命名的共5个文件夹，为相应摄像头的图像数据，此外还包括0.xml-4.xml共5个文件，为相应的标注文件
+ 对0到N-1文件夹内的数据，标注的分类包括good，null和bar
+ 对N及N以上的文件夹内的数据，标注的分类为分类编号（0到N-1的数字）
+ 要求进入到这个文件夹的数据，顶部摄像头的图像都必须是裁剪过的，顶部摄像头图像对应的标注文件，都必须是针对裁剪过的图像进行标注的文件

In [2]:
import os
import xml.etree.ElementTree as ET 
import struct
import numpy as np
import cv2
import random
from matplotlib import pyplot as plt
%matplotlib inline
from IPython.core.debugger import Tracer

#### 可配置的参数：
+ 原数据文件夹
+ 测试集数量占数据总量的比例
+ 训练集和测试集txt文件的名称

In [3]:
ROOT_DATA_DIR = "data_cmdt"
test_ratio = 0.1
train_file_name = "cmdt_train.txt"
test_file_name = "cmdt_test.txt"

In [7]:
def parse_xml(image_dir, xml_file):
    if os.path.isfile(xml_file) == False:
        return [], []
    tree = ET.parse(xml_file)
    root = tree.getroot()

    image_paths = []
    labels = []

    images = root.find("images")
    images = images.findall("image")

    for image in images:
        image_name = image.get('file').split('\\')[1]
        image_path = os.path.join(image_dir, image_name)
        cur_img = cv2.imread(image_path)
        if cur_img == None:
            print("wrong img name: " + xml_file)
            continue
        img_height, img_width, _ = cur_img.shape

        image_labels = []
        boxes = image.findall('box')
        if len(boxes) == 0:
#             print("no label data: " + xml_file)
            continue
        for box in boxes:
            obj_label = box.find('label')
#             if obj_label.text not in classes_num.keys():
#                 print("wrong class name: " + xml_file)
#                 continue
#             if obj_label.text not in use_classes_name:
#                 continue
            if obj_label.text == "bar":
                continue
            top = int(box.get('top'))
            left = int(box.get('left'))
            width = int(box.get('width'))
            height = int(box.get('height'))

            xmin = np.max([left, 0])
            xmax = np.min([left + width, img_width])
            ymin = np.max([top, 0])
            ymax = np.min([top + height, img_height])

            # all the objects are class 0
            image_labels.append([xmin, ymin, xmax, ymax, 0])
        
        if len(image_labels) == 0:
            continue

        image_paths.append(os.path.join(image_dir, image_name))
        labels.append(image_labels)
    return image_paths, labels

In [8]:
data_dirs = os.listdir(ROOT_DATA_DIR)

In [9]:
image_paths = []
labels = []
for data_dir in data_dirs:
    data_dir_path = os.path.join(ROOT_DATA_DIR, data_dir)
    class_dirs = os.listdir(data_dir_path)
    for class_dir_idx, class_dir in enumerate(class_dirs):
        print(data_dir + ": " + str(class_idx))
        class_path = os.path.join(data_dir_path, class_dir)
        img_dirs = os.listdir(class_path)
        for img_dir in img_dirs:
            if img_dir.endswith('xml'):
                continue
            img_dir_path = os.path.join(class_path, img_dir)
            label_path = img_dir_path + ".xml"
            cur_paths, cur_labels = parse_xml(img_dir_path, label_path)
            image_paths.extend(cur_paths)
            labels.extend(cur_labels)
#             break
#         break
#     break

0




1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
wrong img name: data_cmdt/20170510/120/4.xml
27
28
29
30
31
32
33
34
35
36
wrong img name: data_cmdt/20170510/162/3.xml
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133


In [10]:
# convert records to strings
def convert_to_string(image_path, labels):
    """convert image_path, lables to string 
    Returns:
    string 
    """
    out_string = ''
    out_string += image_path
    for label in labels:
        for i in label:
            out_string += ' ' + str(i)
    out_string += '\n'
    return out_string

records = []
for idx, image_path in enumerate(image_paths):
    label = labels[idx]
    records.append(convert_to_string(image_path, label))

random.shuffle(records)
print(len(records))

24801


In [11]:
# split into training set and test set
total_num = len(records)
test_num = int(test_ratio * total_num)
train_num = total_num - test_num
train_records = records[0:train_num]
test_records = records[train_num:]

In [12]:
# save to text file
train_out_file = open(train_file_name, "w")
for record in train_records:
    train_out_file.write(record)
train_out_file.close()
test_out_file = open(test_file_name, "w")
for record in test_records:
    test_out_file.write(record)
test_out_file.close()