In [3]:
import os
import xml.etree.ElementTree as ET
from tqdm import tqdm
import shutil
import cv2
import numpy as np
import random

random.seed(123)

In [4]:
dataset_root = '/home/boyan/sf_files/datasets/chenyang_dataset/0530/'
images_dirname = 'images'
labels_dirname = 'xml'

output_dirname = 'chenyang_0530'
images_dst_dirname = 'images'
labels_dst_dirname = 'labels'
allowed_formats = ['.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG', '.bmp', '.BMP']

In [5]:
images_raw_path = os.path.join(dataset_root, images_dirname)
labels_raw_path = os.path.join(dataset_root, labels_dirname)

output_dataset_root = os.path.join(dataset_root, output_dirname)
images_dst_path = os.path.join(output_dataset_root, images_dst_dirname)
labels_dst_path = os.path.join(output_dataset_root, labels_dst_dirname)

rotated_xml_path = os.path.join(output_dataset_root, 'rotated_xml')
rotated_images_path = os.path.join(output_dataset_root, 'rotated_images')

all_images_path = os.path.join(output_dataset_root, 'all_images')
all_xmls_path = os.path.join(output_dataset_root, 'all_xmls')
all_labels_path = os.path.join(output_dataset_root, 'all_labels')

if os.path.exists(output_dataset_root):
    shutil.rmtree(output_dataset_root)
os.mkdir(output_dataset_root)
os.mkdir(images_dst_path)
os.mkdir(labels_dst_path)
os.mkdir(all_images_path)
os.mkdir(all_xmls_path)
os.mkdir(all_labels_path)
os.mkdir(rotated_xml_path)
os.mkdir(rotated_images_path)
    

In [6]:
xml_files = [i for i in os.listdir(labels_raw_path) if os.path.splitext(i)[-1] == '.xml']

In [7]:
jpg_files = [i for i in os.listdir(images_raw_path) if os.path.splitext(i)[-1] in allowed_formats]

In [8]:
print('xml_files: ', len(xml_files), 'jpg_files: ', len(jpg_files))

xml_files:  349 jpg_files:  353


In [9]:
do_retate = True
fail_images = []
if do_retate:
    for xml_file in tqdm(xml_files):
        with open(os.path.join(labels_raw_path, xml_file), encoding='utf-8') as f:
            tree = ET.parse(f)
            # root = ET.fromstring(f.read().decode('utf-8'))
            root = tree.getroot()

            size = root.find('size')
            img_w = float(size.find('width').text)
            img_h = float(size.find('height').text)

            size.find('width').text = str(round(img_h))
            size.find('height').text = str(round(img_w))


            file_name = os.path.splitext(os.path.split(xml_file)[-1])[0]
            image_format = ''
            for allowed_format in allowed_formats:
                if file_name + allowed_format in jpg_files:
                    image_format = allowed_format
            if image_format == '':
                print('xml_file: ', xml_file)
                print('file_name: ', file_name)
                print('os.path.join(images_raw_path, file_name + image_format): ', os.path.join(images_raw_path, file_name + image_format))
            img = cv2.imdecode(np.fromfile(os.path.join(images_raw_path, file_name + image_format), dtype=np.uint8), -1)
            #img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
            #img = cv2.imread(os.path.join(images_raw_path, file_name + image_format))
            
            rotated_img = np.rot90(img)

            direction = 90
            if random.random() < 0.5:
                rotated_img = np.rot90(rotated_img)
                rotated_img = np.rot90(rotated_img)
                direction = 270

            cv2.imencode(image_format, rotated_img)[1].tofile(os.path.join(rotated_images_path, file_name + '_rotated' + image_format))
            #cv2.imwrite(os.path.join(rotated_images_path, file_name + '_rotated' + image_format), rotated_img)

            for obj in root.iter('object'):
                bbox = obj.find('bndbox')
                xmin = float(bbox.find('xmin').text)
                xmax = float(bbox.find('xmax').text)
                ymin = float(bbox.find('ymin').text)
                ymax = float(bbox.find('ymax').text)

                if direction == 90:
                    bbox.find('xmin').text = str(round(ymin))
                    bbox.find('xmax').text = str(round(ymax))
                    bbox.find('ymin').text = str(img_w - round(xmax) - 1)
                    bbox.find('ymax').text = str(img_w - round(xmin) - 1)
                elif direction == 270:
                    bbox.find('xmin').text = str(img_h - round(ymax) - 1)
                    bbox.find('xmax').text = str(img_h - round(ymin) - 1)
                    bbox.find('ymin').text = str(round(xmin))
                    bbox.find('ymax').text = str(round(xmax))
                else:
                    pass

            tree.write(os.path.join(rotated_xml_path, file_name + '_rotated' + '.xml'))

100%|██████████| 349/349 [00:46<00:00,  7.46it/s]


In [10]:
keep_class_names = ['0', '1', '2', '3', '4', '5']
# keep_class_names = ['bsl', "yw"]
# keep_class_names = ['gms', 'gmq']


