In [1]:
import os
from glob import glob
import json
from shutil import copy2
from tqdm import tqdm
from sklearn.model_selection import train_test_split

In [2]:
class_names = ['Zebra_Cross', 'R_Signal', 'G_Signal']

In [3]:
DATA_PATH = 'C:/Users/Brad/Documents/Dataset/crosswalk-dataset'

img_list = glob(os.path.join(DATA_PATH, '*', '*.jpg'))

len(img_list)

35244

In [4]:
os.makedirs(os.path.join(DATA_PATH, 'train', 'images'), exist_ok=True)
os.makedirs(os.path.join(DATA_PATH, 'train', 'labels'), exist_ok=True)
os.makedirs(os.path.join(DATA_PATH, 'val', 'images'), exist_ok=True)
os.makedirs(os.path.join(DATA_PATH, 'val', 'labels'), exist_ok=True)

In [5]:
train_img_list, val_img_list = train_test_split(img_list, test_size=0.2, random_state=2022)

len(train_img_list), len(val_img_list)

(28195, 7049)

In [8]:
file_list = []

for img_path in tqdm(train_img_list):
    json_path = img_path.replace('.jpg', '.json')

    with open(json_path, 'r') as f:
        data = json.load(f)

    w = data['imageWidth']
    h = data['imageHeight']
    
    txt = ''
    
    try:
        for shape in data['shapes']:
            label = shape['label']

            x1, y1 = shape['points'][0]
            x2, y2 = shape['points'][1]

            cx = (x1 + x2) / 2. / w
            cy = (y1 + y2) / 2. / h
            bw = (x2 - x1) / w
            bh = (y2 - y1) / h

            label = class_names.index(shape['label'])

            txt += '%d %f %f %f %f\n' % (label, cx, cy, bw, bh)

        copy2(img_path, os.path.join(DATA_PATH, 'train', 'images', os.path.basename(img_path)))

        with open(os.path.join(DATA_PATH, 'train', 'labels', os.path.basename(json_path).replace('.json', '.txt')), 'w') as f:
            f.write(txt)
        
        file_list.append(os.path.join('train', 'images', os.path.basename(img_path)))
    except Exception as e:
        print(e, img_path)
    
with open(os.path.join(DATA_PATH, 'train.txt'), 'w', encoding='utf-8') as f:
    f.write('\n'.join(file_list) + '\n')
        
print(len(file_list))

 18%|█████████████▊                                                              | 5135/28195 [00:08<00:41, 555.28it/s]

'1' is not in list C:/Users/Brad/Documents/Dataset/(2차_최종) 교차로정보 데이터셋_20210720\교차로정보 데이터셋_bbox_2\MP_SEL_047407.jpg


 25%|███████████████████▎                                                        | 7149/28195 [00:12<00:34, 614.42it/s]

'1' is not in list C:/Users/Brad/Documents/Dataset/(2차_최종) 교차로정보 데이터셋_20210720\교차로정보 데이터셋_bbox_1\MP_SEL_017712.jpg


 26%|████████████████████                                                        | 7456/28195 [00:12<00:34, 607.39it/s]

'1' is not in list C:/Users/Brad/Documents/Dataset/(2차_최종) 교차로정보 데이터셋_20210720\교차로정보 데이터셋_bbox_1\MP_SEL_037601.jpg
'1' is not in list C:/Users/Brad/Documents/Dataset/(2차_최종) 교차로정보 데이터셋_20210720\교차로정보 데이터셋_bbox_2\MP_SEL_012921.jpg


 65%|████████████████████████████████████████████████▍                          | 18226/28195 [00:31<00:16, 611.11it/s]

'1' is not in list C:/Users/Brad/Documents/Dataset/(2차_최종) 교차로정보 데이터셋_20210720\교차로정보 데이터셋_bbox_1\MP_SEL_053791.jpg


100%|███████████████████████████████████████████████████████████████████████████| 28195/28195 [00:48<00:00, 580.48it/s]

28190





In [9]:
file_list = []

for img_path in tqdm(val_img_list):
    json_path = img_path.replace('.jpg', '.json')

    with open(json_path, 'r') as f:
        data = json.load(f)

    w = data['imageWidth']
    h = data['imageHeight']
    
    txt = ''
    
    try:
        for shape in data['shapes']:
            label = shape['label']

            x1, y1 = shape['points'][0]
            x2, y2 = shape['points'][1]

            cx = (x1 + x2) / 2. / w
            cy = (y1 + y2) / 2. / h
            bw = (x2 - x1) / w
            bh = (y2 - y1) / h

            label = class_names.index(shape['label'])

            txt += '%d %f %f %f %f\n' % (label, cx, cy, bw, bh)

        copy2(img_path, os.path.join(DATA_PATH, 'val', 'images', os.path.basename(img_path)))

        with open(os.path.join(DATA_PATH, 'val', 'labels', os.path.basename(json_path).replace('.json', '.txt')), 'w') as f:
            f.write(txt)
        
        file_list.append(os.path.join('val', 'images', os.path.basename(img_path)))
    except Exception as e:
        print(e, img_path)
    
with open(os.path.join(DATA_PATH, 'val.txt'), 'w', encoding='utf-8') as f:
    f.write('\n'.join(file_list) + '\n')
        
print(len(file_list))

100%|█████████████████████████████████████████████████████████████████████████████| 7049/7049 [00:17<00:00, 402.53it/s]

7049



