In [12]:
import cv2
from PIL import Image
import numpy as np
import glob
import json
import os

home_image_root = "../yolov5/dataset_person/images"
home_bbox_root = "../yolov5/dataset_person/labels" # .txt, Should contain the name of the corresponding image

# home_image_root = "office_bg"
# home_bbox_root = None

home_mask_root = None
away_image_root = "image"
away_bbox_root = None
away_mask_root = "mask" # .png, Should contain the name of the corresponding image

home_mask_dict_path = None
away_mask_dict_path = None

home_image_exts = ['.jpg', '.png']
home_bbox_exts = ['.txt']
home_mask_exts = ['.png']
away_image_exts = ['.jpg', '.png']
away_bbox_exts = ['.txt']
away_mask_exts = ['.png']


def make_mask_dict(image_paths, mask_paths, save=True, save_name="default.json"):

    if os.path.isfile(save_name):
        print("There is existing mask dict file. load...")
        with open(save_name, 'r') as f:
            mask_dict = json.loads(f.readline())
        return mask_dict
    
    print(f"Making mask dict at {save_name}...")
    
    mask_dict = {}
    for path in image_paths:

        for mask_path in mask_paths:

            image_id = os.path.basename(path).split('.')[0]
            if image_id in mask_path:
                if image_id not in mask_dict:
                    mask_dict[image_id] = [mask_path]
                else:
                    mask_dict[image_id].append(mask_path)

    if save:
        with open(save_name, "w") as outfile:
            json.dump(mask_dict, outfile)
                    
    print("Done.")
    return mask_dict

def get_ext_paths(root, exts):
    paths = []
    for ext in exts:
        paths += glob.glob(root + f'/*{ext}')
        
    return paths

home_image_paths = get_ext_paths(home_image_root, home_image_exts)
if home_bbox_root is not None:
    home_bbox_paths = get_ext_paths(home_bbox_root, home_bbox_exts)
# home_mask_paths = None

away_image_paths = get_ext_paths(away_image_root, away_image_exts)
# away_bbox_paths = None

if away_mask_dict_path is None:
    away_mask_dict = make_mask_dict(away_image_paths, get_ext_paths(away_mask_root, away_mask_exts), save=True, save_name="away_mask.json")
else:
    with open(away_mask_dict_path, 'r') as f:
        away_mask_dict = json.loads(f.readline())

There is existing mask dict file. load...


In [13]:
home_bboxes

['0 0.12395833333333334 0.29953703703703705 0.10833333333333334 0.125\n',
 '0 0.34453125 0.9458333333333333 0.0296875 0.06018518518518518\n']

In [14]:
bbox_format = 'yolo' # center xywh
save = True
show = True # should be False if you wanna save
save_root = 'human_phone'
amount = 5000

if save:
    print("force show==False during saving mode..")
    show = False


cnt = 0    
while(cnt < amount):
    if save:
        print(f"saving {cnt}/{amount} images..", end='\r')
    home_idx = np.random.randint(len(home_image_paths))
    home_img_path = home_image_paths[home_idx]
    ff = np.fromfile(home_img_path, np.uint8)
    home_img = cv2.imdecode(ff, cv2.IMREAD_COLOR)
    
    if home_bbox_root is not None:
        with open(home_bbox_paths[home_idx], 'r') as f:
            home_bboxes = f.readlines()

        max_person_width = 0
        for line in home_bboxes:
            l = line.replace('\n', '').split(' ')
            label, xywh = l[0], [float(s) for s in l[1:]]

            h, w = home_img.shape[:2]

            x1, y1, x2, y2 = int((xywh[0]-xywh[2]/2)*w), int((xywh[1]-xywh[3]/2)*h), int((xywh[0]+xywh[2]/2)*w), int((xywh[1]+xywh[3]/2)*h)
            if show:
                cv2.rectangle(home_img, (x1, y1), (x2, y2), (255,0,0), 2)

            if max_person_width < x2-x1:
                max_person_width = x2-x1

        if home_bboxes[-1][-1] != '\n':
            home_bboxes[-1] += '\n'

        # print(max_person_width)
        if max_person_width < 100:
            continue
    else:
        home_bboxes = []

    for _ in range(np.random.randint(1,4)):

        away_idx = np.random.randint(len(away_image_paths))
        away_img = cv2.imread(away_image_paths[away_idx])
        away_mask = cv2.imread(np.random.choice(away_mask_dict[os.path.basename(away_image_paths[away_idx]).split('.')[0]]))



        away_mask = cv2.cvtColor(away_mask, cv2.COLOR_BGR2GRAY)
        bbox = cv2.boundingRect(cv2.findNonZero(away_mask))

        # print(home_idx, away_idx, home_bboxes, bbox)

        # cv2.rectangle(away_img, [bbox[0], bbox[1]], [bbox[0]+bbox[2], bbox[1]+bbox[3]], (255,0,0), 3)

        cropped_away_img = away_img[bbox[1]:bbox[1]+bbox[3], bbox[0]:bbox[0]+bbox[2]]
        cropped_away_mask = away_mask[bbox[1]:bbox[1]+bbox[3], bbox[0]:bbox[0]+bbox[2]]

        # display(Image.fromarray(home_img))
        # display(Image.fromarray(away_img))
        # display(Image.fromarray(away_mask))
        # display(Image.fromarray(cropped_away_img))
        # display(Image.fromarray(cropped_away_mask))

        home_h, home_w = home_img.shape[:2]



        while(True):
            resize_factor = np.random.uniform(0.05, 0.6)

            h, w = cropped_away_img.shape[:2]
            new_h, new_w = int(h*resize_factor), int(w*resize_factor)

            cropped_away_img = cv2.resize(cropped_away_img, dsize=(new_w, new_h))
            cropped_away_mask = cv2.resize(cropped_away_mask, dsize=(new_w, new_h))
            # print(new_w, new_h)

            if home_w > new_w and home_h > new_h:
                break


        x1 = np.random.randint(home_w - new_w)
        y1 = np.random.randint(home_h - new_h)
        x2 = x1 + new_w
        y2 = y1 + new_h


        fg = cv2.bitwise_and(cropped_away_img, cropped_away_img, mask=cropped_away_mask)
        bg = cv2.bitwise_and(home_img[y1:y2, x1:x2], home_img[y1:y2, x1:x2], mask=cv2.bitwise_not(cropped_away_mask))

        home_img[y1:y2, x1:x2] = fg+bg

        if show:
            cv2.rectangle(home_img, (x1,y1), (x2,y2), (0,255,0), 2)
        if save:
            x1, y1, x2, y2, home_w, home_h

            x = (x1+x2)/2/home_w
            y = (y1+y2)/2/home_h
            w = (x2-x1)/home_w
            h = (y2-y1)/home_h

            home_bboxes.append(f'0 {x} {y} {w} {h}\n')

    if show:
        display(Image.fromarray(home_img))
    if save:
        save_path = os.path.join(save_root, 
                                 'images', 
                                 f"{str(cnt).zfill(len(str(amount)))}_" + os.path.basename(home_img_path))
        label_save_path = save_path.replace('images', 'labels').replace('.jpg', '.txt')
        
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        os.makedirs(os.path.dirname(label_save_path), exist_ok=True)
        
        cv2.imwrite(save_path, home_img)
        with open(label_save_path, "w") as f:
            f.writelines(home_bboxes)
            
    cnt += 1
    
print("\ncomplete!")

force show==False during saving mode..
saving 4999/5000 images..
complete!
