### Load libraries

In [None]:
import cv2, os, shutil
from tqdm import tqdm
from glob import glob
from omegaconf import DictConfig, OmegaConf
import random
import numpy as np
import threading


dataset_source = "/home/ubuntu/workspace/ywshin/construct/train_folder/hybridnets_dataset"
new_img_dir = os.path.join(dataset_source, "Images")
new_label_dir = os.path.join(dataset_source, "Labels")
seg_dir = os.path.join(dataset_source, "Seg")
parsings = [new_img_dir, new_label_dir, seg_dir]

for p in parsings:
    if os.path.isdir(p):
        shutil.rmtree(p)
        os.mkdir(p)
    else:
        os.mkdir(p)

root = "/mnt/vitasoft/2022_Patrasche/Images_label_processing/"
tasks = next(os.walk(root))[1]
tasks = [t for t in tasks if t != "주행가능구역"]

total_img_list = []

### Define core function

In [None]:
def create_dataset(transfer_img_list):
    for img_path in tqdm(transfer_img_list):
        json_path = img_path.replace(".png", ".json")
        try:
            with open(json_path, 'r') as f:
                data = json.load(f)
        except:
            continue

        !cp $img_path $img_subdir        
        !cp $json_path $label_subdir

        seg = blank.copy()
        seg_path = os.path.join(seg_dir, split_key, os.path.basename(json_path).replace(".json", ".png"))

        obj_list = data["shapes"]
        points = []
        for obj in obj_list:
            if obj["label"] == "road":
                points = obj["points"]
                break
        if len(points):
            seg = cv2.fillPoly(seg, [np.array(points, np.int32)], (255,255,255))
        cv2.imwrite(seg_path, seg)

### Collect labeled data

In [None]:
for task in tasks:
    source = os.path.join(root, task)
# source = "/mnt/vitasoft/2022_Patrasche/Images_label_processing/보행안내자/"
    dir_list = next(os.walk(source))[1]
    dir_list.sort()
    
    for f_i, d in enumerate(dir_list):
        label_txt_path = os.path.join(source, d, "label.txt")

        if not os.path.isfile(label_txt_path):
            continue
        with open(label_txt_path) as f:
            label_count = int(f.readlines()[0])


        img_dir = os.path.join(source, d)

        img_list = glob(os.path.join(img_dir, "*.png"))
        img_list = sorted(img_list, key = lambda x: x.split("-")[-2])[:label_count-(f_i*500)]
        total_img_list += img_list
    #     print(len(total_img_list))
    
random.shuffle(total_img_list)
split_list = ["val", "train"]
# split_list = ["train"]
split_amount = int(len(total_img_list)*0.03)

width, height = 1920, 1088
blank = np.zeros((height, width, 3), dtype=np.uint8)

print(len(total_img_list))

### Parsing

In [None]:
for split_key in split_list:
    img_subdir = os.path.join(new_img_dir, split_key)
    label_subdir = os.path.join(new_label_dir, split_key)
    seg_subdir = os.path.join(seg_dir, split_key)

    subdirs = [img_subdir, label_subdir, seg_subdir]
    for sdir in subdirs:
        if os.path.isdir(sdir):
            shutil.rmtree(sdir)
            os.mkdir(sdir)
        else:
            os.mkdir(sdir)

    if split_key == "train":
        transfer_img_list = total_img_list[:-split_amount]
        
        thread_count = 10
        file_unit = len(transfer_img_list)//thread_count

        # file_unit
        for i in range(thread_count-1):
            thread = threading.Thread(target=create_dataset, args = (transfer_img_list[(i)*file_unit:(i+1)*file_unit],))
            thread.start()
        thread = threading.Thread(target=create_dataset, args = (transfer_img_list[(i+1)*file_unit:],))
        thread.start()
        
    elif split_key == "val":
        transfer_img_list = total_img_list[-split_amount:]
        
        create_dataset(transfer_img_list)