In [1]:
import cv2
import numpy as np
import os, shutil
import glob
import random
from datetime import datetime
import tqdm
from tqdm import tqdm_notebook as tqdm
# from tqdm import tqdm
import params

In [2]:
random.seed(datetime.now())

In [3]:
base_path = params.source_data_path
number_of_training_samples = 10000
number_of_validation_samples_coeff = 0.2
number_of_validation_samples = int(number_of_validation_samples_coeff * number_of_training_samples)

In [4]:
origin_imgs = glob.glob(os.path.join(base_path, '*.png'))

origin_imgs = [ x for x in origin_imgs if 'mask' not in x]

used_imgs = []
for x in origin_imgs:
    used_imgs.append([x, os.path.splitext(x)[0] + '_mask.png'])

In [5]:
training_data_path = params.training_data_path
validation_data_path = params.validation_data_path

In [6]:
def GenerateSampleIndex(samples_count):
    index_max = samples_count
    index = random.randint(0, index_max)
    return index

In [7]:
def EstimateClassPresence(classes_mask, class_num):
    masked_image = (classes_mask == class_num + 1)
    return sum(sum(masked_image)) / float(classes_mask.shape[0] * classes_mask.shape[1])

In [8]:
def GenerateImage(image, mask, sample_width, sample_height):
    height, width, _ = image.shape
    if random.randint(0, 100) > 30:        
        x_offset = random.randint(0, width - sample_width - 1)
        y_offset = random.randint(0, height - sample_height - 1)

        sample = image[y_offset:y_offset + sample_height, x_offset:x_offset + sample_width]
        sample_mask = mask[y_offset:y_offset + sample_height, x_offset:x_offset + sample_width]
    else:
        iters = 0
        while True:
            x_offset = random.randint(0, width - sample_width - 1)
            y_offset = random.randint(0, height - sample_height - 1)

            sample = image[y_offset:y_offset + sample_height, x_offset:x_offset + sample_width]
            sample_mask = mask[y_offset:y_offset + sample_height, x_offset:x_offset + sample_width] 
            
            if EstimateClassPresence(sample_mask, 2) > 0.2 or iters > 100:
                break
                
            iters += 1
                
#             print('* ',EstimateClassPresence(sample_mask, 2))
    
#     print(EstimateClassPresence(sample_mask, 2))
    
    angles = [0, 90, 180, 270]
    angle = random.choice(angles)
    if angle > 0:
        for i in range(int(angle / 90)):
            sample = np.rot90(sample)
            sample_mask = np.rot90(sample_mask)
    
    return [sample, sample_mask]
    
def GenerateImageFile(image_fn, mask_fn, sample_image_fn, sample_mask_fn, sample_width, sample_height):
    image = cv2.imread(image_fn)
    mask = cv2.imread(mask_fn, cv2.IMREAD_GRAYSCALE)
    
    [sample, sample_mask] = GenerateImage(image, mask, sample_width, sample_height)
    
    cv2.imwrite(sample_image_fn, sample)
    cv2.imwrite(sample_mask_fn, sample_mask)
    
def GenerateImageFiles(used_files, target_dir, required_images_cnt, sample_width, sample_height):
    result_files = []
    for i in tqdm(range(required_images_cnt)):
        used_files_pair_index = random.randint(0, len(used_files) - 1)
        image_fn, mask_fn = used_files[used_files_pair_index]
        
        result_sample_fn = os.path.join(target_dir, 'sample_{}.png'.format(len(result_files)))
        result_sample_mask_fn = os.path.join(target_dir, 'sample_{}_mask.png'.format(len(result_files)))
        
        GenerateImageFile(image_fn, mask_fn, result_sample_fn, result_sample_mask_fn, sample_width, sample_height)
        
        result_files.append([result_sample_fn, result_sample_mask_fn])
    return result_files

if os.path.exists(training_data_path):
    shutil.rmtree(training_data_path)
os.makedirs(training_data_path)

if os.path.exists(validation_data_path):
    shutil.rmtree(validation_data_path)
os.makedirs(validation_data_path)

training_samples = GenerateImageFiles(used_imgs, training_data_path, number_of_training_samples, params.GetImageSize(), params.GetImageSize())
validation_samples = GenerateImageFiles(used_imgs, validation_data_path, number_of_validation_samples, params.GetImageSize(), params.GetImageSize())
    


HBox(children=(IntProgress(value=0, max=10000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=2000), HTML(value='')))




In [9]:
params.GetImageSize()

256