In [1]:
import os
import random as rnd
from typing import List

import cv2
import numpy as np
import torch
from tqdm import tqdm

In [2]:
BASE_DIR = "../data/processed"
OUTPUT_DIR = "../data/splits"
TRAIN_PCT = 0.7
EVAL_PCT = 0.15
TEST_PCT = 0.15

In [3]:
def split_image(image: np.ndarray, patch_size=256):
    """
    Split an image into patches of a specified size.

    Parameters:
    - image: Input image (numpy array).
    - patch_size: Size of each patch (int).

    Returns:
    - List of patches (list of numpy arrays).
    """
    h, w = image.shape[:2]
    new_h = ((h + patch_size - 1) // patch_size) * patch_size
    new_w = ((w + patch_size - 1) // patch_size) * patch_size
    image = cv2.copyMakeBorder(image, 0, new_h - h, 0, new_w - w, cv2.BORDER_CONSTANT, value=0)
    patches = []

    for i in range(0, new_h, patch_size):
        for j in range(0, new_w, patch_size):
            patch = image[i : i + patch_size, j : j + patch_size]
            if patch.shape[0] == patch_size and patch.shape[1] == patch_size:
                patches.append(patch)

    return patches

In [4]:
def get_image_pairs(path: str):
    return [os.path.join(path, file_name) for file_name in os.listdir(path) if file_name.endswith(".jpg")]

In [5]:
def split_images(image_path_pairs: List[np.ndarray], save_dir: str):
    image_index = 0
    os.makedirs(os.path.join(OUTPUT_DIR, save_dir, "input"), exist_ok=True)
    os.makedirs(os.path.join(OUTPUT_DIR, save_dir, "target"), exist_ok=True)
    for input_path, target_path in tqdm(image_path_pairs, desc=f"Packing images for {save_dir}"):
        input_image = cv2.imread(input_path)
        target_image = cv2.imread(target_path)
        inputs = split_image(input_image)
        targets = split_image(target_image)
        for ip, tp in zip(inputs, targets):
            if (np.count_nonzero(tp < 15) / tp.size) > 0.6:
                continue
            cv2.imwrite(os.path.join(OUTPUT_DIR, save_dir, "input", f"{image_index}.jpg"), ip)
            cv2.imwrite(os.path.join(OUTPUT_DIR, save_dir, "target", f"{image_index}.jpg"), tp)
            image_index += 1


def save_test_images(image_path_pairs: List[np.ndarray], save_dir: str):
    image_index = 0
    os.makedirs(os.path.join(OUTPUT_DIR, save_dir, "input"), exist_ok=True)
    os.makedirs(os.path.join(OUTPUT_DIR, save_dir, "target"), exist_ok=True)
    for input_path, target_path in tqdm(image_path_pairs, desc=f"Saving images for {save_dir}"):
        input_image = cv2.imread(input_path)
        target_image = cv2.imread(target_path)
        cv2.imwrite(os.path.join(OUTPUT_DIR, save_dir, "input", f"{image_index}.jpg"), input_image)
        cv2.imwrite(os.path.join(OUTPUT_DIR, save_dir, "target", f"{image_index}.jpg"), target_image)
        image_index += 1

In [6]:
input_images = get_image_pairs(os.path.join(BASE_DIR, "input"))
target_images = get_image_pairs(os.path.join(BASE_DIR, "target"))
input_images.sort()
target_images.sort()

image_path_pairs = [(i, t) for i, t in zip(input_images, target_images)]
rnd.shuffle(image_path_pairs)

train_size = int(TRAIN_PCT * len(image_path_pairs))
eval_size = int(EVAL_PCT * len(image_path_pairs))

train_paths = image_path_pairs[:train_size]
eval_paths = image_path_pairs[train_size : train_size + eval_size]
test_paths = image_path_pairs[train_size + eval_size :]

split_images(eval_paths, "eval")
split_images(train_paths, "train")
save_test_images(test_paths, "test")

Packing images for eval: 100%|██████████| 185/185 [07:57<00:00,  2.58s/it]
Packing images for train: 100%|██████████| 866/866 [29:32<00:00,  2.05s/it]  
Saving images for test: 100%|██████████| 187/187 [01:59<00:00,  1.56it/s]
