In [7]:
import os
import random
import cv2

In [8]:
SEED = 462
random.seed(SEED)
TRAIN_RATIO = 1  # no val set

In [9]:
data_path = "data"
image_data_path = os.path.join(data_path, "image")

generated_path = os.path.join(image_data_path, "generated")
photograph_path = os.path.join(image_data_path, "photograph")

train_path = os.path.join(image_data_path, "train")
# validation_path = os.path.join(image_data_path, "validation")
# commented out since we do not need val in the current svm implementation
# using a val set to select C would be an improvement
test_path = os.path.join(image_data_path, "test")

In [10]:
def enumerate_dataset(dataset_path):
    return {cls: os.listdir(os.path.join(dataset_path, cls)) for cls in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, cls))}

In [11]:
def read_and_resize_image(image_path):
    if (img := cv2.imread(image_path)) is not None:
        return cv2.resize(img, (512, 512))
    return None

In [12]:
for cls, image_names in enumerate_dataset(generated_path).items():
    class_path = os.path.join(generated_path, cls)

    image_names.sort()
    random.shuffle(image_names)

    total_images = len(image_names)
    train_split = int(total_images * TRAIN_RATIO)

    train_images = image_names[:train_split]
    validation_images = image_names[train_split:]

    os.makedirs(os.path.join(train_path, cls), exist_ok=True)
    for i, image_name in enumerate(train_images):
        image_path = os.path.join(class_path, image_name)
        if (image := read_and_resize_image(image_path)) is not None:
            cv2.imwrite(os.path.join(train_path, cls, f"{i:04d}.png"), image)

    # os.makedirs(os.path.join(validation_path, cls), exist_ok=True)
    # for i, image_name in enumerate(validation_images):
    #     image_path = os.path.join(class_path, image_name)
    #     if (image := read_and_resize_image(image_path)) is not None:
    #         cv2.imwrite(os.path.join(validation_path, cls, f"{i:04d}.png"), image)

In [13]:
for cls, image_names in enumerate_dataset(photograph_path).items():
    class_path = os.path.join(photograph_path, cls)

    image_names.sort()
    random.shuffle(image_names)

    os.makedirs(os.path.join(test_path, cls), exist_ok=True)
    for i, image_name in enumerate(image_names):
        image_path = os.path.join(class_path, image_name)
        if (image := read_and_resize_image(image_path)) is not None:
            cv2.imwrite(os.path.join(test_path, cls, f"{i:04d}.png"), image)