In [1]:
import os
import cv2
import xml.etree.ElementTree as ET
import numpy as np
import random
from matplotlib import pyplot as plt
%matplotlib inline

In [2]:
root_dir = "20171001/images"
test_ratio = 0.1
train_file_name = "head_train.txt"
test_file_name = "head_test.txt"

In [3]:
sub_dirs = os.listdir(root_dir)

In [4]:
# one class: head

In [9]:
def parse_xml(image_dir, xml_file):
    if os.path.isfile(xml_file) == False:
        return [], []
    tree = ET.parse(xml_file)
    root = tree.getroot()

    image_paths = []
    labels = []

    images = root.find("images")
    images = images.findall("image")

    for image in images:
        image_name = image.get('file').split('\\')[1]
        image_path = os.path.join(image_dir, image_name)
        cur_img = cv2.imread(image_path)
        if cur_img is None:
            print("wrong img name: " + xml_file)
            continue
        img_height, img_width, _ = cur_img.shape

        image_labels = []
        boxes = image.findall('box')
        if len(boxes) == 0:
            continue
        for box in boxes:
            obj_label = box.find('label')
            if obj_label.text.endswith("head"):
                klass = 0
            else:
                continue
            top = int(box.get('top'))
            left = int(box.get('left'))
            width = int(box.get('width'))
            height = int(box.get('height'))

            xmin = np.max([left, 0])
            xmax = np.min([left + width, img_width])
            ymin = np.max([top, 0])
            ymax = np.min([top + height, img_height])

            # all the objects are class 0
            image_labels.append([xmin, ymin, xmax, ymax, klass])
        if len(image_labels) == 0:
            continue
        image_paths.append(os.path.join(image_dir, image_name))
        labels.append(image_labels)
    return image_paths, labels

In [10]:
image_paths = []
labels = []
for sub_dir in sub_dirs:
    sub_dir_path = os.path.join(root_dir, sub_dir)
    segment_dirs = os.listdir(sub_dir_path)
    for segment_dir in segment_dirs:
        segment_dir_path = os.path.join(sub_dir_path, segment_dir)
        print(segment_dir_path)
        if os.path.isdir(segment_dir_path) == False:
            continue
        files = os.listdir(segment_dir_path)
        for file in files:
            if file.endswith("xml") != True:
                continue
            xml_path = os.path.join(segment_dir_path, file)
            img_dir_path = os.path.join(segment_dir_path, file.split('.')[0])
            
            cur_paths, cur_labels = parse_xml(img_dir_path, xml_path)
            image_paths.extend(cur_paths)
            labels.extend(cur_labels)
#         break
#     break

20171001/images/1/005e3abe-df6b-42d9-9aca-1798f9167498
20171001/images/1/3ea9a492-0369-4fff-90d3-f453c059dddd
20171001/images/1/010c8c75-9d84-40ca-9b3e-e5dbbd3e9886
20171001/images/1/e0f7d7c7-ac41-4f12-b873-7350ce011572
20171001/images/1/02001955-28e2-44c6-8a16-300195bf7cca
20171001/images/1/cc4a3ec4-e0c1-403f-8ab0-1aa79e7a6899
20171001/images/1/b0bc439e-558e-4cf1-8826-5b93dca150e5
20171001/images/1/0f6cb1bb-cbba-4448-a602-69ce7f23803e
20171001/images/1/2fed3667-45bd-4eec-bcb3-d4c170ba6c17
20171001/images/1/b769b2f9-8da6-420b-9207-23112a0a3bb0
20171001/images/1/1aeaefd1-8231-41d3-852c-2a922f80390e
20171001/images/1/4a57fd94-43ab-44c2-9941-2f4fc536c2a9
20171001/images/1/225040bc-89be-4979-93c2-22f06d06cad4
20171001/images/1/f5f87698-275c-465f-a33e-e31aebf85a90
20171001/images/1/1a11181e-cd3a-4c9b-931a-1000450feadc
20171001/images/1/d7fd29cd-7308-49d1-9d4f-8e78333b7491
20171001/images/1/1c7a50d9-26cf-4671-a616-a3063197cb14
20171001/images/1/2b104625-2b96-4a0c-95fd-1567b375fd30
20171001/i

In [11]:
# convert records to strings
def convert_to_string(image_path, labels):
    """convert image_path, lables to string 
    Returns:
    string 
    """
    out_string = ''
    out_string += image_path
    for label in labels:
        for i in label:
            out_string += ' ' + str(i)
    out_string += '\n'
    return out_string

records = []
for idx, image_path in enumerate(image_paths):
    label = labels[idx]
    records.append(convert_to_string(image_path, label))

random.shuffle(records)
print(len(records))

6990


In [12]:
# split into training set and test set
total_num = len(records)
test_num = int(test_ratio * total_num)
train_num = total_num - test_num
train_records = records[0:train_num]
test_records = records[train_num:]

In [13]:
# save to text file
train_out_file = open(train_file_name, "w")
for record in train_records:
    train_out_file.write(record)
train_out_file.close()
test_out_file = open(test_file_name, "w")
for record in test_records:
    test_out_file.write(record)
test_out_file.close()