In [1]:
import os
import csv
import random
from PIL import Image
from tqdm import tqdm

# Split Train and Test

In [2]:
def split(origin_image_path, origin_mask_path):
    images = os.listdir(origin_image_path)
    masks = os.listdir(origin_mask_path)
    assert len(images) == len(masks)
    num_data = len(images)
    random.shuffle(images)
    os.makedirs("./data/train/image")
    os.makedirs("./data/train/mask")
    os.makedirs("./data/test/image")
    os.makedirs("./data/test/mask")
    for idx, image in enumerate(tqdm(images)):
        image_name = image.split("/")[-1]
        os.system("cp {} {}".format(os.path.join(origin_image_path, image_name), os.path.join(f"./data/{'test' if idx <= num_data // 10 else 'train'}/image", image_name)))
        os.system("cp {} {}".format(os.path.join(origin_mask_path, image_name), os.path.join(f"./data/{'test' if idx <= num_data // 10 else 'train'}/mask", image_name)))

In [3]:
split("./data/original/image/", "./data/original/mask/")

100%|██████████| 2900/2900 [00:06<00:00, 435.87it/s]


# Split Image

In [None]:
def split_images(input_dir, output_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    image_files = [f for f in os.listdir(input_dir) if f.lower().endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp"))]
    for image_file in tqdm(image_files):
        input_path = os.path.join(input_dir, image_file)
        original_image = Image.open(input_path)
        width, height = original_image.size
        split_point = width // 2
        left_image = original_image.crop((0, 0, split_point, height))
        right_image = original_image.crop((split_point, 0, width, height))
        left_output_path = os.path.join(output_dir, f"left_{image_file}")
        right_output_path = os.path.join(output_dir, f"right_{image_file}")
        right_image = right_image.transpose(Image.FLIP_LEFT_RIGHT)
        left_image.save(left_output_path)
        right_image.save(right_output_path)

In [None]:
# split_images("./data/train/image", "./data/train/image_splited")
# split_images("./data/train/mask", "./data/train/mask_splited")
# split_images("./data/test/image", "./data/test/image_splited")
# split_images("./data/test/mask", "./data/test/mask_splited")
# split_images("./data/predict/image", "./data/predict/image_splited")

In [None]:
def split_images_overlap(input_dir, output_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    image_files = [f for f in os.listdir(input_dir) if f.lower().endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp"))]
    for image_file in tqdm(image_files):
        input_path = os.path.join(input_dir, image_file)
        original_image = Image.open(input_path)
        width, height = original_image.size
        split_point = width // 2
        left_image = original_image.crop((30, 0, split_point + 30, height))
        right_image = original_image.crop((split_point - 30, 0, width - 30, height))
        left_output_path = os.path.join(output_dir, f"left_{image_file}")
        right_output_path = os.path.join(output_dir, f"right_{image_file}")
        right_image = right_image.transpose(Image.FLIP_LEFT_RIGHT)
        left_image.save(left_output_path)
        right_image.save(right_output_path)

In [None]:
# split_images_overlap("./data/train/image", "./data/train/image_splited")
# split_images_overlap("./data/train/mask", "./data/train/mask_splited")
# split_images_overlap("./data/test/image", "./data/test/image_splited")
# split_images_overlap("./data/test/mask", "./data/test/mask_splited")
# split_images_overlap("./data/predict/image", "./data/predict/image_splited")

# Filter Illegal Data

In [None]:
def is_binary_image_all_zeros(image_path):
    with Image.open(image_path) as img:
        img = img.convert("L")
        pixels = list(img.getdata())
        return all(pixel == 0 for pixel in pixels)
    

def get_illegal_data(mask_dir):
    illegal_data = []
    for filename in tqdm(os.listdir(mask_dir)):
        if filename.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
            image_path = os.path.join(mask_dir, filename)
            if is_binary_image_all_zeros(image_path):
                illegal_data.append(filename)
    return illegal_data


def perform_delete(iliiegal_data, image_dir, mask_dir):
    for filename in tqdm(iliiegal_data):
        os.remove(os.path.join(image_dir, filename))
        os.remove(os.path.join(mask_dir, filename))


def filter_illegal_data(image_dir, mask_dir):
    illegal_data = get_illegal_data(mask_dir)
    print(f"Found {len(illegal_data)} illegal data:")
    print(illegal_data)
    input("Confirm?")
    perform_delete(illegal_data, image_dir, mask_dir)

In [None]:
# filter_illegal_data("./data/train/image", "./data/train/mask")
# filter_illegal_data("./data/test/image", "./data/test/mask")

# Generate CSV

In [None]:
def generate_csv(type, image_dir, mask_dir = None):
    images = os.listdir(image_dir)

    if mask_dir is None:
        masks = [None] * len(images)
    else:
        masks = os.listdir(mask_dir)

    assert len(images) == len(masks)

    with open(f"./data/{type}.csv", "w", newline="") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(["#", "img", "seg"])
        for i in range(len(images)):
            writer.writerow([i, os.path.join(image_dir, images[i]), os.path.join(mask_dir, masks[i]) if mask_dir is not None else "./placeholder.png"])

In [None]:
# generate_csv("train", "./data/train/image", "./data/train/mask")
# generate_csv("test", "./data/test/image", "./data/test/mask")
# generate_csv("predict", "./data/predict/image")