In [10]:
import os
import random
from PIL import Image
import math
import cv2
import matplotlib.pyplot as plt
import numpy as np
from enum import Enum
# from util.generate import *
execfile("util/generate.py")

In [11]:
MAX_PLACE_TRIES = 1000
SCALE_TRY_THRESHOLD = 100

MIN_SIZE_RATIO = 0.6
MIN_OBJECT_SIZE = 30

class Color(Enum):
    RED = (255, 0, 0)
    GREEN = (0, 255, 0)
    BLUE = (0, 0, 255)
    YELLOW = (255, 255, 0)
    CYAN = (0, 255, 255)
    MAGENTA = (255, 0, 255)
    BLACK = (0, 0, 0)
    WHITE = (255, 255, 255)


In [12]:
def place_objects(background, object_collection, num_objects):
    # Make a copy of the object_collection to avoid modifying the original list
    background = background.copy()
    object_collection = object_collection.copy()

    object_collection = pre_scale_images(background, object_collection, num_objects)

    # Create a list to store the positions of the placed objects
    positions = []

    # Iterate over the number of objects to be placed
    for _ in range(num_objects):
        # Randomly select an object from the collection
        object_image = random.choice(object_collection)
        
        original_width = object_image.shape[1]

        # Remove the selected object from the collection
        # object_collection.remove(object_image)
        
        # Randomly scale the object image
        scale_factor = random.choice((random.uniform(0.7, 1.1), 1, 1))
        object_width = int(object_image.shape[1] * scale_factor)
        object_height = int(object_image.shape[0] * scale_factor)

        object_image = cv2.resize(object_image, (object_width, object_height))
        
        x = random.randint(0, background.shape[1] - object_width)
        y = random.randint(0, background.shape[0] - object_height)
       
        dx = int(random.randint(-10, 10) * background.shape[0] / 100)
        dy = int(random.randint(-10, 10) * background.shape[1] / 100)

        for i in range(0, MAX_PLACE_TRIES):
            # Check if the new position overlaps with any of the existing positions
            if not any([boxes_overlap((x, y, object_width, object_height), (pos_x, pos_y, pos_width, pos_height)) for pos_x, pos_y, pos_width, pos_height in positions]):
                break

            if i >= SCALE_TRY_THRESHOLD and object_width / original_width > MIN_SIZE_RATIO:
                scale_factor = random.uniform(0.8, 1)
                object_width = int(object_image.shape[1] * scale_factor)
                object_height = int(object_image.shape[0] * scale_factor)
                if object_width < MIN_OBJECT_SIZE or object_height < MIN_OBJECT_SIZE:
                    smaller = min(object_width, object_height)
                    scale_factor = MIN_OBJECT_SIZE / smaller
                    object_width = int(object_image.shape[1] * scale_factor)
                    object_height = int(object_image.shape[0] * scale_factor)

                object_image = cv2.resize(object_image, (object_width, object_height))
            
            x += dx
            y += dy

            if x < 0 or x + object_width > background.shape[0]:
                dx = -dx
                x += dx
            elif y < 0 or y + object_height > background.shape[1]:
                dy = -dy
                y += dy

        else:
            # If a valid position could not be found, skip this object
            continue

        # Add the new position to the list of positions
        positions.append((x, y, object_width, object_height))


        object_width = object_image.shape[1]
        object_height = object_image.shape[0]
        # Place the object on the background image
        background[y:y+object_height, x:x+object_width] = object_image

        # plt.figure(figsize=(10, 10))
        # plt.imshow(background)
        # plt.show()
    
    # Return the modified background image
    return background

In [13]:
import os
import cv2
import seaborn as sns

def generate_images(num_images, image_size, num_objects, color_pallette=None):
    randomize_size = image_size is None 
    # Define the directory containing your object images
    collection_directory = './collection'

    # Read all images from the collection directory and save them in the collection array
    collection = [cv2.imread(os.path.join(collection_directory, filename)) for filename in os.listdir(collection_directory) if filename.endswith('.jpg') or filename.endswith('.png')]
        # collection[i] = filter_image(collection[i])

    if color_pallette is None:
        color_pallette = np.array([np.array(color) * 255 for color in sns.color_palette("husl", 8)], dtype=np.uint8)

#     # Append a few random shapes to the collection
#     for _ in range(3):
#         shape_type = random.choice(['circle', 'square', 'triangle'])
#         object_size = random.randint(500, 1000)
#         color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
# # def create_compound_shape(size, num_shapes=None, color_pallette = None):
#         shape_type = create_compound_shape(object_size, num_shapes=3, color_pallette = color_pallette)
#         collection.append(shape_type)

    for i in range(len(collection)):
        collection[i] = autocrop_image(collection[i], background_color=[255, 255, 255], margin=10)
    # Create a white background image
    if randomize_size:
        image_size = random.randint(500, 1000)
    background = np.ones((image_size[0], image_size[1], 3), dtype=np.uint8) * 255

    # Get list of existing files
    existing_files = os.listdir('generated_images')
    
    # Find the highest number already used in filenames
    highest_num = max([int(f.split('.')[0].split('generated_image')[1]) for f in existing_files if f.startswith('generated_image') and f.endswith('.jpg')], default=-1)
    
    for i in range(num_images):
        # Generate image using place_objects function
        image = place_objects(background, collection, num_objects)
        
        # Save the image with appropriate filename
        cv2.imwrite(f'generated_images/generated_image{highest_num + i + 1}.jpg', image)

In [16]:
generate_images(1, (1000, 1000), 18)
# generate_images(1, (800, 800), 8)