In [1]:
# Imports for the train test val split
import os
import shutil
import random
from tqdm import tqdm

In [2]:
def create_directory_structure(output_dir):
    """
    Creates following dir struct:
    output_dir/
        train/
            grayscale/
            color/
        val/
            grayscale/
            color/
        test/
            grayscale/
            color/
    """
    splits = ['train', 'val', 'test']
    subdirs = ['grayscale', 'color']
    
    for split in splits:
        for subdir in subdirs:
            dir_path = os.path.join(output_dir, split, subdir)
            os.makedirs(dir_path, exist_ok=True)

In [3]:
def get_all_files(root_dir):
    all_files = []
    for subdir, _, files in os.walk(root_dir):
        for file in files:
            all_files.append(os.path.join(subdir, file))
    return all_files

In [4]:
def handle_collision(filepath):
    base, extension = os.path.splitext(filepath)
    counter = 1
    new_filepath = filepath
    while os.path.exists(new_filepath):
        new_filepath = f"{base}_{counter}{extension}"
        counter += 1
    return new_filepath

In [5]:
def move_files(indices, grayscale_images, color_images, output_dir, split):
    os.makedirs(os.path.join(output_dir, split, 'grayscale'), exist_ok=True)
    os.makedirs(os.path.join(output_dir, split, 'color'), exist_ok=True)
    
    for idx in tqdm(indices, desc=f"Moving {split} files"):
        gray_src = grayscale_images[idx]
        color_src = color_images[idx]
        
        gray_filename = os.path.basename(gray_src)
        color_filename = os.path.basename(color_src)

        gray_letter = gray_src.split("/")[1]
        color_letter = color_src.split("/")[1]

        gray_filename = gray_letter + "_" + gray_filename
        color_filename = color_letter + "_" + color_filename
        
        gray_dst = os.path.join(output_dir, split, 'grayscale', gray_filename)
        color_dst = os.path.join(output_dir, split, 'color', color_filename)
        
        # Handle filename collisions
        gray_dst = handle_collision(gray_dst)
        color_dst = handle_collision(color_dst)

        
        shutil.copy2(gray_src, gray_dst)
        shutil.copy2(color_src, color_dst)

In [6]:
def split_data(grayscale_dir, color_dir, output_dir, train_ratio=0.7, val_ratio=0.15):
    """
    Splits:
    output_dir/
        train/
            grayscale/
            color/
        val/
            grayscale/
            color/
        test/
            grayscale/
            color/
    """
    grayscale_images = get_all_files(grayscale_dir)
    color_images = get_all_files(color_dir)

    assert len(grayscale_images) == len(color_images), "Mismatch in the number of grayscale and color images"
    
    indices = list(range(len(grayscale_images)))
    random.shuffle(indices)
    
    # Split
    train_split = int(train_ratio * len(indices))
    val_split = int((train_ratio + val_ratio) * len(indices))
    
    train_indices = indices[:train_split]
    val_indices = indices[train_split:val_split]
    test_indices = indices[val_split:]
    
    # Move the files
    move_files(train_indices, grayscale_images, color_images, output_dir, 'train')
    move_files(val_indices, grayscale_images, color_images, output_dir, 'val')
    move_files(test_indices, grayscale_images, color_images, output_dir, 'test')

In [7]:
create_directory_structure('newdata_dual_data')
split_data('newdata_preprocessed_50x50', 'newdata_landmark_preprocessed', 'newdata_dual_data')


Moving train files: 100%|██████████████| 133857/133857 [02:55<00:00, 764.86it/s]
Moving val files: 100%|██████████████████| 28684/28684 [00:34<00:00, 824.30it/s]
Moving test files: 100%|█████████████████| 28684/28684 [00:34<00:00, 820.01it/s]