class_names = dict(zip(keep_class_names, range(len(keep_class_names))))
auto_scan = False
class_index = 0
if len(class_names) == 0:
    auto_scan = True
    print('Auto scan classnames enabled.')


for xml_file in tqdm(xml_files):
    
    file_path = os.path.join(labels_raw_path, xml_file)
    tree = ET.parse(file_path)
    root = tree.getroot()
    
    img_size = root.find('size')
    width = float(img_size.find('width').text)
    height = float(img_size.find('height').text)
    
    objs = root.findall('object')
    
    copied = False
    
    with open(os.path.join(all_labels_path, os.path.splitext(os.path.split(xml_file)[-1])[0] + '.txt'), 'w+') as txt:
        if len(objs) == 0:
            file_name = os.path.splitext(os.path.split(xml_file)[-1])[0]
            image_format = ''
            for allowed_format in allowed_formats:
                if file_name + allowed_format in jpg_files:
                    image_format = allowed_format
            shutil.copyfile(src=os.path.join(images_raw_path, file_name + image_format), dst=os.path.join(all_images_path, file_name + image_format))
            copied = True
        for obj in objs:
            if not copied:
                file_name = os.path.splitext(os.path.split(xml_file)[-1])[0]
                image_format = ''
                for allowed_format in allowed_formats:
                    if file_name + allowed_format in jpg_files:
                        image_format = allowed_format
                shutil.copyfile(src=os.path.join(images_raw_path, file_name + image_format), dst=os.path.join(all_images_path, file_name + image_format))
                copied = True
            name = obj.find('name').text.strip()
            if name == 'default' or name == 'S' or name == 'l' or name == 'Z' or name == 'G':
                print(xml_file, name)
            if name not in class_names.keys():
                if auto_scan:
                    print('auto append class name: "' + name + '" from ', xml_file)
                    class_names[name] = class_index
                    class_index = class_index + 1
                else:
                    continue
            bbox = obj.find('bndbox')
            if not bbox:
                print('no robndbox in %s' % (xml_file))
            xmin = float(bbox.find('xmin').text)
            xmax = float(bbox.find('xmax').text)
            ymin = float(bbox.find('ymin').text)
            ymax = float(bbox.find('ymax').text)

            x = (xmin + xmax) / 2 / width
            y = (ymin + ymax) / 2 / height
            w = (xmax - xmin) / width
            h = (ymax - ymin) / height
            
            txt.write('%s %s %s %s %s\n' % (class_names[name], x, y, w, h))
    

100%|██████████| 349/349 [00:01<00:00, 199.85it/s]


In [11]:
rotated_xml_files = [i for i in os.listdir(rotated_xml_path) if os.path.splitext(i)[-1] == '.xml']
rotated_jpg_files = [i for i in os.listdir(rotated_images_path) if os.path.splitext(i)[-1] in allowed_formats]
print('xml_files: ', len(rotated_xml_files), 'jpg_files: ', len(rotated_jpg_files))

xml_files:  349 jpg_files:  349


In [12]:
count = 0
for xml_file in tqdm(rotated_xml_files):
    
    file_path = os.path.join(rotated_xml_path, xml_file)
    tree = ET.parse(file_path)
    root = tree.getroot()
    
    img_size = root.find('size')
    width = float(img_size.find('width').text)
    height = float(img_size.find('height').text)
    
    objs = root.findall('object')
    
    copied = False
    
    with open(os.path.join(all_labels_path, os.path.splitext(os.path.split(xml_file)[-1])[0] + '.txt'), 'w+') as txt:
        if len(objs) == 0:
            file_name = os.path.splitext(os.path.split(xml_file)[-1])[0]
            image_format = ''
            for allowed_format in allowed_formats:
                if file_name + allowed_format in rotated_jpg_files:
                    image_format = allowed_format
            shutil.copyfile(src=os.path.join(rotated_images_path, file_name + image_format), dst=os.path.join(all_images_path, file_name + image_format))
            copied = True
        for obj in objs:
            if not copied:
                file_name = os.path.splitext(os.path.split(xml_file)[-1])[0]
                image_format = ''
                for allowed_format in allowed_formats:
                    if file_name + allowed_format in rotated_jpg_files:
                        image_format = allowed_format
                shutil.copyfile(src=os.path.join(rotated_images_path, file_name + image_format), dst=os.path.join(all_images_path, file_name + image_format))
                copied = True
                count = count + 1
            name = obj.find('name').text.strip()
            if name not in class_names.keys():
                continue
            bbox = obj.find('bndbox')
            if not bbox:
                print('no robndbox in %s' % (xml_file))
            xmin = float(bbox.find('xmin').text)
            xmax = float(bbox.find('xmax').text)
            ymin = float(bbox.find('ymin').text)
            ymax = float(bbox.find('ymax').text)

            x = (xmin + xmax) / 2 / width
            y = (ymin + ymax) / 2 / height
            w = (xmax - xmin) / width
            h = (ymax - ymin) / height
            
            txt.write('%s %s %s %s %s\n' % (class_names[name], x, y, w, h))
    

