## Creating the dataset for WavePaint training

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import random
import os
import os.path as osp
from tqdm import tqdm
import cv2
from PIL import Image
import albumentations as A

In [2]:
def reading_in_background_image(index, image_path, file_names) :
    file_name = str(file_names[index])
    background_image_path = osp.join(image_path, file_name)
    background = Image.open(background_image_path)
    return background

In [3]:
def get_resized_background_transform(image_width, image_height):
    return A.Compose([
        A.RandomCrop(p=1, height=image_height, width=image_width)
    ])

In [4]:
def show_image_with_bounding_box (xcentre, ycentre, width, heigth):

    xmin = xcentre - width/2
    ymin = ycentre - heigth/2
    xmax = xcentre + width/2
    ymax = ycentre + heigth/2

    plt.plot([xmin, xmin], [ymin, ymax], '-', color = 'red' ) # Left edge
    plt.plot([xmax, xmax], [ymin, ymax], '-', color = 'red') # Right edge
    plt.plot([xmin, xmax], [ymin, ymin], '-', color = 'red') # Top edge
    plt.plot([xmin, xmax], [ymax, ymax], '-', color = 'red') # Bottom edge

In [5]:
bg_train_path = "../datasets/OpenImages_data_example/train"
bg_train_images = os.listdir(bg_train_path)

bg_val_path = "../datasets/OpenImages_data_example/val"
bg_val_images = os.listdir(bg_val_path)


path = '../datasets/making_WavePaint_data'

if not os.path.exists(path):
    os.makedirs(path)

# Path for every model images
images_path_per_model = os.path.join(path, "images")
if not os.path.exists(images_path_per_model):
    os.makedirs(images_path_per_model)

train_path = os.path.join(images_path_per_model, "train")
if not os.path.exists(train_path):
    os.makedirs(train_path)
test_path = os.path.join(images_path_per_model, "val")
if not os.path.exists(test_path):
    os.makedirs(test_path)


## Train

In [6]:
id_length = 6
nr_of_train = 24
image_w_h = 256

for i in tqdm(range(nr_of_train)):

    background_image = reading_in_background_image(i, bg_train_path, bg_train_images)
    resizing_image = get_resized_background_transform(image_w_h, image_w_h)
    background_resized = resizing_image(image=np.array(background_image))
    background_image = Image.fromarray(np.uint8(background_resized['image'])).convert('RGB')

     # Generate file name
    image_id = '0' * (id_length - len(str(i))) + str(i)

    # Save image
    image_path = os.path.join(path, '{}/{}/{}.png'.format('images', 'train', image_id))
    background_image.save(image_path, format='png')

    
    # Generating mask
    mask = np.zeros((image_w_h, image_w_h))
    # Randomly generating bbox
    box_width_ratio = random.uniform(0.3, 0.6)
    box_height_ratio = random.uniform(0.65, 0.85)
    box_width = int(image_w_h * box_width_ratio)
    box_height = int(image_w_h * box_height_ratio)
    start_x = int((image_w_h - box_width) / 2)
    start_y = int((image_w_h - box_height) / 2)
    mask[start_y:start_y + box_height, start_x:start_x + box_width] = 1
    mask = mask * 255

    # Saving a mask
    mask_path = os.path.join(path, '{}/{}/{}_mask000.png'.format('images', 'train', image_id))
    mask_image = Image.fromarray(np.uint8(mask)).convert('L')
    mask_image.save(mask_path, format='png')

    

100%|███████████████████████████████████████████| 24/24 [00:00<00:00, 43.84it/s]


## Val

In [7]:

id_length = 6
nr_of_val = 12
image_w_h = 256

for i in tqdm(range(nr_of_val)):

    background_image = reading_in_background_image(i, bg_val_path, bg_val_images)
    resizing_image = get_resized_background_transform(image_w_h, image_w_h)
    background_resized = resizing_image(image=np.array(background_image))
    background_image = Image.fromarray(np.uint8(background_resized['image'])).convert('RGB')

     # Generate file name
    image_id = '0' * (id_length - len(str(i))) + str(i)

    # Save image
    image_path = os.path.join(path, '{}/{}/{}.png'.format('images', 'val', image_id))
    background_image.save(image_path, format='png')

    
    # Generating mask
    mask = np.zeros((image_w_h, image_w_h))
    # Randomly generating bbox
    box_width_ratio = random.uniform(0.3, 0.6)
    box_height_ratio = random.uniform(0.65, 0.85)
    box_width = int(image_w_h * box_width_ratio)
    box_height = int(image_w_h * box_height_ratio)
    start_x = int((image_w_h - box_width) / 2)
    start_y = int((image_w_h - box_height) / 2)
    mask[start_y:start_y + box_height, start_x:start_x + box_width] = 1
    mask = mask * 255

    # Saving a mask
    mask_path = os.path.join(path, '{}/{}/{}_mask000.png'.format('images', 'val', image_id))
    mask_image = Image.fromarray(np.uint8(mask)).convert('L')
    mask_image.save(mask_path, format='png')


100%|███████████████████████████████████████████| 12/12 [00:00<00:00, 44.79it/s]
