In [None]:
from Padd import ObjectAdder
from PIL import Image
import torch
import numpy as np
import os
import json
import random
from PIL import Image
from tqdm import tqdm
from typing import List, Tuple, Dict, Any

from PATHS.PaddGaragePaths import IMAGES_DATASET_PATH, GENRATED_OBJECTS

%load_ext autoreload
%autoreload 2

### Entityseg

In [None]:
scene_directory = IMAGES_DATASET_PATH

scene_files = sorted(os.listdir(scene_directory))
generated_objects_directory = GENRATED_OBJECTS
len(scene_files)

11580

### Necessory functions

In [10]:
def CombineImagesHorizontally(*images):
    if not images:
        raise ValueError("No images provided")

    total_width = sum(img.width for img in images)
    max_height = max(img.height for img in images)

    combined_image = Image.new('RGB', (total_width, max_height))

    x_offset = 0
    for img in images:
        combined_image.paste(img, (x_offset, 0))
        x_offset += img.width

    return combined_image

In [11]:
class CustomDataset:
    def __init__(self, root_dir: str):
        self.root_dir = root_dir
        self.class_names = self._get_class_names(root_dir)
        self.data = self._load_data()

    def _get_class_names(self, root_dir: str) -> List[str]:
        class_names = []
        for entry in os.listdir(root_dir):
            entry_path = os.path.join(root_dir, entry)
            if os.path.isdir(entry_path):
                class_names.append(entry)
        return class_names

    def _load_data(self) -> Dict[str, Dict[str, Any]]:
        data = {}
        for class_name in self.class_names:
            class_dir = os.path.join(self.root_dir, class_name)
            prompts_path = os.path.join(class_dir, 'prompts.json')
            with open(prompts_path, 'r') as f:
                prompts = json.load(f)

            images_dir = os.path.join(class_dir, 'images')
            data[class_name] = {
                'prompts': prompts,
                'images_dir': images_dir
            }
        return data

    def get_batch(self, n: int, seed: int = None) -> Tuple[List[Image.Image], List[Image.Image], List[str]]:
        if seed is not None:
            random.seed(seed)

        images = []
        masks = []
        prompts = []

        all_data = []
        for class_name in self.class_names:
            class_data = self.data[class_name]
            for prompt_key in class_data['prompts'].keys():
                all_data.append((class_name, prompt_key, class_data['prompts'][prompt_key]))

        random.shuffle(all_data)

        for _ in range(n):
            if not all_data:
                break

            class_name, prompt_key, prompt = all_data.pop()
            class_data = self.data[class_name]
            image_dir = os.path.join(class_data['images_dir'], prompt_key)

            image_path = os.path.join(image_dir, 'object_raw_image.jpg')
            mask_path = os.path.join(image_dir, 'mask.jpg')

            image = Image.open(image_path)
            mask = Image.open(mask_path)

            images.append(image)
            masks.append(mask)
            prompts.append(prompt)

        return images, masks, prompts

    def get_item_by_class_and_index(self, class_name: str, index: int) -> Tuple[Image.Image, Image.Image, str]:
        class_data = self.data[class_name]
        prompt_key = list(class_data['prompts'].keys())[index]
        prompt = class_data['prompts'][prompt_key]
        image_dir = os.path.join(class_data['images_dir'], prompt_key)

        image_path = os.path.join(image_dir, 'object_raw_image.jpg')
        mask_path = os.path.join(image_dir, 'mask.jpg')

        image = Image.open(image_path)
        mask = Image.open(mask_path)

        return image, mask, prompt


### Augmentation

In [12]:
dataset = CustomDataset(generated_objects_directory)

In [None]:
model = ObjectAdder()

In [None]:
num = 800
for file in tqdm(scene_files[num:1000]):
    scene = Image.open(os.path.join(scene_directory, file))
    batch = dataset.get_batch(n=5, seed=num)
    try:
        new_images, controlnet_images = model(scene, batch[0], batch[1], batch[2], seed=num)
        if new_images == []:
            continue
        combine_image = CombineImagesHorizontally(*new_images, *controlnet_images)
        new_image = new_images[-1]
        new_image.save("augmentations/images/" + file)
        combine_image.save("augmentations/combine_images/" + file)
    except Exception as e:
        print(f"Error, start new iteration")
        continue
    num=num+1

In [3]:
len(os.listdir("augmentations/images/"))

905