100%|██████████| 349/349 [00:00<00:00, 552.46it/s]


In [13]:
images_train = os.path.join(images_dst_path, 'train')
images_val = os.path.join(images_dst_path, 'val')

In [14]:
labels_train = os.path.join(labels_dst_path, 'train')
labels_val = os.path.join(labels_dst_path, 'val')

In [15]:
for d in [images_train, images_val, labels_train, labels_val]:
    if not os.path.exists(d):
        os.mkdir(d)

In [16]:
txt_files = [i for i in os.listdir(all_labels_path) if os.path.splitext(i)[-1] == '.txt']
all_jpg_files = [i for i in os.listdir(all_images_path) if os.path.splitext(i)[-1] in allowed_formats]

In [17]:
print('txt_files: ', len(txt_files))

txt_files:  698


In [18]:
print('before shuffle: ', txt_files[:5])

before shuffle:  ['202305170910_36_698_rotated.txt', '202305230029_37_575.txt', '202305170235_44_807.txt', '202305170800_06_662_rotated.txt', '202305221559_44_269_rotated.txt']


In [19]:
random.shuffle(txt_files)

In [20]:
print('after shuffle: ', txt_files[:5])

after shuffle:  ['202305221701_15_780.txt', '202305221404_44_041.txt', '202305221559_47_379_rotated.txt', '202305170235_05_250.txt', '202305221638_28_201.txt']


In [21]:
train_factor = 0.9

In [22]:
train_txt_files = txt_files[:int(len(txt_files)*train_factor)]
val_txt_files = txt_files[int(len(txt_files)*train_factor):]

In [23]:
print('train count: ', len(train_txt_files))
print('val count: ', len(val_txt_files))

train count:  628
val count:  70


In [24]:
for i in tqdm(train_txt_files):
    file_name = os.path.splitext(os.path.split(i)[-1])[0]
    
    image_format = '.bmp'
    for allowed_format in allowed_formats:
        if file_name + allowed_format in all_jpg_files:
            image_format = allowed_format
    
    shutil.copyfile(src=os.path.join(all_images_path, file_name + image_format), dst=os.path.join(images_train, file_name + image_format))
    shutil.copyfile(src=os.path.join(all_labels_path, i), dst=os.path.join(labels_train, i))

100%|██████████| 628/628 [00:02<00:00, 308.18it/s]


In [25]:
for i in tqdm(val_txt_files):
    file_name = os.path.splitext(os.path.split(i)[-1])[0]
    image_format = '.bmp'
    for allowed_format in allowed_formats:
        if file_name + allowed_format in all_jpg_files:
            image_format = allowed_format
    shutil.copyfile(src=os.path.join(all_images_path, file_name + image_format), dst=os.path.join(images_val, file_name + image_format))
    shutil.copyfile(src=os.path.join(all_labels_path, i), dst=os.path.join(labels_val, i))

100%|██████████| 70/70 [00:00<00:00, 301.27it/s]


In [26]:
for d in [rotated_xml_path, rotated_images_path, all_images_path, all_xmls_path, all_labels_path]:
    if os.path.exists(d):
        shutil.rmtree(d)

In [27]:
with open(os.path.join(output_dataset_root, 'data.yaml'), 'w+') as yaml_file:
    yaml_file.write('path: ')
    yaml_file.write(os.path.abspath(output_dataset_root))
    yaml_file.write('\n')
    
    yaml_file.write('train: %s\n' % (os.path.join('images', 'train')))
    yaml_file.write('val: %s\n' % (os.path.join('images', 'val')))
    yaml_file.write('test: %s\n' % (os.path.join('images', 'val')))
    
    yaml_file.write('nc: %d\n' % (len(class_names)))
    
    yaml_file.write('names: [')
    for i in range(len(class_names)):
        for k, v in class_names.items():
            if v == i:
                if i+1 != len(class_names):
                    yaml_file.write('"%s", ' % (k))
                else:
                    yaml_file.write('"%s"]' % (k))
    

In [28]:
class_names

{'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5}

In [29]:
names = list(class_names.keys())

In [30]:
'", "'.join(names)

'0", "1", "2", "3", "4", "5'

In [31]:
len(class_names)

6

In [32]:
for i, name in enumerate(names):
    print('id2Name.insert(std::pair<int, std::string>(%d, "%s"));' % (i, name))

id2Name.insert(std::pair<int, std::string>(0, "0"));
id2Name.insert(std::pair<int, std::string>(1, "1"));
id2Name.insert(std::pair<int, std::string>(2, "2"));
id2Name.insert(std::pair<int, std::string>(3, "3"));
id2Name.insert(std::pair<int, std::string>(4, "4"));
id2Name.insert(std::pair<int, std::string>(5, "5"));
