In [8]:
# Imports for the train test val split
import os
import shutil
import random
from tqdm import tqdm

In [1]:
def create_directory_structure(output_dir):
    """
    Creates the directory structure for train/val/test splits with grayscale and color subdirectories.

    Args:
    output_dir (str): The base directory where the train/val/test directories will be created.

    Directory structure created:
    output_dir/
        train/
            grayscale/
            color/
        val/
            grayscale/
            color/
        test/
            grayscale/
            color/
    """
    splits = ['train', 'val', 'test']
    subdirs = ['grayscale', 'color']
    
    for split in splits:
        for subdir in subdirs:
            dir_path = os.path.join(output_dir, split, subdir)
            os.makedirs(dir_path, exist_ok=True)

In [2]:
def get_all_files(root_dir):
    """
    Recursively gets all files from the nested directory structure.
    
    Args:
    root_dir (str): Root directory to search for files.

    Returns:
    List of file paths.
    """
    all_files = []
    for subdir, _, files in os.walk(root_dir):
        for file in files:
            all_files.append(os.path.join(subdir, file))
    return all_files

In [3]:
def handle_collision(filepath):
    """
    Handle filename collision by appending a number to the filename.

    Args:
    filepath (str): The original file path.

    Returns:
    str: A file path that does not exist by appending numbers if necessary.
    """
    base, extension = os.path.splitext(filepath)
    counter = 1
    new_filepath = filepath
    while os.path.exists(new_filepath):
        new_filepath = f"{base}_{counter}{extension}"
        counter += 1
    return new_filepath

In [45]:
def move_files(indices, grayscale_images, color_images, output_dir, split):
    """
    Move files to the specified directory split and handle filename collisions.

    Args:
    indices (list): List of indices for the files to be moved.
    grayscale_images (list): List of grayscale image file paths.
    color_images (list): List of color image file paths.
    output_dir (str): Base directory where the split data will be saved.
    split (str): The split type ('train', 'val', or 'test').
    """
    os.makedirs(os.path.join(output_dir, split, 'grayscale'), exist_ok=True)
    os.makedirs(os.path.join(output_dir, split, 'color'), exist_ok=True)
    
    for idx in tqdm(indices, desc=f"Moving {split} files"):
        gray_src = grayscale_images[idx]
        color_src = color_images[idx]
        
        gray_filename = os.path.basename(gray_src)
        color_filename = os.path.basename(color_src)

        gray_letter = gray_src.split("/")[1]
        color_letter = color_src.split("/")[1]

        # Add letter to beginning of filename for label identification
        gray_filename = gray_letter + "_" + gray_filename
        color_filename = color_letter + "_" + color_filename
        
        gray_dst = os.path.join(output_dir, split, 'grayscale', gray_filename)
        color_dst = os.path.join(output_dir, split, 'color', color_filename)
        
        # Handle filename collisions
        gray_dst = handle_collision(gray_dst)
        color_dst = handle_collision(color_dst)

        
        shutil.copy2(gray_src, gray_dst)
        shutil.copy2(color_src, color_dst)

In [46]:
def split_data(grayscale_dir, color_dir, output_dir, train_ratio=0.7, val_ratio=0.15):
    """
    Splits the dataset into train, validation, and test sets and moves the images to the corresponding directories.

    Args:
    grayscale_dir (str): Directory containing grayscale images.
    color_dir (str): Directory containing color images.
    output_dir (str): Base directory where the split data will be saved.
    train_ratio (float): Proportion of the data to be used for training. Default is 0.7.
    val_ratio (float): Proportion of the data to be used for validation. Default is 0.15.

    The remaining data will be used for testing.

    Directory structure created:
    output_dir/
        train/
            grayscale/
            color/
        val/
            grayscale/
            color/
        test/
            grayscale/
            color/
    """
    # Get list of all images
    grayscale_images = get_all_files(grayscale_dir)
    color_images = get_all_files(color_dir)

    # Ensure the lists are the same length
    assert len(grayscale_images) == len(color_images), "Mismatch in the number of grayscale and color images"
    
    # Shuffle the indices
    indices = list(range(len(grayscale_images)))
    random.shuffle(indices)
    
    # Split the indices
    train_split = int(train_ratio * len(indices))
    val_split = int((train_ratio + val_ratio) * len(indices))
    
    train_indices = indices[:train_split]
    val_indices = indices[train_split:val_split]
    test_indices = indices[val_split:]
    
    # Move the files
    move_files(train_indices, grayscale_images, color_images, output_dir, 'train')
    move_files(val_indices, grayscale_images, color_images, output_dir, 'val')
    move_files(test_indices, grayscale_images, color_images, output_dir, 'test')

In [47]:
create_directory_structure('newdata_dual_data')
split_data('newdata_preprocessed_50x50', 'newdata_landmark_preprocessed', 'newdata_dual_data')


Moving train files: 100%|█████████████| 133856/133856 [01:46<00:00, 1257.93it/s]
Moving val files: 100%|█████████████████| 28684/28684 [00:21<00:00, 1362.18it/s]
Moving test files: 100%|████████████████| 28684/28684 [00:19<00:00, 1461.83it/s]
